import torch
import torch.nn as nn
import torch.nn.functional as F
import functools
from tqdm.auto import tqdm

from torch.autograd import grad
from torch_scatter import scatter_mean

from utils.nn_utils import variadic_meshgrid

from .transition import construct_transition

from ...dyMEAN.modules.am_egnn import AMEGNN
from ...dyMEAN.modules.radial_basis import RadialBasis

from dataclasses import dataclass

@dataclass
class SoftMHConfig:
    enabled: bool = True
    apply_to: str = "both"   # "coords" | "h" | "both"
    start_ratio: float = 0.30  # start after 30% of steps (in forward time)
    interval: int = 1          # apply every N steps
    tau: float = 1.0           # alpha = sigmoid(log_r / tau)
    alpha_min: float = 0.05
    alpha_max: float = 0.95
    deadzone: float = 0.02     # skip if |log_r| small
    normalize_dx: bool = True  # per-batch std norm of Δz


def low_trianguler_inv(L):
    # L: [bs, 3, 3]
    L_inv = torch.linalg.solve_triangular(L, torch.eye(3).unsqueeze(0).expand_as(L).to(L.device), upper=False)
    return L_inv


class EpsilonNet(nn.Module):

    def __init__(
            self,
            input_size,
            hidden_size,
            n_channel,
            n_layers=3,
            edge_size=0,
            n_rbf=0,
            cutoff=1.0,
            dropout=0.1,
            additional_pos_embed=True
        ):
        super().__init__()
        
        atom_embed_size = hidden_size // 4
        edge_embed_size = hidden_size // 4
        pos_embed_size, seg_embed_size = input_size, input_size
        # enc_input_size = input_size + seg_embed_size + 3 + (pos_embed_size if additional_pos_embed else 0)
        enc_input_size = input_size + 3 + (pos_embed_size if additional_pos_embed else 0)
        self.encoder = AMEGNN(
            enc_input_size, hidden_size, hidden_size, n_channel,
            channel_nf=atom_embed_size, radial_nf=hidden_size,
            in_edge_nf=edge_embed_size + edge_size, n_layers=n_layers, residual=True,
            dropout=dropout, dense=False, n_rbf=n_rbf, cutoff=cutoff)
        self.hidden2input = nn.Linear(hidden_size, input_size)
        # self.pos_embed2latent = nn.Linear(hidden_size, pos_embed_size)
        # self.segment_embedding = nn.Embedding(2, seg_embed_size)
        self.edge_embedding = nn.Embedding(2, edge_embed_size)

    def forward(
            self, H_noisy, X_noisy, position_embedding, ctx_edges, inter_edges,
            atom_embeddings, atom_weights, mask_generate, beta,
            ctx_edge_attr=None, inter_edge_attr=None):
        """
        Args:
            H_noisy: (N, hidden_size)
            X_noisy: (N, 14, 3)
            mask_generate: (N)
            batch_ids: (N)
            beta: (N)
        Returns:
            eps_H: (N, hidden_size)
            eps_X: (N, 14, 3)
        """
        t_embed = torch.stack([beta, torch.sin(beta), torch.cos(beta)], dim=-1)
        # seg_embed = self.segment_embedding(mask_generate.long())
        if position_embedding is None:
            # in_feat = torch.cat([H_noisy, t_embed, seg_embed], dim=-1) # [N, hidden_size * 2 + 3]
            in_feat = torch.cat([H_noisy, t_embed], dim=-1) # [N, hidden_size * 2 + 3]
        else:
            # in_feat = torch.cat([H_noisy, t_embed, self.pos_embed2latent(position_embedding), seg_embed], dim=-1) # [N, hidden_size * 3 + 3]
            in_feat = torch.cat([H_noisy, t_embed, position_embedding], dim=-1) # [N, hidden_size * 3 + 3]
        edges = torch.cat([ctx_edges, inter_edges], dim=-1)
        edge_embed = torch.cat([
            torch.zeros_like(ctx_edges[0]), torch.ones_like(inter_edges[0])
        ], dim=-1)
        edge_embed = self.edge_embedding(edge_embed)
        if ctx_edge_attr is None:
            edge_attr = edge_embed
        else:
            edge_attr = torch.cat([
                edge_embed,
                torch.cat([ctx_edge_attr, inter_edge_attr], dim=0)],
                dim=-1
            ) # [E, embed size + edge_attr_size]
        next_H, next_X = self.encoder(in_feat, X_noisy, edges, ctx_edge_attr=edge_attr, channel_attr=atom_embeddings, channel_weights=atom_weights)

        # equivariant vector features changes
        eps_X = next_X - X_noisy
        eps_X = torch.where(mask_generate[:, None, None].expand_as(eps_X), eps_X, torch.zeros_like(eps_X)) 

        # invariant scalar features changes
        next_H = self.hidden2input(next_H)
        eps_H = next_H - H_noisy
        eps_H = torch.where(mask_generate[:, None].expand_as(eps_H), eps_H, torch.zeros_like(eps_H))

        return eps_H, eps_X


class FullDPM(nn.Module):

    def __init__(
        self, 
        latent_size,
        hidden_size,
        n_channel,
        num_steps, 
        n_layers=3,
        dropout=0.1,
        trans_pos_type='Diffusion',
        trans_seq_type='Diffusion',
        trans_pos_opt={}, 
        trans_seq_opt={},
        n_rbf=0,
        cutoff=1.0,
        std=10.0,
        additional_pos_embed=True,
        dist_rbf=0,
        dist_rbf_cutoff=7.0
    ):
        super().__init__()
        self.eps_net = EpsilonNet(
            latent_size, hidden_size, n_channel, n_layers=n_layers, edge_size=dist_rbf,
            n_rbf=n_rbf, cutoff=cutoff, dropout=dropout, additional_pos_embed=additional_pos_embed)
        if dist_rbf > 0:
            self.dist_rbf = RadialBasis(dist_rbf, dist_rbf_cutoff)
        self.num_steps = num_steps
        self.trans_x = construct_transition(trans_pos_type, num_steps, trans_pos_opt)
        self.trans_h = construct_transition(trans_seq_type, num_steps, trans_seq_opt)

        self.register_buffer('std', torch.tensor(std, dtype=torch.float))

    def _normalize_position(self, X, batch_ids, mask_generate, atom_mask, L=None):
        ctx_mask = (~mask_generate[:, None].expand_as(atom_mask)) & atom_mask
        ctx_mask[:, 0] = 0
        ctx_mask[:, 2:] = 0 # only retain CA
        centers = scatter_mean(X[ctx_mask], batch_ids[:, None].expand_as(ctx_mask)[ctx_mask], dim=0) # [bs, 3]
        centers = centers[batch_ids].unsqueeze(1) # [N, 1, 3]
        if L is None:
            X = (X - centers) / self.std
        else:
            with torch.no_grad():
                L_inv = low_trianguler_inv(L)
                # print(L_inv[0])
            X = X - centers
            X = torch.matmul(L_inv[batch_ids][..., None, :, :], X.unsqueeze(-1)).squeeze(-1)
        return X, centers

    def _unnormalize_position(self, X_norm, centers, batch_ids, L=None):
        if L is None:
            X = X_norm * self.std + centers
        else:
            X = torch.matmul(L[batch_ids][..., None, :, :], X_norm.unsqueeze(-1)).squeeze(-1) + centers
        return X
    
    @torch.no_grad()
    def _get_batch_ids(self, mask_generate, lengths):

        # batch ids
        batch_ids = torch.zeros_like(mask_generate).long()
        batch_ids[torch.cumsum(lengths, dim=0)[:-1]] = 1
        batch_ids.cumsum_(dim=0)

        return batch_ids

    @torch.no_grad()
    def _get_edges(self, mask_generate, batch_ids, lengths):
        row, col = variadic_meshgrid(
            input1=torch.arange(batch_ids.shape[0], device=batch_ids.device),
            size1=lengths,
            input2=torch.arange(batch_ids.shape[0], device=batch_ids.device),
            size2=lengths,
        ) # (row, col)
        
        is_ctx = mask_generate[row] == mask_generate[col]
        is_inter = ~is_ctx
        ctx_edges = torch.stack([row[is_ctx], col[is_ctx]], dim=0) # [2, Ec]
        inter_edges = torch.stack([row[is_inter], col[is_inter]], dim=0) # [2, Ei]
        return ctx_edges, inter_edges
    
    @torch.no_grad()
    def _get_edge_dist(self, X, edges, atom_mask):
        '''
        Args:
            X: [N, 14, 3]
            edges: [2, E]
            atom_mask: [N, 14]
        '''
        ca_x = X[:, 1] # [N, 3]
        no_ca_mask = torch.logical_not(atom_mask[:, 1]) # [N]
        ca_x[no_ca_mask] = X[:, 0][no_ca_mask] # latent coordinates
        dist = torch.norm(ca_x[edges[0]] - ca_x[edges[1]], dim=-1)  # [N]
        return dist

    def forward(self, H_0, X_0, position_embedding, mask_generate, lengths, atom_embeddings, atom_mask, L=None, t=None, sample_structure=True, sample_sequence=True):
        # if L is not None:
        #     L = L / self.std
        batch_ids = self._get_batch_ids(mask_generate, lengths)
        batch_size = batch_ids.max() + 1
        if t == None:
            t = torch.randint(0, self.num_steps + 1, (batch_size,), dtype=torch.long, device=H_0.device)
        X_0, centers = self._normalize_position(X_0, batch_ids, mask_generate, atom_mask, L)

        if sample_structure:
            X_noisy, eps_X = self.trans_x.add_noise(X_0, mask_generate, batch_ids, t)
        else:
            X_noisy, eps_X = X_0, torch.zeros_like(X_0)
        if sample_sequence:
            H_noisy, eps_H = self.trans_h.add_noise(H_0, mask_generate, batch_ids, t)
        else:
            H_noisy, eps_H = H_0, torch.zeros_like(H_0)

        ctx_edges, inter_edges = self._get_edges(mask_generate, batch_ids, lengths)
        if hasattr(self, 'dist_rbf'):
            ctx_edge_attr = self._get_edge_dist(self._unnormalize_position(X_noisy, centers, batch_ids, L), ctx_edges, atom_mask)
            inter_edge_attr = self._get_edge_dist(self._unnormalize_position(X_noisy, centers, batch_ids, L), inter_edges, atom_mask)
            ctx_edge_attr = self.dist_rbf(ctx_edge_attr).view(ctx_edges.shape[1], -1)
            inter_edge_attr = self.dist_rbf(inter_edge_attr).view(inter_edges.shape[1], -1)
        else:
            ctx_edge_attr, inter_edge_attr = None, None

        beta = self.trans_x.get_timestamp(t)[batch_ids]  # [N]
        eps_H_pred, eps_X_pred = self.eps_net(
            H_noisy, X_noisy, position_embedding, ctx_edges, inter_edges, atom_embeddings, atom_mask.float(), mask_generate, beta,
            ctx_edge_attr=ctx_edge_attr, inter_edge_attr=inter_edge_attr)

        loss_dict = {}

        # equivariant vector feature loss, TODO: latent channel
        if sample_structure:
            mask_loss = mask_generate[:, None] & atom_mask
            loss_X = F.mse_loss(eps_X_pred[mask_loss], eps_X[mask_loss], reduction='none').sum(dim=-1)  # (Ntgt * n_latent_channel)
            loss_X = loss_X.sum() / (mask_loss.sum().float() + 1e-8)
            loss_dict['X'] = loss_X
        else:
            loss_dict['X'] = 0

        # invariant scalar feature loss
        if sample_sequence:
            loss_H = F.mse_loss(eps_H_pred[mask_generate], eps_H[mask_generate], reduction='none').sum(dim=-1)  # [N]
            loss_H = loss_H.sum() / (mask_generate.sum().float() + 1e-8)
            loss_dict['H'] = loss_H
        else:
            loss_dict['H'] = 0

        return loss_dict
    
    def _sigma_from_transition(self, trans, t: int, device):
        """
        Best-effort σ_t from your transition; falls back to 1.0 if not available.
        """
        # direct getter
        if hasattr(trans, "get_sigma"):
            s = trans.get_sigma(t)
            return s.to(device).float() if torch.is_tensor(s) else torch.tensor(float(s), device=device)

        # common schedules
        if hasattr(trans, "var_sched"):
            vs = trans.var_sched
            abar = None
            if hasattr(vs, "alphas_cumprod"): abar = vs.alphas_cumprod
            elif hasattr(vs, "alphas2"):      abar = vs.alphas2
            elif hasattr(vs, "alpha_bars"):   abar = vs.alpha_bars
            if abar is not None:
                abar_t = abar[t] if torch.is_tensor(abar) else torch.tensor(float(abar[t]), device=device)
                abar_t = abar_t.to(device).float()
                return torch.sqrt((1.0 - abar_t).clamp_min(1e-12))
            if hasattr(vs, "betas"):
                betas = vs.betas
                betas = betas.to(device).float() if torch.is_tensor(betas) else torch.tensor(betas, device=device).float()
                abar = torch.cumprod(1.0 - betas, dim=0)
                return torch.sqrt((1.0 - abar[t]).clamp_min(1e-12))

        # fallback
        return torch.tensor(1.0, device=device)

    '''
    @torch.no_grad()
    def sample(self, H, X, position_embedding, mask_generate, lengths, atom_embeddings, atom_mask,
        L=None, sample_structure=True, sample_sequence=True, pbar=False, energy_func=None, energy_lambda=0.01
    ):
        """
        Args:
            H: contextual hidden states, (N, latent_size)
            X: contextual atomic coordinates, (N, 14, 3)
            L: cholesky decomposition of the covariance matrix \Sigma=LL^T, (bs, 3, 3)
            energy_func: guide diffusion towards lower energy landscape
        """

        soft_mh = SoftMHConfig()

        # if L is not None:
        #     L = L / self.std
        batch_ids = self._get_batch_ids(mask_generate, lengths)
        X, centers = self._normalize_position(X, batch_ids, mask_generate, atom_mask, L)
        # print(X[0, 0])

        # Set the orientation and position of residues to be predicted to random values
        if sample_structure:
            X_rand = torch.randn_like(X) # [N, 14, 3]
            X_init = torch.where(mask_generate[:, None, None].expand_as(X), X_rand, X)
        else:
            X_init = X

        if sample_sequence:
            H_rand = torch.randn_like(H)
            H_init = torch.where(mask_generate[:, None].expand_as(H), H_rand, H)
        else:
            H_init = H

        # traj = {self.num_steps: (self._unnormalize_position(X_init, centers, batch_ids, L), H_init)}
        traj = {self.num_steps: (X_init, H_init)}
        if pbar:
            pbar = functools.partial(tqdm, total=self.num_steps, desc='Sampling')
        else:
            pbar = lambda x: x
        for t in pbar(range(self.num_steps, 0, -1)):
            X_t, H_t = traj[t]
            # X_t, _ = self._normalize_position(X_t, batch_ids, mask_generate, atom_mask, L)
            X_t, H_t = torch.round(X_t, decimals=4), torch.round(H_t, decimals=4) # reduce numerical error
            # print(t, 'input', X_t[0, 0] * 1000)
            
            # beta = self.trans_x.var_sched.betas[t].view(1).repeat(X_t.shape[0])
            beta = self.trans_x.get_timestamp(t).view(1).repeat(X_t.shape[0])
            t_tensor = torch.full([X_t.shape[0], ], fill_value=t, dtype=torch.long, device=X_t.device)

            ctx_edges, inter_edges = self._get_edges(mask_generate, batch_ids, lengths)
            if hasattr(self, 'dist_rbf'):
                ctx_edge_attr = self._get_edge_dist(self._unnormalize_position(X_t, centers, batch_ids, L), ctx_edges, atom_mask)
                inter_edge_attr = self._get_edge_dist(self._unnormalize_position(X_t, centers, batch_ids, L), inter_edges, atom_mask)
                ctx_edge_attr = self.dist_rbf(ctx_edge_attr).view(ctx_edges.shape[1], -1)
                inter_edge_attr = self.dist_rbf(inter_edge_attr).view(inter_edges.shape[1], -1)
            else:
                ctx_edge_attr, inter_edge_attr = None, None
            eps_H, eps_X = self.eps_net(
                H_t, X_t, position_embedding, ctx_edges, inter_edges, atom_embeddings, atom_mask.float(), mask_generate, beta,
                ctx_edge_attr=ctx_edge_attr, inter_edge_attr=inter_edge_attr)
            if energy_func is not None:
                with torch.enable_grad():
                    cur_X_state = X_t.clone().double()
                    cur_X_state.requires_grad = True
                    energy = energy_func(
                        X=self._unnormalize_position(cur_X_state, centers.double(), batch_ids, L.double()),
                        mask_generate=mask_generate, batch_ids=batch_ids)
                    energy_eps_X = grad([energy], [cur_X_state], create_graph=False, retain_graph=False)[0].float()
                # print(energy_lambda, energy / mask_generate.sum())
                energy_eps_X[~mask_generate] = 0
                energy_eps_X = -energy_eps_X
                # print(t, 'energy', energy_eps_X[mask_generate][0, 0] * 1000)
            else:
                energy_eps_X = None

            # print(t, 'eps X', eps_X[mask_generate][0, 0] * 1000)
            H_prop = self.trans_h.denoise(H_t, eps_H, mask_generate, batch_ids, t_tensor)
            X_prop = self.trans_x.denoise(X_t, eps_X, mask_generate, batch_ids, t_tensor, guidance=energy_eps_X, guidance_weight=energy_lambda)
            # print(t, 'output', X_next[mask_generate][0, 0] * 1000)
            # if t == 90:
            #     aa

            if not sample_structure:
                X_prop = X_t
            if not sample_sequence:
                H_prop = H_t
            
            # do_softmh = (soft_mh.enabled and t > 1)
            # if do_softmh:
            #     f = self.num_steps - t
            #     start_f = int(soft_mh.start_ratio * self.num_steps)
            #     trigger = (f >= start_f) and ((f - start_f) % max(1, soft_mh.interval) == 0)
            # else:
            #     trigger = False
            
            # if trigger:
            #     beta_s = self.trans_x.get_timestamp(t - 1).view(1).repeat(X_t.shape[0])
            #     if hasattr(self, 'dist_rbf'):
            #         ctx_attr_s = self.dist_rbf(self._get_edge_dist(self._unnormalize_position(X_prop, centers, batch_ids, L), ctx_edges,  atom_mask)).view(ctx_edges.shape[1],  -1)
            #         inter_attr_s = self.dist_rbf(self._get_edge_dist(self._unnormalize_position(X_prop, centers, batch_ids, L), inter_edges, atom_mask)).view(inter_edges.shape[1], -1)
            #     else:
            #         ctx_attr_s = inter_attr_s = None

            #     eps_H_s, eps_X_s = self.eps_net(
            #         H_prop, X_prop, position_embedding, ctx_edges, inter_edges,
            #         atom_embeddings, atom_mask.float(), mask_generate, beta_s,
            #         ctx_edge_attr=ctx_attr_s, inter_edge_attr=inter_attr_s
            #     )

            #     # score = -eps / sigma
            #     device = X_t.device
            #     sigma_x_t = self._sigma_from_transition(self.trans_x, t,   device)
            #     sigma_x_s = self._sigma_from_transition(self.trans_x, t-1, device)
            #     sigma_h_t = self._sigma_from_transition(self.trans_h, t,   device)
            #     sigma_h_s = self._sigma_from_transition(self.trans_h, t-1, device)

            #     score_X_t = -eps_X / (sigma_x_t + 1e-8)
            #     score_X_s = -eps_X_s / (sigma_x_s + 1e-8)
            #     score_H_t = -eps_H / (sigma_h_t + 1e-8)
            #     score_H_s = -eps_H_s / (sigma_h_s + 1e-8)
                
            #     use_X = sample_structure and (soft_mh.apply_to in ("coords", "both"))
            #     use_H = sample_sequence  and (soft_mh.apply_to in ("h", "both"))

            #     # masks
            #     node_mask = mask_generate[:, None]                 # [N,1]
            #     mask_H    = node_mask if use_H else None          # [N,1]
            #     mask_X    = (mask_generate[:, None] & atom_mask) if use_X else None  # [N,14]

            #     # per-node contributions
            #     N = X_t.size(0)
            #     # X part
            #     if use_X:
            #         dx_X    = (X_prop - X_t)
            #         s_avg_X = 0.5 * (score_X_t + score_X_s)
            #         mxyz    = mask_X[:, :, None].expand_as(dx_X)
            #         dx_X    = dx_X    * mxyz
            #         s_avg_X = s_avg_X * mxyz
            #         dp_X    = (s_avg_X * dx_X).view(N, -1).sum(dim=1)     # [N]
            #         dx2_X   = (dx_X.pow(2)).view(N, -1).sum(dim=1)        # [N]
            #         dims_X  = mask_X.float().sum(dim=1) * 3               # [N]
            #     else:
            #         dp_X = dx2_X = dims_X = None

            #     # H part
            #     if use_H:
            #         dx_H    = (H_prop - H_t) * mask_H
            #         s_avg_H = (0.5 * (score_H_t + score_H_s)) * mask_H
            #         dp_H    = (s_avg_H * dx_H).view(N, -1).sum(dim=1)     # [N]
            #         dx2_H   = (dx_H.pow(2)).view(N, -1).sum(dim=1)        # [N]
            #         dims_H  = mask_H.float().sum(dim=1)                   # [N]
            #     else:
            #         dp_H = dx2_H = dims_H = None

            #     # combine and reduce per-batch with index_add_
            #     B = int(batch_ids.max().item()) + 1
            #     dp_nodes   = torch.zeros(N, device=device)
            #     dims_nodes = torch.zeros(N, device=device)
            #     dx2_nodes  = torch.zeros(N, device=device)
            #     if dp_X is not None:
            #         dp_nodes   += dp_X
            #         dims_nodes += dims_X
            #         dx2_nodes  += dx2_X
            #     if dp_H is not None:
            #         dp_nodes   += dp_H
            #         dims_nodes += dims_H
            #         dx2_nodes  += dx2_H

            #     log_r_b = torch.zeros(B, device=device); log_r_b.index_add_(0, batch_ids, dp_nodes)
            #     dims_b  = torch.zeros(B, device=device); dims_b.index_add_(0, batch_ids, dims_nodes)
            #     dx2_b   = torch.zeros(B, device=device); dx2_b.index_add_(0, batch_ids, dx2_nodes)

            #     if soft_mh.normalize_dx:
            #         std_b   = torch.sqrt((dx2_b / dims_b.clamp_min(1e-6)).clamp_min(1e-12))
            #         log_r_b = log_r_b / std_b.clamp_min(1e-6)

            #     # soft acceptance per-batch
            #     alpha_b = torch.sigmoid(log_r_b / soft_mh.tau)
            #     alpha_b = torch.clamp(alpha_b, soft_mh.alpha_min, soft_mh.alpha_max)

            #     if soft_mh.deadzone > 0:
            #         keep = (log_r_b.abs() > soft_mh.deadzone).float()
            #         alpha_b = keep * alpha_b + (1.0 - keep) * 1.0  # α=1 -> take proposal
                
            #     alpha_nodes = alpha_b[batch_ids].view(N, 1)
            #     H_blend = alpha_nodes * H_prop + (1.0 - alpha_nodes) * H_t
            #     X_blend = alpha_nodes[:, None, None] * X_prop + (1.0 - alpha_nodes)[:, None, None] * X_t

            #     H_next = torch.where(mask_generate[:, None],        H_blend, H_prop)
            #     X_next = torch.where(mask_generate[:, None, None],  X_blend, X_prop)
            # else:
            #     H_next, X_next = H_prop, X_prop

            use_softmh = soft_mh.enabled and (t > 1)
            if use_softmh:
                # trigger schedule
                f = self.num_steps - t  # forward-time index
                start_f = int(soft_mh.start_ratio * self.num_steps)
                trigger = (f >= start_f) and ((f - start_f) % max(1, soft_mh.interval) == 0)
            else:
                trigger = False

            if not trigger:
                H_next, X_next = H_prop, X_prop
            else:
                # precompute scalars once per step
                dev = X_t.device
                sigma_x_t = self._sigma_from_transition(self.trans_x, t,   dev)
                sigma_h_t = self._sigma_from_transition(self.trans_h, t,   dev)
                sigma_x_s = self._sigma_from_transition(self.trans_x, t-1, dev)
                sigma_h_s = self._sigma_from_transition(self.trans_h, t-1, dev)
                beta_s_scalar = self.trans_x.get_timestamp(t - 1).item()

                # prepare outputs
                H_next = H_prop.clone()
                X_next = X_prop.clone()

                # we already have eps at t (full batch)
                score_X_t_full = (-eps_X) / (sigma_x_t + 1e-8)
                score_H_t_full = (-eps_H) / (sigma_h_t + 1e-8)

                # Cast control
                use_amp  = True if X_t.is_cuda else False
                amp_dtype = torch.float16  # or torch.bfloat16 if supported

                # walk over graphs to keep peak memory low
                offset = 0
                for b, n in enumerate(lengths.tolist()):
                    if n == 0:
                        continue
                    sl = slice(offset, offset + n)
                    offset += n

                    # skip graphs with no generated nodes
                    mgen_b = mask_generate[sl]
                    if not bool(mgen_b.any()):
                        continue

                    # local slices
                    Ht_b, Xt_b = H_t[sl], X_t[sl]
                    Hp_b, Xp_b = H_prop[sl], X_prop[sl]
                    Heps_t_b, Xeps_t_b = eps_H[sl], eps_X[sl]
                    Sx_t_b = score_X_t_full[sl]
                    Sh_t_b = score_H_t_full[sl]
                    atom_mask_b = atom_mask[sl]
                    pos_emb_b = None if (position_embedding is None) else position_embedding[sl]
                    beta_s_b = torch.full((n,), beta_s_scalar, device=dev)

                    # edges for this graph only
                    ctx_b, inter_b = self._get_edges(mgen_b, torch.zeros_like(mgen_b), torch.tensor([n], device=dev))
                    # edge attributes at proposal state (s)
                    if hasattr(self, 'dist_rbf'):
                        Xp_b_unnorm = self._unnormalize_position(Xp_b, centers[sl], batch_ids[sl], L)
                        ctx_attr_s_b   = self.dist_rbf(self._get_edge_dist(Xp_b_unnorm, ctx_b,   atom_mask_b)).view(ctx_b.shape[1],   -1)
                        inter_attr_s_b = self.dist_rbf(self._get_edge_dist(Xp_b_unnorm, inter_b, atom_mask_b)).view(inter_b.shape[1], -1)
                    else:
                        ctx_attr_s_b = inter_attr_s_b = None

                    # ε at s on proposal (half precision to save memory)
                    if use_amp:
                        with torch.autocast(device_type="cuda", dtype=amp_dtype):
                            Heps_s_b, Xeps_s_b = self.eps_net(
                                Hp_b, Xp_b, pos_emb_b, ctx_b, inter_b, atom_embeddings[sl], atom_mask_b.float(), mgen_b, beta_s_b,
                                ctx_edge_attr=ctx_attr_s_b, inter_edge_attr=inter_attr_s_b
                            )
                    else:
                        Heps_s_b, Xeps_s_b = self.eps_net(
                            Hp_b, Xp_b, pos_emb_b, ctx_b, inter_b, atom_embeddings[sl], atom_mask_b.float(), mgen_b, beta_s_b,
                            ctx_edge_attr=ctx_attr_s_b, inter_edge_attr=inter_attr_s_b
                        )
                    del ctx_attr_s_b, inter_attr_s_b  # free ASAP

                    # scores at s (FP32 for stability)
                    Sx_s_b = (-Xeps_s_b).to(torch.float32) / (sigma_x_s + 1e-8)
                    Sh_s_b = (-Heps_s_b).to(torch.float32) / (sigma_h_s + 1e-8)
                    Sx_t_b = Sx_t_b.to(torch.float32)
                    Sh_t_b = Sh_t_b.to(torch.float32)

                    # select what to correct
                    use_X = sample_structure and (soft_mh.apply_to in ("coords", "both"))
                    use_H = sample_sequence  and (soft_mh.apply_to in ("h", "both"))

                    # build masked deltas
                    log_r = Xt_b.new_zeros(())
                    dims  = Xt_b.new_zeros(())

                    if use_X:
                        dxX  = (Xp_b - Xt_b).to(torch.float32)
                        mxyz = (mgen_b[:, None] & atom_mask_b)[:, :, None].expand_as(dxX)
                        dxX  = dxX * mxyz
                        savgX = 0.5 * (Sx_t_b + Sx_s_b) * mxyz
                        log_r = log_r + (savgX * dxX).sum()
                        dims  = dims  + (mgen_b[:, None] & atom_mask_b).float().sum() * 3.0
                    if use_H:
                        dxH   = (Hp_b - Ht_b).to(torch.float32) * mgen_b[:, None]
                        savgH = (0.5 * (Sh_t_b + Sh_s_b)) * mgen_b[:, None]
                        log_r = log_r + (savgH * dxH).sum()
                        dims  = dims  + mgen_b.float().sum()

                    if soft_mh.normalize_dx and dims.item() > 0:
                        # std normalization on Δz magnitude
                        norm = 0.0
                        if use_X: norm += (dxX * dxX).sum().item()
                        if use_H: norm += (dxH * dxH).sum().item()
                        norm = torch.tensor(norm, device=dev).clamp_min(1e-12)
                        std  = torch.sqrt((norm / dims.clamp_min(1e-6)).clamp_min(1e-12))
                        log_r = log_r / std

                    # α (scalar for this graph)
                    alpha = torch.sigmoid(log_r / soft_mh.tau)
                    alpha = torch.clamp(alpha, soft_mh.alpha_min, soft_mh.alpha_max)
                    if soft_mh.deadzone > 0 and torch.abs(log_r) <= soft_mh.deadzone:
                        alpha = alpha.new_tensor(1.0)

                    # blend on generated nodes only (keep context from proposal)
                    if use_H:
                        H_next[sl] = torch.where(mgen_b[:, None], alpha * Hp_b + (1 - alpha) * Ht_b, Hp_b)
                    if use_X:
                        X_next[sl] = torch.where(
                            (mgen_b[:, None, None] & atom_mask_b[:, :, None]),
                            alpha * Xp_b + (1 - alpha) * Xt_b,
                            Xp_b
                        )

            # traj[t-1] = (self._unnormalize_position(X_next, centers, batch_ids, L), H_next)
            traj[t-1] = (X_next, H_next)
            traj[t] = (self._unnormalize_position(traj[t][0], centers, batch_ids, L).cpu(), traj[t][1].cpu())
            # traj[t] = tuple(x.cpu() for x in traj[t])    # Move previous states to cpu memory.
        traj[0] = (self._unnormalize_position(traj[0][0], centers, batch_ids, L), traj[0][1])
        return traj
'''
    @torch.no_grad()
    def sample(self, H, X, position_embedding, mask_generate, lengths, atom_embeddings, atom_mask,
            L=None, sample_structure=True, sample_sequence=True, pbar=False,
            energy_func=None, energy_lambda=0.01):
        """
        Args:
            H: contextual hidden states, (N, latent_size)
            X: contextual atomic coordinates, (N, 14, 3)
            L: cholesky decomposition of the covariance matrix \Sigma=LL^T, (bs, 3, 3)
            energy_func: guide diffusion towards lower energy landscape
        """
        # --- Soft-MH config: keep math, just choose subjects/schedule for co-design ---
        soft_mh = SoftMHConfig()
        co_design = bool(sample_structure and sample_sequence)
        if co_design:
            # Apply the same rule to *both* subjects to avoid X/H asynchrony,
            # and delay/sparsify to preserve early exploration.
            soft_mh.apply_to   = "both"     # <- only change of subject; math unchanged
            soft_mh.start_ratio = max(soft_mh.start_ratio, 0.30)
            soft_mh.interval    = 1
            soft_mh.tau         = max(soft_mh.tau, 2.0)   # gentler sigmoid slope
            soft_mh.alpha_min   = 0.00
            soft_mh.alpha_max   = 1.00
            soft_mh.deadzone    = max(soft_mh.deadzone, 0.08)
        # For structure-only (fixed sequence), keep your original defaults (coords-only).

        batch_ids = self._get_batch_ids(mask_generate, lengths)
        X, centers = self._normalize_position(X, batch_ids, mask_generate, atom_mask, L)

        # Init noisy vars on generated region
        if sample_structure:
            X_rand = torch.randn_like(X)
            X_init = torch.where(mask_generate[:, None, None].expand_as(X), X_rand, X)
        else:
            X_init = X

        if sample_sequence:
            H_rand = torch.randn_like(H)
            H_init = torch.where(mask_generate[:, None].expand_as(H), H_rand, H)
        else:
            H_init = H

        traj = {self.num_steps: (X_init, H_init)}
        if pbar:
            pbar = functools.partial(tqdm, total=self.num_steps, desc='Sampling')
        else:
            pbar = lambda x: x

        for t in pbar(range(self.num_steps, 0, -1)):
            X_t, H_t = traj[t]  # keep full precision; no rounding for co-design sensitivity

            beta = self.trans_x.get_timestamp(t).view(1).repeat(X_t.shape[0])
            t_tensor = torch.full([X_t.shape[0],], fill_value=t, dtype=torch.long, device=X_t.device)

            ctx_edges, inter_edges = self._get_edges(mask_generate, batch_ids, lengths)
            if hasattr(self, 'dist_rbf'):
                ctx_edge_attr = self._get_edge_dist(self._unnormalize_position(X_t, centers, batch_ids, L), ctx_edges, atom_mask)
                inter_edge_attr = self._get_edge_dist(self._unnormalize_position(X_t, centers, batch_ids, L), inter_edges, atom_mask)
                ctx_edge_attr = self.dist_rbf(ctx_edge_attr).view(ctx_edges.shape[1], -1)
                inter_edge_attr = self.dist_rbf(inter_edge_attr).view(inter_edges.shape[1], -1)
            else:
                ctx_edge_attr, inter_edge_attr = None, None

            eps_H, eps_X = self.eps_net(
                H_t, X_t, position_embedding, ctx_edges, inter_edges,
                atom_embeddings, atom_mask.float(), mask_generate, beta,
                ctx_edge_attr=ctx_edge_attr, inter_edge_attr=inter_edge_attr
            )

            if energy_func is not None:
                with torch.enable_grad():
                    cur_X_state = X_t.clone().double()
                    cur_X_state.requires_grad = True
                    energy = energy_func(
                        X=self._unnormalize_position(cur_X_state, centers.double(), batch_ids, L.double()),
                        mask_generate=mask_generate, batch_ids=batch_ids)
                    energy_eps_X = grad([energy], [cur_X_state], create_graph=False, retain_graph=False)[0].float()
                energy_eps_X[~mask_generate] = 0
                energy_eps_X = -energy_eps_X
            else:
                energy_eps_X = None

            H_prop = self.trans_h.denoise(H_t, eps_H, mask_generate, batch_ids, t_tensor)
            X_prop = self.trans_x.denoise(X_t, eps_X, mask_generate, batch_ids, t_tensor,
                                        guidance=energy_eps_X, guidance_weight=energy_lambda)

            if not sample_structure:
                X_prop = X_t
            if not sample_sequence:
                H_prop = H_t

            # ---- Soft-MH (same math), but scheduled later/sparser and applied to chosen subjects ----
            do_softmh = (soft_mh.enabled and t > 1)
            if do_softmh:
                f = self.num_steps - t
                start_f = int(soft_mh.start_ratio * self.num_steps)
                trigger = (f >= start_f) and ((f - start_f) % max(1, soft_mh.interval) == 0)
            else:
                trigger = False

            if not trigger:
                H_next, X_next = H_prop, X_prop
            else:
                dev = X_t.device
                sigma_x_t = self._sigma_from_transition(self.trans_x, t,   dev)
                sigma_h_t = self._sigma_from_transition(self.trans_h, t,   dev)
                sigma_x_s = self._sigma_from_transition(self.trans_x, t-1, dev)
                sigma_h_s = self._sigma_from_transition(self.trans_h, t-1, dev)
                beta_s_scalar = self.trans_x.get_timestamp(t - 1).item()

                H_next = H_prop.clone()
                X_next = X_prop.clone()

                score_X_t_full = (-eps_X) / (sigma_x_t + 1e-8)
                score_H_t_full = (-eps_H) / (sigma_h_t + 1e-8)

                use_amp  = True if X_t.is_cuda else False
                amp_dtype = torch.float16

                offset = 0
                for b, n in enumerate(lengths.tolist()):
                    if n == 0: continue
                    sl = slice(offset, offset + n); offset += n

                    mgen_b = mask_generate[sl]
                    if not bool(mgen_b.any()): continue

                    Ht_b, Xt_b = H_t[sl], X_t[sl]
                    Hp_b, Xp_b = H_prop[sl], X_prop[sl]
                    Sx_t_b = score_X_t_full[sl]
                    Sh_t_b = score_H_t_full[sl]
                    atom_mask_b = atom_mask[sl]
                    pos_emb_b = None if (position_embedding is None) else position_embedding[sl]
                    beta_s_b = torch.full((n,), beta_s_scalar, device=dev)

                    ctx_b, inter_b = self._get_edges(mgen_b, torch.zeros_like(mgen_b), torch.tensor([n], device=dev))

                    if hasattr(self, 'dist_rbf'):
                        Xp_b_unnorm = self._unnormalize_position(Xp_b, centers[sl], batch_ids[sl], L)
                        ctx_attr_s_b   = self.dist_rbf(self._get_edge_dist(Xp_b_unnorm, ctx_b,   atom_mask_b)).view(ctx_b.shape[1],   -1)
                        inter_attr_s_b = self.dist_rbf(self._get_edge_dist(Xp_b_unnorm, inter_b, atom_mask_b)).view(inter_b.shape[1], -1)
                    else:
                        ctx_attr_s_b = inter_attr_s_b = None

                    if use_amp:
                        with torch.autocast(device_type="cuda", dtype=amp_dtype):
                            Heps_s_b, Xeps_s_b = self.eps_net(
                                Hp_b, Xp_b, pos_emb_b, ctx_b, inter_b, atom_embeddings[sl], atom_mask_b.float(), mgen_b, beta_s_b,
                                ctx_edge_attr=ctx_attr_s_b, inter_edge_attr=inter_attr_s_b
                            )
                    else:
                        Heps_s_b, Xeps_s_b = self.eps_net(
                            Hp_b, Xp_b, pos_emb_b, ctx_b, inter_b, atom_embeddings[sl], atom_mask_b.float(), mgen_b, beta_s_b,
                            ctx_edge_attr=ctx_attr_s_b, inter_edge_attr=inter_attr_s_b
                        )
                    del ctx_attr_s_b, inter_attr_s_b

                    Sx_s_b = (-Xeps_s_b).to(torch.float32) / (sigma_x_s + 1e-8)
                    Sh_s_b = (-Heps_s_b).to(torch.float32) / (sigma_h_s + 1e-8)
                    Sx_t_b = Sx_t_b.to(torch.float32)
                    Sh_t_b = Sh_t_b.to(torch.float32)

                    use_X = sample_structure and (soft_mh.apply_to in ("coords", "both"))
                    use_H = sample_sequence  and (soft_mh.apply_to in ("h", "both"))

                    # log_r = Xt_b.new_zeros(())
                    # dims  = Xt_b.new_zeros(())

                    # if use_X:
                    #     dxX  = (Xp_b - Xt_b).to(torch.float32)
                    #     mxyz = (mgen_b[:, None] & atom_mask_b)[:, :, None].expand_as(dxX)
                    #     dxX  = dxX * mxyz
                    #     savgX = 0.5 * (Sx_t_b + Sx_s_b) * mxyz
                    #     log_r = log_r + (savgX * dxX).sum()
                    #     dims  = dims  + (mgen_b[:, None] & atom_mask_b).float().sum() * 3.0
                    # if use_H:
                    #     dxH   = (Hp_b - Ht_b).to(torch.float32) * mgen_b[:, None]
                    #     savgH = (0.5 * (Sh_t_b + Sh_s_b)) * mgen_b[:, None]
                    #     log_r = log_r + (savgH * dxH).sum()
                    #     dims  = dims  + mgen_b.float().sum()

                    log_r_X = Xt_b.new_zeros(())
                    dims_X  = Xt_b.new_zeros(())

                    if use_X:
                        dxX  = (Xp_b - Xt_b).to(torch.float32)
                        mxyz = (mgen_b[:, None] & atom_mask_b)[:, :, None].expand_as(dxX)
                        dxX  = dxX * mxyz
                        savgX = 0.5 * (Sx_t_b + Sx_s_b) * mxyz
                        log_r_X = (savgX * dxX).sum()
                        dims_X  = (mgen_b[:, None] & atom_mask_b).float().sum() * 3.0

                    # === STRUCTURE-DRIVEN ACCEPTANCE ===
                    log_r = log_r_X
                    dims  = dims_X

                    # Optional: keep normalization but only w.r.t. X
                    if soft_mh.normalize_dx and dims.item() > 0:
                        norm = (dxX * dxX).sum().item() if use_X else 0.0
                        norm = torch.tensor(norm, device=dev).clamp_min(1e-12)
                        std  = torch.sqrt((norm / dims.clamp_min(1e-6)).clamp_min(1e-12))
                        log_r = log_r / std

                    # if soft_mh.normalize_dx and dims.item() > 0:
                    #     norm = 0.0
                    #     if use_X: norm += (dxX * dxX).sum().item()
                    #     if use_H: norm += (dxH * dxH).sum().item()
                    #     norm = torch.tensor(norm, device=dev).clamp_min(1e-12)
                    #     std  = torch.sqrt((norm / dims.clamp_min(1e-6)).clamp_min(1e-12))
                    #     log_r = log_r / std

                    alpha = torch.sigmoid(log_r / soft_mh.tau)
                    alpha = torch.clamp(alpha, soft_mh.alpha_min, soft_mh.alpha_max)
                    if soft_mh.deadzone > 0 and torch.abs(log_r) <= soft_mh.deadzone:
                        alpha = alpha.new_tensor(1.0)

                    if use_H:
                        H_next[sl] = torch.where(mgen_b[:, None], alpha * Hp_b + (1 - alpha) * Ht_b, Hp_b)
                    if use_X:
                        X_next[sl] = torch.where(
                            (mgen_b[:, None, None] & atom_mask_b[:, :, None]),
                            alpha * Xp_b + (1 - alpha) * Xt_b,
                            Xp_b
                        )

            # store
            traj[t-1] = (X_next, H_next)
            traj[t] = (self._unnormalize_position(traj[t][0], centers, batch_ids, L).cpu(), traj[t][1].cpu())

        traj[0] = (self._unnormalize_position(traj[0][0], centers, batch_ids, L), traj[0][1])
        return traj
