import torch
import torch.nn as nn
import torch.nn.functional as F
import functools
from tqdm.auto import tqdm

from interfacediff.modules.common.geometry import quaternion_1ijk_to_rotation_matrix
from interfacediff.modules.common.so3 import so3vec_to_rotation, rotation_to_so3vec, random_uniform_so3
from interfacediff.modules.encoders.ga import GAInterfacePipeline
from .transition import RotationTransition, PositionTransition, AminoacidCategoricalTransition


def _counts_from_ptr_or_batch(data):
    if hasattr(data, "ptr") and data.ptr is not None:
        return (data.ptr[1:] - data.ptr[:-1]).to(torch.long)
    else:
        num_graphs = int(data.batch.max().item()) + 1 if data.batch.numel() > 0 else 0
        return torch.bincount(data.batch, minlength=num_graphs).to(torch.long)

def _split_iface_tensor_to_chains(tensor, g1, g2, gi):
    n1_per = _counts_from_ptr_or_batch(g1)   
    n2_per = _counts_from_ptr_or_batch(g2)   
    ni_per = _counts_from_ptr_or_batch(gi)   

    B = n1_per.numel()
    device = tensor.device
    off1 = torch.cat([torch.zeros(1, device=device, dtype=torch.long), torch.cumsum(n1_per, dim=0)])
    off2 = torch.cat([torch.zeros(1, device=device, dtype=torch.long), torch.cumsum(n2_per, dim=0)])
    offi = torch.cat([torch.zeros(1, device=device, dtype=torch.long), torch.cumsum(ni_per, dim=0)])

    shape_tail = tuple(tensor.shape[1:])
    t1 = torch.empty((int(n1_per.sum().item()),) + shape_tail, dtype=tensor.dtype, device=device)
    t2 = torch.empty((int(n2_per.sum().item()),) + shape_tail, dtype=tensor.dtype, device=device)

    for i in range(B):
        n1_i = int(n1_per[i].item())
        n2_i = int(n2_per[i].item())

        s_i, e_i   = int(offi[i].item()),  int(offi[i+1].item())
        s1,  e1    = int(off1[i].item()),  int(off1[i+1].item())
        s2,  e2    = int(off2[i].item()),  int(off2[i+1].item())

        si_c1, ei_c1 = s_i,        s_i + n1_i
        si_c2, ei_c2 = s_i+n1_i,   s_i + n1_i + n2_i

        t1[s1:e1] = tensor[si_c1:ei_c1]
        t2[s2:e2] = tensor[si_c2:ei_c2]

    return t1.contiguous(), t2.contiguous()

def split_interface_to_chains(batch,
                              fields=("R_t", "trans_t", "node_embed"),
                              delete_from_interface=False):
    g1 = batch["chain_1_graph"]
    g2 = batch["chain_2_graph"]
    gi = batch["interface_graph"]

    for f in fields:
        if not hasattr(gi, f):
            continue
        src = getattr(gi, f)  
        t1, t2 = _split_iface_tensor_to_chains(src, g1, g2, gi)
        setattr(g1, f, t1)
        setattr(g2, f, t2)

        if delete_from_interface:
            delattr(gi, f)

    for f in ("R", "t", "rot_t", "seq", "ss_embed", "seq_t"):
        if delete_from_interface and hasattr(gi, f):
            delattr(gi, f)

    return batch

def rotation_matrix_cosine_loss(R_pred, R_true):
    
    size = list(R_pred.shape[:-2])
    RT_pred = R_pred.transpose(-2, -1).reshape(-1, 3)  
    RT_true = R_true.transpose(-2, -1).reshape(-1, 3)  
    ones = torch.ones([RT_pred.size(0)], dtype=torch.long, device=R_pred.device)
    loss = F.cosine_embedding_loss(RT_pred, RT_true, ones, reduction='none')  
    loss = loss.view(*size, 3).sum(dim=-1)  
    return loss


class UpdateNet(nn.Module):

    def __init__(self, cfg):
        super().__init__()
        self.embedding_dim = cfg.get('embedding_dim', 64)
        self.num_layers = cfg.get('num_layers', 6)

        self.seq_embedding = nn.Embedding(20, self.embedding_dim)
        self.node_feat_mixer = nn.Sequential(
            nn.Linear(self.embedding_dim * 2, self.embedding_dim), nn.SiLU(),
            nn.Linear(self.embedding_dim, self.embedding_dim),
        )
        self.graph_encoder = GAInterfacePipeline(
                            node_feat_dim=64,
                            pair_feat_dim=64,
                            num_layers=2,
                            ga_block_opt=dict(
                                value_dim=32, query_key_dim=32,
                                num_query_points=8, num_value_points=8,
                                num_heads=12, bias=False
                            ),
                        )
        self.rot_decoder = nn.Sequential(
            nn.Linear(self.embedding_dim+3, self.embedding_dim), nn.SiLU(),
            nn.Linear(self.embedding_dim, self.embedding_dim), nn.SiLU(),
            nn.Linear(self.embedding_dim, 3)
        )

        self.trans_decoder = nn.Sequential(
            nn.Linear(self.embedding_dim+3, self.embedding_dim), nn.SiLU(),
            nn.Linear(self.embedding_dim, self.embedding_dim), nn.SiLU(),
            nn.Linear(self.embedding_dim, 3)
        )

        self.seq_decoder = nn.Sequential(
            nn.Linear(self.embedding_dim+3, self.embedding_dim), nn.SiLU(),
            nn.Linear(self.embedding_dim, self.embedding_dim), nn.SiLU(),
            nn.Linear(self.embedding_dim, 20), nn.Softmax(dim=-1) 
        )

    def forward(self, batch, beta):
        ig = batch['interface_graph']

        R_t = so3vec_to_rotation(ig.rot_t)
        seq_t = ig.seq_t
        node_embed = ig.node_embed
        ss_embed = ig.ss_embed

        node_embed = self.node_feat_mixer(torch.cat([node_embed, self.seq_embedding(seq_t)], dim=-1))
        node_embed_condition = node_embed + ss_embed
        ig.node_embed = node_embed_condition
        ig.R_t = R_t

        batch = split_interface_to_chains(batch,
                                  fields=("R_t", "trans_t", "node_embed"),
                                  delete_from_interface=False)
        x_updated = self.graph_encoder(batch)  

        iface = batch['interface_graph']
        R_t_nodes = iface.R_t              
        t_nodes   = iface.trans_t          
        gidx      = iface.batch            
        G         = iface.ptr.numel() - 1 
        Ni        = x_updated.size(0)

        if isinstance(beta, float) or (torch.is_tensor(beta) and beta.ndim == 0) or (torch.is_tensor(beta) and beta.numel() == 1):
            beta_nodes = torch.full((Ni,), float(beta), device=x_updated.device, dtype=x_updated.dtype)
        elif torch.is_tensor(beta) and beta.ndim == 1 and beta.numel() == G:
            beta = beta.to(device=x_updated.device, dtype=x_updated.dtype)
            beta_nodes = beta[gidx]  
        elif torch.is_tensor(beta) and beta.ndim == 1 and beta.numel() == Ni:
            beta_nodes = beta.to(device=x_updated.device, dtype=x_updated.dtype)
        else:
            raise ValueError(f"Unexpected beta shape: {getattr(beta,'shape',None)}; expected scalar, (G,), or (Ni,)")

        t_embed_nodes = torch.stack([beta_nodes,
                                    torch.sin(beta_nodes),
                                    torch.cos(beta_nodes)], dim=-1)  

        in_feat = torch.cat([x_updated, t_embed_nodes], dim=-1) 

        noise_rot_pred = self.rot_decoder(in_feat)                      
        U = quaternion_1ijk_to_rotation_matrix(noise_rot_pred)          
        R_pred = torch.matmul(R_t_nodes, U)                             
        rot_pred = rotation_to_so3vec(R_pred)                           

        eps_crd = self.trans_decoder(in_feat)                           
        noise_trans_pred = torch.einsum('nij,nj->ni', R_t_nodes, eps_crd)  

        seq_pred = self.seq_decoder(in_feat)                            

        return rot_pred, R_pred, noise_trans_pred, seq_pred


class FullDPM(nn.Module):

    def __init__(self, cfg):
        super().__init__()
        self.update_net = UpdateNet(cfg['graph_encoder'])
        self.num_steps = cfg.get('num_steps', 500)
        pm = cfg.get('position_mean', [0.0, 0.0, 0.0])
        ps = cfg.get('position_scale', [10.0])
        self.embedding_mode = cfg.get('embedding_mode', 'complex')
        self.design_mode = cfg.get('design_mode', 'codesign')

        self.trans_rot = RotationTransition(self.num_steps)
        self.trans_pos = PositionTransition(self.num_steps)
        self.trans_seq = AminoacidCategoricalTransition(self.num_steps, cfg)

        self.register_buffer('position_mean', torch.as_tensor(pm, dtype=torch.float).view(1, -1))
        self.register_buffer('position_scale', torch.as_tensor(ps, dtype=torch.float).view(1, -1))

    def scale_down(self, p):
        p_norm = (p - self.position_mean) / self.position_scale
        return p_norm

    def scale_up(self, p_norm):
        p = p_norm * self.position_scale + self.position_mean
        return p

    
    def _finalize_chain_graph(self, g):
        out = g.clone()
        for k in ('pos_heavyatom', 'mask_heavyatom', 'x'):
            if hasattr(out, k):
                delattr(out, k)

        return out

    def _finalize_interface_graph(self, g, rot_t, trans_t, seq_t):
        out = g.clone()
        for k in ('pos_heavyatom', 'mask_heavyatom'):
            if hasattr(out, k):
                delattr(out, k)
        out.rot_t = rot_t       
        out.trans_t = trans_t   
        out.seq_t = seq_t       

        return out

    def forward(self, batch):
        B = len(batch['id'])
        chain_1_graph = batch['chain_1_graph']
        chain_2_graph = batch['chain_2_graph']
        interface_graph = batch['interface_graph']

        batch_A = chain_1_graph.batch  
        batch_B = chain_2_graph.batch  
        batch_inter = interface_graph.batch  

        R_A_0 = chain_1_graph.R.clone()  
        R_B_0 = chain_2_graph.R.clone()  
        R_inter_0 = interface_graph.R.clone()  

        
        
        
        rot_inter_0 = rotation_to_so3vec(R_inter_0)  

        trans_A_0 = self.scale_down(batch['chain_1_graph'].t)  
        trans_B_0 = self.scale_down(batch['chain_2_graph'].t)  
        trans_inter_0 = self.scale_down(batch['interface_graph'].t)  

        
        
        seq_inter_0 = batch['interface_graph'].seq  
        

        device = self.position_mean.device
        
        t = torch.randint(0, self.num_steps, (B,), dtype=torch.long, device=device)

        if self.embedding_mode == 'complex':
            if self.design_mode=='codesign':
                
                rot_inter_t, _ = self.trans_rot.add_noise(rot_inter_0, batch_inter, t)

                
                trans_inter_t, noise_trans = self.trans_pos.add_noise(trans_inter_0, batch_inter, t)

                
                
                seq_probs, seq_inter_t = self.trans_seq.add_noise(seq_inter_0, batch_inter, t)

                
                chain_1_final = self._finalize_chain_graph(chain_1_graph)
                chain_2_final = self._finalize_chain_graph(chain_2_graph)
                interface_final = self._finalize_interface_graph(interface_graph, rot_inter_t, trans_inter_t, seq_inter_t)

                
                updated_batch = {
                    'id': batch['id'],
                    'chain_1_graph': chain_1_final,   
                    'chain_2_graph': chain_2_final,   
                    'interface_graph': interface_final 
                }
                
                
                beta = self.trans_pos.var_sched.betas[t]
                rot_pred, R_pred, noise_trans_pred, seq_pred = self.update_net(updated_batch, beta)

                R_A_pred, R_B_pred = _split_iface_tensor_to_chains(R_pred, chain_1_final, chain_2_final, interface_final)                 
                noise_trans_A_pred, noise_trans_B_pred = _split_iface_tensor_to_chains(noise_trans_pred, chain_1_final, chain_2_final, interface_final)  
                seq_A_pred, seq_B_pred = _split_iface_tensor_to_chains(seq_pred, chain_1_final, chain_2_final, interface_final)           

                noise_trans_A, noise_trans_B = _split_iface_tensor_to_chains(noise_trans, chain_1_final, chain_2_final, interface_final)  
                seq_A_t, seq_B_t = _split_iface_tensor_to_chains(interface_final.seq_t, chain_1_final, chain_2_final, interface_final)
                seq_A_0, seq_B_0 = _split_iface_tensor_to_chains(seq_inter_0, chain_1_final, chain_2_final, interface_final)

                loss_dict = {}

                loss_rot_A = rotation_matrix_cosine_loss(R_A_pred, R_A_0).mean()
                loss_rot_B = rotation_matrix_cosine_loss(R_B_pred, R_B_0).mean()
                loss_rot_inter = rotation_matrix_cosine_loss(R_pred, R_inter_0).mean()
                loss_rot = (loss_rot_A + loss_rot_B + loss_rot_inter) / 3
                loss_dict['rot'] = loss_rot

                loss_pos_A = F.mse_loss(noise_trans_A_pred, noise_trans_A, reduction='none').sum(dim=-1)
                loss_pos_B = F.mse_loss(noise_trans_B_pred, noise_trans_B, reduction='none').sum(dim=-1)
                loss_pos_inter = F.mse_loss(noise_trans_pred, noise_trans, reduction='none').sum(dim=-1)
                loss_pos = (loss_pos_A.mean() + loss_pos_B.mean() + loss_pos_inter.mean()) / 3
                loss_dict['pos'] = loss_pos

                post_true_inter = self.trans_seq._blosum_posterior(seq_inter_t, seq_inter_0, batch_inter, t)
                log_post_inter_pred = torch.log(self.trans_seq._blosum_posterior(seq_inter_t, seq_pred, batch_inter, t) + 1e-8)
                kl_inter = F.kl_div(
                    input=log_post_inter_pred, 
                    target=post_true_inter, 
                    reduction='none',
                    log_target=False
                ).sum(dim=-1)
                

                loss_dict['seq'] = kl_inter

        return loss_dict

    @torch.no_grad()
    def sample(
        self, 
        v, p, s, 
        res_feat, pair_feat, 
        mask_generate, mask_res, 
        sample_structure=True, sample_sequence=True,
        pbar=False,
    ):
        
        N, L = v.shape[:2]
        p = self.scale_down(p)

        
        if sample_structure:
            v_rand = random_uniform_so3([N, L], device=self._dummy.device)
            p_rand = torch.randn_like(p)
            v_init = torch.where(mask_generate[:, :, None].expand_as(v), v_rand, v)
            p_init = torch.where(mask_generate[:, :, None].expand_as(p), p_rand, p)
        else:
            v_init, p_init = v, p

        if sample_sequence:
            s_rand = torch.randint_like(s, low=0, high=19)
            s_init = torch.where(mask_generate, s_rand, s)
        else:
            s_init = s

        traj = {self.num_steps: (v_init, self.scale_up(p_init), s_init)}
        if pbar:
            pbar = functools.partial(tqdm, total=self.num_steps, desc='Sampling')
        else:
            pbar = lambda x: x
        for t in pbar(range(self.num_steps, 0, -1)):
            v_t, p_t, s_t = traj[t]
            p_t = self.scale_down(p_t)
            
            beta = self.trans_pos.var_sched.betas[t].expand([N, ])
            t_tensor = torch.full([N, ], fill_value=t, dtype=torch.long, device=self._dummy.device)

            v_next, R_next, eps_p, c_denoised = self.eps_net(
                v_t, p_t, s_t, res_feat, pair_feat, beta, mask_generate, mask_res
            )   

            v_next = self.trans_rot.denoise(v_t, v_next, mask_generate, t_tensor)
            p_next = self.trans_pos.denoise(p_t, eps_p, mask_generate, t_tensor)
            _, s_next = self.trans_seq.denoise(s_t, c_denoised, mask_generate, t_tensor)

            if not sample_structure:
                v_next, p_next = v_t, p_t
            if not sample_sequence:
                s_next = s_t

            traj[t-1] = (v_next, self.scale_up(p_next), s_next)
            traj[t] = tuple(x.cpu() for x in traj[t])    

        return traj
