import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import copy

from bindgen.utils import *
from bindgen.nnutils import *
from bindgen.data import ALPHABET, ATOM_TYPES
from bindgen.protein_features import ProteinFeatures


class EGNNEncoder(nn.Module):
    
    def __init__(self, args, node_hdim=0, features_type='backbone', update_X=True):
        super(EGNNEncoder, self).__init__()
        self.update_X = update_X
        self.features_type = features_type
        self.features = ProteinFeatures(
                top_k=args.k_neighbors, num_rbf=args.num_rbf,
                features_type=features_type,
                direction='bidirectional'
        )
        self.node_in, self.edge_in = self.features.feature_dimensions[features_type]
        self.node_in += node_hdim
        
        self.W_v = nn.Linear(self.node_in, args.hidden_size)
        self.W_e = nn.Linear(self.edge_in, args.hidden_size)
        self.layers = nn.ModuleList([
                MPNNLayer(args.hidden_size, args.hidden_size * 3, dropout=args.dropout)
                for _ in range(args.depth)
        ])
        if self.update_X:
            self.W_x = nn.Linear(args.hidden_size, args.hidden_size)
            self.U_x = nn.Linear(args.hidden_size, args.hidden_size)
            self.T_x = nn.Sequential(nn.ReLU(), nn.Linear(args.hidden_size, 14))

        for param in self.parameters():
            if param.dim() > 1:
                nn.init.xavier_uniform_(param)

    # [backbone] X: [B,N,L,3], V/S: [B,N,H], A: [B,N,L]
    # [atom] X: [B,N*L,3], V/S: [B,N*L,H], A: [B,N*L]
    def forward(self, X, V, S, A):
        mask = A.clamp(max=1).float()
        vmask = mask[:,:,1] if self.features_type == 'backbone' else mask
        _, E, E_idx = self.features(X, vmask)

        h = self.W_v(V)    # [B, N, H] 
        h_e = self.W_e(E)  # [B, N, K, H] 
        nei_s = gather_nodes(S, E_idx)  # [B, N, K, H]
        emask = gather_nodes(vmask[...,None], E_idx).squeeze(-1)

        # message passing
        for layer in self.layers:
            nei_v = gather_nodes(h, E_idx)  # [B, N, K, H]
            nei_h = torch.cat([nei_v, nei_s, h_e], dim=-1)
            #h_input = h
            h = layer(h, nei_h, mask_attend=emask)  # [B, N, H]
            #h = h_input + h
            h = h * vmask.unsqueeze(-1)  # [B, N, H]

        if self.update_X and self.features_type == 'backbone':
            ca_mask = mask[:,:,1]  # [B, N]
            mij = self.W_x(h).unsqueeze(2) + self.U_x(h).unsqueeze(1)  # [B,N,N,H]
            xij = X.unsqueeze(2) - X.unsqueeze(1)  # [B,N,N,L,3]
            xij = xij * self.T_x(mij).unsqueeze(-1)  # [B,N,N,L,3]
            f = torch.sum(xij * ca_mask[:,None,:,None,None], dim=2)  # [B,N,N,L,3] * [B,1,N,1,1]
            f = f / (1e-6 + ca_mask.sum(dim=1)[:,None,None,None])    # [B,N,L,3] / [B,1,1,1]
            X = X + f.clamp(min=-20.0, max=20.0)

        return h, X * mask[...,None]


class HierEGNNEncoder(nn.Module):

    def __init__(self, args, update_X=True, backbone_CA_only=True):
        super(HierEGNNEncoder, self).__init__()
        self.update_X = update_X
        self.backbone_CA_only = backbone_CA_only
        self.clash_step = args.clash_step
        self.residue_mpn = EGNNEncoder(
                args, features_type='backbone',
                node_hdim=args.hidden_size,
                update_X=False,
        )
        self.atom_mpn = EGNNEncoder(
                args, features_type='atom',
                node_hdim=args.hidden_size,
                update_X=False,
        )
        if self.update_X:
            # backbone coord update
            self.W_x = nn.Linear(args.hidden_size, args.hidden_size)
            self.U_x = nn.Linear(args.hidden_size, args.hidden_size)
            self.T_x = nn.Sequential(nn.ReLU(), nn.Linear(args.hidden_size, 4))
            # side chain coord update
            self.W_a = nn.Linear(args.hidden_size, args.hidden_size)
            self.U_a = nn.Linear(args.hidden_size, args.hidden_size)
            self.T_a = nn.Sequential(nn.ReLU(), nn.Linear(args.hidden_size, 1))

        self.embedding = nn.Embedding(len(ATOM_TYPES), args.hidden_size)
        for param in self.parameters():
            if param.dim() > 1:
                nn.init.xavier_uniform_(param)

    # X: [B,N,L,3], V: [B,N,6], S: [B,N,H], A: [B,N,L]
    def forward(self, X, V, S, A):
        B, N, L = X.size()[:3]
        X_atom = X.view(B, N*L, 3)
        mask = A.clamp(max=1).float()

        # atom message passing
        h_atom = self.embedding(A).view(B, N*L, -1)
        h_atom, _ = self.atom_mpn(X_atom, h_atom, h_atom, A.view(B,-1))
        h_atom = h_atom.view(B,N,L,-1)
        h_atom = h_atom * mask[...,None]
        h_A = h_atom.sum(dim=-2) / (1e-6 + mask.sum(dim=-1)[...,None])

        # residue message passing
        h_V = torch.cat([V, h_A], dim=-1)
        h_res, _ = self.residue_mpn(X, h_V, S, A)

        if self.update_X:
            # backbone update
            bb_mask = mask[:,:,:4]  # [B, N, 4]
            X_bb = X[:,:,:4]  # backbone atoms
            mij = self.W_x(h_res).unsqueeze(2) + self.U_x(h_res).unsqueeze(1)  # [B,N,N,H]
            xij = X_bb.unsqueeze(2) - X_bb.unsqueeze(1)  # [B,N,N,4,3]
            dij = xij.norm(dim=-1)  # [B,N,N,4]
            fij = torch.maximum(self.T_x(mij), 3.8 - dij)  # break term [B,N,N,4]
            xij = xij * fij.unsqueeze(-1)
            f_res = torch.sum(xij * bb_mask[:,None,:,:,None], dim=2)  # [B,N,N,4,3] * [B,1,N,4,1] -> [B,N,4,3]
            f_res = f_res / (1e-6 + bb_mask.sum(dim=1, keepdims=True)[...,None])  # [B,N,4,3]
            X_bb = X_bb + f_res.clamp(min=-20.0, max=20.0)

            # Clash correction
            for _ in range(self.clash_step):
                xij = X_bb.unsqueeze(2) - X_bb.unsqueeze(1)  # [B,N,N,4,3]
                dij = xij.norm(dim=-1)  # [B,N,N,4]
                fij = F.relu(3.8 - dij)  # repulsion term [B,N,N,4]
                xij = xij * fij.unsqueeze(-1)
                f_res = torch.sum(xij * bb_mask[:,None,:,:,None], dim=2)  # [B,N,N,4,3] * [B,1,N,4,1] -> [B,N,4,3]
                f_res = f_res / (1e-6 + bb_mask.sum(dim=1, keepdims=True)[...,None])  # [B,N,4,3]
                X_bb = X_bb + f_res.clamp(min=-20.0, max=20.0)

            # side chain update
            mij = self.W_a(h_atom).unsqueeze(3) + self.U_a(h_atom).unsqueeze(2)  # [B,N,L,1,H] + [B,N,1,L,H]
            xij = X.unsqueeze(3) - X.unsqueeze(2)  # [B,N,L,1,3] - [B,N,1,L,3]
            dij = xij.norm(dim=-1)  # [B,N,L,L]
            fij = torch.maximum(self.T_a(mij).squeeze(-1), 1.5 - dij)  # break term [B,N,L,L]
            xij = xij * fij.unsqueeze(-1)  # [B,N,L,L,3]
            f_atom = torch.sum(xij * mask[:,:,None,:,None], dim=3)  # [B,N,L,L,3] * [B,N,1,L,1] -> [B,N,L,3]
            X_sc = X + 0.1 * f_atom

            if self.backbone_CA_only:
                X = torch.cat((X_sc[:,:,:1], X_bb[:,:,1:2], X_sc[:,:,2:]), dim=2)
            else:
                X = torch.cat((X_bb[:,:,:4], X_sc[:,:,4:]), dim=2)

        return h_res, X * mask[...,None]


class ReverseDiffusionDense(nn.Module):

    def __init__(
            self,
            *,
            T,
            B,
            max_len,
            device,
            pos_embed_size,
            h_embed_size,
            num_layers,
            e_embed_size,
            num_heads,
            n_layers_per_egnn,
            network,
            use_sequence_transformer,
            use_positional_encoding,
            use_rel_positional_encoding,
            b_0,
            b_T,
            scale_eps,
    ):
        super(ReverseDiffusionDense, self).__init__()
        self.T = T
        self.B = B
        self.max_len = max_len
        self.pos_embed_size = pos_embed_size
        self.h_embed_size = h_embed_size
        self.e_embed_size = e_embed_size if e_embed_size is not None else h_embed_size
        self.device = device
        self.network = network
        self.num_heads = num_heads
        self.use_sequence_transformer = use_sequence_transformer
        self.use_positional_encoding = use_positional_encoding
        self.use_rel_positional_encoding = use_rel_positional_encoding
        self.scale_eps = scale_eps

        # self.node_embedding_layer
        self.embedding_layer = Embedding(
            max_len, T, self.pos_embed_size, h_embed_size, self.use_positional_encoding)
        self.edge_embedding_layer = fn.partial(
            index_embedding, N=max_len, embed_size=self.e_embed_size)
        self.diffuser = diffuser.Diffuser(T=T, b_0=b_0, b_T=b_T)
        self.cum_a_schedule = torch.Tensor(
            self.diffuser.cum_a_schedule).to(device)

        # num - layers
        # self.reduction_linear = nn.Linear(1767, 500*256)
        self.layers = []
        for i in range(num_layers):
            layer = []
            egnn = egnn_dense.EGNN(
                in_node_nf=self.h_embed_size,
                hidden_nf=self.h_embed_size,  # dimension of messages?
                out_node_nf=self.h_embed_size,
                in_edge_nf=self.e_embed_size if use_rel_positional_encoding else 0,
                n_layers=n_layers_per_egnn,
                normalize=True
            ).to(device)
            layer.append(egnn)
            layer.append(nn.LayerNorm(self.h_embed_size))
            self.layers.append(layer)
        self.attn = Attention(dim=256, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.)
        self.layers_pytorch = nn.ModuleList([l for sublist in self.layers for l in sublist])
        self.weight = nn.Parameter(torch.randn(500, 1767))

    def forward(self, input_feats, names, selected_list):
        """forward computes the reverse diffusion conditionals p(X^t|X^{t+1})
        for each item in the batch

        Args:
            X: the noised samples from the noising process, of shape [Batch, N, D].
                Where the T time steps are t=1,...,T (i.e. not including the un-noised X^0)

        Returns:
            eps_theta_val: estimate of error for each step shape [B, N, 3]
        """
        # Scale protein positions to be on similar scale as noise distribution.
        bb_pos = input_feats['bb_corrupted'].type(torch.float32)  # / 10.
        curr_pos = bb_pos.clone()  # [B, N, D]
        bb_mask = input_feats['bb_mask'].type(torch.float32)  # [B, N]
        bb_2d_mask = bb_mask[:, None, :] * bb_mask[:, :, None]
        bb_pos *= bb_mask[..., None]
        t = input_feats['t']
        B, N, _ = bb_pos.shape
        embedding_tensor = torch.zeros((B, selected_list[0]['embedding'].shape[0], 256))
        for i in range(B):
            protein_name = names[i]
            for protein_sample in selected_list:
                if protein_sample.get('name') == protein_name:
                    embedding_tensor[i] = protein_sample['embedding']

        # Generate edge feature as embedding of residue offsets
        if self.use_rel_positional_encoding:
            res_index = input_feats['residue_index']
            edge_attr = res_index[:, :, None] - res_index[:, None, :]
            edge_attr = edge_attr.reshape([B, N ** 2])
            edge_attr = self.edge_embedding_layer(edge_attr, device=self.device)
            edge_attr = edge_attr.reshape([B, N, N, self.e_embed_size])
            assert edge_attr.shape[0] == B
            assert edge_attr.shape[1] == N
            assert edge_attr.shape[1] == edge_attr.shape[2]
        else:
            edge_attr = None

        # Node representations for first layer.
        H = self.embedding_layer(
            input_feats['residue_index'], B, t, N, bb_mask, device=self.device)
        concatenated_tensor = torch.cat((H, embedding_tensor.cuda()), dim=1)

        atten_tensor = self.attn(concatenated_tensor)
        updated_embedding = torch.matmul(self.weight, atten_tensor)
        print(self.weight)
        # concatenated_tensor = concatenated_tensor.view(B, -1)
        # reduction_tensor = self.reduction_linear(concatenated_tensor)
        # concatenated_tensor = reduction_tensor.view(B, 500, -1)

        for layer in self.layers:
            updated_embedding *= bb_mask[..., None]
            # H *= bb_mask[..., None]
            curr_pos *= bb_mask[..., None]
            edge_attr *= bb_2d_mask[..., None]
            if len(layer) == 3:
                tfmr, egnn, norm = layer
                H = tfmr(H, src_key_padding_mask=1 - bb_mask)
            else:
                egnn, norm = layer
            updated_embedding, curr_pos = egnn(updated_embedding, curr_pos, edge_attr, mask=bb_mask)
            updated_embedding *= bb_mask[..., None]
            updated_embedding = norm(updated_embedding)
            # H, curr_pos = egnn(H, curr_pos, edge_attr, mask=bb_mask)
            # H *= bb_mask[..., None]
            # H = norm(H)
        if self.scale_eps:
            cum_a_t = self.cum_a_schedule[t[:, None, None]]
            eps_theta_val = bb_pos - curr_pos * cum_a_t
            eps_theta_val = eps_theta_val / torch.sqrt(1 - cum_a_t)
            eps_theta_val = eps_theta_val * bb_mask[..., None]
        else:
            eps_theta_val = curr_pos - bb_pos
            eps_theta_val = eps_theta_val.reshape(bb_pos.shape)
            eps_theta_val = eps_theta_val * bb_mask[..., None]
        return eps_theta_val
