import torch
import torch.nn as nn
from models import register_model
from models.modules import FeatNet, PropagationLayer, CrossAttentionLayer,GAT

class MPBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.msg_pass = config.msgpass_mechanism
        # intra-surface propagation
        if self.msg_pass=='harmonic':
            self.propagation_layers = nn.ModuleList([PropagationLayer(config) \
                                                    for _ in range(config.num_propagation_layers)])
        else:
            self.propagation_layers = GAT(node_input_dim = config.h_dim,
                                          output_dim = config.h_dim,
                                          num_step_message_passing = config.num_propagation_layers)

        # inter-surface cross attention
        self.cross_attn_layers = nn.ModuleList([CrossAttentionLayer(config) \
                                                for _ in range(config.num_cross_attn_layers)])

    def forward(self, lig_h, rec_h, lig_dict, rec_dict):

        # intra-surface propagation
        if self.msg_pass=='harmonic':
            for ilayer in self.propagation_layers:
                lig_h = ilayer.forward(h=lig_h, feat_dict=lig_dict)
                rec_h = ilayer.forward(h=rec_h, feat_dict=rec_dict)
        elif self.msg_pass == 'graph':
            lig_h = self.propagation_layers.forward(x = lig_h, edge_index = lig_dict['edge'].to(torch.long))
            rec_h = self.propagation_layers.forward(x = rec_h, edge_index = rec_dict['edge'].to(torch.long))
        else:
            lig_h = self.propagation_layers.forward(x = lig_h, edge_index = None)
            rec_h = self.propagation_layers.forward(x = rec_h, edge_index = None)
        # inter-surface attention
        lig_attn, rec_attn = None, None
        for jlayer in self.cross_attn_layers:
            lig_out, lig_attn = jlayer.forward(src_h=rec_h, dst_h=lig_h, src_batch_idx=rec_dict['num_verts'], dst_batch_idx=lig_dict['num_verts'])
            rec_out, rec_attn = jlayer.forward(src_h=lig_h, dst_h=rec_h, src_batch_idx=lig_dict['num_verts'], dst_batch_idx=rec_dict['num_verts'])
            lig_h = lig_out
            rec_h = rec_out

        return lig_h, rec_h, lig_attn, rec_attn


@register_model
class PuzzleDock(nn.Module):
    def __init__(self, config):
        super().__init__()
        h_dim = config.h_dim
        dropout = config.dropout
        self.msg_pass = config.msgpass_mechanism
        # initialize features
        self.feat_init = FeatNet(config)
        
        # message passing
        self.message_passing_blocks = nn.ModuleList([MPBlock(config) \
                                                     for _ in range(config.num_message_passing_blocks)])
        
        # smoothing
        if self.msg_pass =='harmonic':
            self.smoothing_layers = nn.ModuleList([PropagationLayer(config) \
                                                for _ in range(config.num_smoothing_layers)])
        else:
            self.smoothing_layers = GAT(node_input_dim = config.h_dim,
                                          output_dim = config.h_dim,
                                          num_step_message_passing = config.num_smoothing_layers)
        
        # binding site prediction
        self.bsp = nn.Sequential(
            nn.Linear(h_dim, h_dim),
            nn.Dropout(dropout),
            nn.BatchNorm1d(h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, 1)
        )

    def forward(self, feat_dict):
        lig_dict = feat_dict['lig_dict']
        rec_dict = feat_dict['rec_dict']

        # feature extraction
        lig_h = self.feat_init(feat_dict=lig_dict)
        rec_h = self.feat_init(feat_dict=rec_dict)

        # propagation
        lig_attn, rec_attn = None, None
        for block in self.message_passing_blocks:
            lig_h, rec_h, lig_attn, rec_attn = block.forward(lig_h=lig_h, rec_h=rec_h, lig_dict=lig_dict, rec_dict=rec_dict)

        # smoothing
        if self.msg_pass =='harmonic':
            for layer in self.smoothing_layers:
                lig_h = layer.forward(h=lig_h, feat_dict=lig_dict)
                rec_h = layer.forward(h=rec_h, feat_dict=rec_dict)
        elif self.msg_pass == 'graph':
            lig_h = self.smoothing_layers.forward(x = lig_h, edge_index = lig_dict['edge'].to(torch.long))
            rec_h = self.smoothing_layers.forward(x = rec_h, edge_index = rec_dict['edge'].to(torch.long))
        else:
            lig_h = self.smoothing_layers.forward(x = lig_h, edge_index = None)
            rec_h = self.smoothing_layers.forward(x = rec_h, edge_index = None)
            
            
        lig_dict['h'] = lig_h
        rec_dict['h'] = rec_h

        lig_dict['attn'] = lig_attn
        rec_dict['attn'] = rec_attn

        # binding site prediction
        lig_dict['bsp'] = self.bsp(lig_h)
        rec_dict['bsp'] = self.bsp(rec_h)
        
        return feat_dict


