import torch
import torch.nn as nn

from interfacediff.modules.common.so3 import rotation_to_so3vec
from interfacediff.modules.encoders.node import NodeEncoder
from interfacediff.modules.encoders.edge import ChainEdgeEncoder, InterfaceEdgeEncoder
from interfacediff.modules.diffusion.dpm_full import FullDPM

class InterFaceDiff(nn.Module):

    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.embedding_mode = cfg.get('embedding_mode', 'complex')

        self.node_embed = NodeEncoder(cfg["model"]["node_encoder"])
        self.chain_edge_embed = ChainEdgeEncoder(cfg["model"]["chain_edge_encoder"])
        self.interface_edge_embed = InterfaceEdgeEncoder(cfg["model"]["interface_edge_encoder"])

        self.diffusion = FullDPM(cfg["model"]["diffusion"])

    def _encode_chain(self, g):
        gx = g.clone()
        
        edge_attr = self.chain_edge_embed(gx.edge_attr)             
        gx.edge_attr = edge_attr

        return gx

    def _encode_interface(self, g):
        gi = g.clone()
        seq, ss_embed, node_embed = self.node_embed(gi)             
        edge_attr = self.interface_edge_embed(gi.edge_attr)      
        gi.seq = seq
        gi.ss_embed = ss_embed
        gi.node_embed = node_embed
        gi.edge_attr = edge_attr
        if hasattr(gi, 'x'):
            delattr(gi, 'x')
        return gi
    
    def forward(self, batch):
        chain_1_raw = batch['chain_1_graph']
        chain_2_raw = batch['chain_2_graph']
        inter_raw   = batch['interface_graph']

        if self.embedding_mode == 'complex':
            chain_1_graph = self._encode_chain(chain_1_raw)
            chain_2_graph = self._encode_chain(chain_2_raw)
            interface_graph = self._encode_interface(inter_raw)

        elif self.embedding_mode == 'binder':
            chain_1_graph = chain_1_raw.clone()
            edge_attr_A = self.chain_edge_embed(chain_1_graph.edge_attr)
            chain_1_graph.edge_attr = edge_attr_A

            chain_2_graph = self._encode_chain(chain_2_raw)
            interface_graph = self._encode_interface(inter_raw)

        else:
            raise ValueError(f"Unknown embedding_mode: {self.embedding_mode}")

        encoded_batch = {
            'id': batch['id'],
            'chain_1_graph': chain_1_graph,
            'chain_2_graph': chain_2_graph,
            'interface_graph': interface_graph,
        }

        loss_dict = self.diffusion(encoded_batch)
        return loss_dict

    @torch.no_grad()
    def sample(
        self, 
        batch, 
        sample_opt={
            'sample_structure': True,
            'sample_sequence': True,
        }
    ):
        mask_generate = batch['generate_flag']
        mask_res = batch['mask']
        node_feat, edge_feat, R_0, p_0 = self.encode(batch)

        v_0 = rotation_to_so3vec(R_0)
        s_0 = batch['aa']
        traj = self.diffusion.sample(v_0, p_0, s_0, node_feat, edge_feat, mask_generate, mask_res, **sample_opt)
        return traj
