
import torch
from torch import nn
import torch.nn.functional as F
from .activations import get_activation
from .triplet import get_triplet_layer
from .torsion_angle import get_torsionangle_layer
import numpy as np 

@torch.jit.script
def form_degree_scalers(gates: torch.Tensor):
    degrees = torch.sum(gates, dim=2, keepdim=True)
    degree_scalers = torch.log(1+degrees)
    return degree_scalers


class EGT_Attention(nn.Module):
    def __init__(self,
                 node_width            ,
                 edge_width            ,
                 angle_width           ,
                 torsion_angle_width   ,
                 num_heads             ,
                 source_dropout = 0    ,
                 scale_degree   = True ,
                 edge_update    = True ,
                 angle_update    = True ,
                 ):
        super().__init__()
        self.node_width          = node_width
        self.edge_width          = edge_width
        self.angle_width          = angle_width
        self.torsion_angle_width = torsion_angle_width
        self.num_heads           = num_heads
        self.source_dropout      = source_dropout
        self.scale_degree        = scale_degree
        self.edge_update         = edge_update
        self.angle_update        = angle_update
        
        assert not (self.node_width % self.num_heads),\
                'node_width must be divisible by num_heads'
        self._dot_dim = self.node_width//self.num_heads
        self._scale_factor = self._dot_dim ** -0.5
        
        self.mha_ln_h   = nn.LayerNorm(self.node_width)
        self.mha_ln_e   = nn.LayerNorm(self.edge_width)
        self.mha_ln_a   = nn.LayerNorm(self.angle_width)
        self.mha_ln_t   = nn.LayerNorm(self.torsion_angle_width)
        self.lin_QKV    = nn.Linear(self.node_width, self.node_width*3)
        self.lin_EG     = nn.Linear(self.edge_width, self.num_heads*2)
        self.lin_A      = nn.Linear(self.angle_width, self.num_heads)
        self.lin_T      = nn.Linear(self.torsion_angle_width, self.num_heads)

        # self.norm_h = nn.LayerNorm(self.node_width)  
        # self.norm_e = nn.LayerNorm(self.edge_width)
        # self.norm_a = nn.LayerNorm(self.angle_width)
        # self.norm_t = nn.LayerNorm(self.torsion_angle_width) 
        
        self.lin_O_h    = nn.Linear(self.node_width, self.node_width)

        if self.edge_update:
            self.lin_O_e    = nn.Linear(self.num_heads, self.edge_width)
        if self.angle_update:
            self.lin_O_a    = nn.Linear(self.num_heads, self.angle_width)
            self.lin_O_t    = nn.Linear(self.num_heads, self.torsion_angle_width)

        nn.init.xavier_uniform_(self.lin_QKV.weight)
        nn.init.xavier_uniform_(self.lin_EG.weight)
        nn.init.xavier_uniform_(self.lin_A.weight)
        nn.init.xavier_uniform_(self.lin_T.weight)
        nn.init.xavier_uniform_(self.lin_O_h.weight, gain=0.1)
        nn.init.xavier_uniform_(self.lin_O_e.weight, gain=0.1)
        nn.init.xavier_uniform_(self.lin_O_a.weight, gain=0.1)
        nn.init.xavier_uniform_(self.lin_O_t.weight, gain=0.1)

    def forward(self, h, e, a, t, mask, angle_indices, torsion_indices):
        bsize, num_nodes, embed_dim = h.shape
        h_ln = self.mha_ln_h(h)
        e_ln = self.mha_ln_e(e)
        a_ln = self.mha_ln_a(a)
        t_ln = self.mha_ln_t(t)

        
        # Projections
        Q, K, V = self.lin_QKV(h_ln).chunk(3, dim=-1)
        E, G = self.lin_EG(e_ln).chunk(2, dim=-1)
        A = self.lin_A(a_ln)
        T = self.lin_T(t_ln)
        
        if self.source_dropout > 0 and self.training:
            rmask = h.new_empty(size=[bsize,1,num_nodes,1])\
                        .bernoulli_(self.source_dropout)\
                            * torch.finfo(mask.dtype).min
            mask = mask + rmask
            
        # Multi-head attention
        Q = Q.view(bsize, num_nodes, self._dot_dim, self.num_heads)
        K = K.view(bsize, num_nodes, self._dot_dim, self.num_heads)
        V = V.view(bsize, num_nodes, self._dot_dim, self.num_heads)
        
        Q = Q * self._scale_factor
        
        gates = torch.sigmoid(G + mask)
        H_hat = torch.einsum('bldh,bmdh->blmh', Q, K) + E
        # H_hat_2 = torch.einsum('bldh,bmdh->blmh', K, V) * self._scale_factor + G
        # H_hat_3 = torch.einsum('bldh,bmdh->blmh', Q, V) + I


        A_tild = F.softmax(H_hat + mask, dim=2) * gates
        V_att = torch.einsum('blmh,bmkh->blkh', A_tild, V)
        
        if self.scale_degree:
            degree_scalers = form_degree_scalers(gates)
            V_att = V_att * degree_scalers
        
        V_att = V_att.reshape(bsize, num_nodes, embed_dim)
        # print_statistics(V_att, 'H')
        

        P_hat_all = (H_hat.unsqueeze(3) + H_hat.unsqueeze(1) + H_hat.unsqueeze(2)) * self._scale_factor 
        # P_hat_all = self.mha_ln_P_hat_all(P_hat_all)
        ab = H_hat.unsqueeze(3).unsqueeze(4)
        bc = H_hat.unsqueeze(3).unsqueeze(1)
        cd = H_hat.unsqueeze(1).unsqueeze(2)
        ac = H_hat.unsqueeze(2).unsqueeze(4)
        bd = H_hat.unsqueeze(1).unsqueeze(3)
        ad = H_hat.unsqueeze(2).unsqueeze(3)
        T_hat_all = (ab + bc + cd + ac + bd + ad) * self._scale_factor 

        # P_hat_all = (H_hat.unsqueeze(3) + H_hat_2.unsqueeze(1)) * self._scale_factor 
        # # P_hat_all = self.mha_ln_P_hat_all(P_hat_all)
        # T_hat_all = (H_hat.unsqueeze(3).unsqueeze(4) + H_hat_2.unsqueeze(3).unsqueeze(1) + H_hat_3.unsqueeze(1).unsqueeze(1)) * self._scale_factor 
        # T_hat_all = self.mha_ln_T_hat_all(T_hat_all)
        
        ai = angle_indices[:, :, 0]  # Shape: (batch_size, max_angles)
        aj = angle_indices[:, :, 1]
        ak = angle_indices[:, :, 2]

        batch_indices = torch.arange(bsize).unsqueeze(1)

        P_hat = A # Shape: (batch_size, max_angles+1, embed_dim)
        P_hat[:,:,:] = P_hat[:,:,:] + P_hat_all[batch_indices, ai, aj, ak]

        ti = torsion_indices[:, :, 0]  # Shape: (batch_size, max_torsions)
        tj = torsion_indices[:, :, 1]
        tk = torsion_indices[:, :, 2]
        tl = torsion_indices[:, :, 3]

        T_hat = T # Shape: (batch_size, max_torsion+1, embed_dim)
        T_hat[:,:,:] = T_hat[:,:,:] + T_hat_all[batch_indices, ti, tj, tk, tl]
                # Update
        h = self.lin_O_h(V_att)
        # print_statistics(h, 'H_updatefinal')
        e = self.lin_O_e(H_hat)

        a = self.lin_O_a(P_hat)

        t = self.lin_O_t(T_hat)

        
        return h, e, a, t

class EdgeUpdate(nn.Module):
    def __init__(self,
                 node_width            ,
                 edge_width            ,
                 angle_width           ,
                 torsion_angle_width   ,
                 num_heads             ,
                 ):
        super().__init__()
        self.node_width          = node_width
        self.edge_width          = edge_width
        self.angle_width          = angle_width
        self.torsion_angle_width = torsion_angle_width
        self.num_heads           = num_heads
        
        assert not (self.node_width % self.num_heads),\
                'node_width must be divisible by num_heads'
        self._dot_dim = self.node_width//self.num_heads
        self._scale_factor = self._dot_dim ** -0.5
        
        self.mha_ln_h   = nn.LayerNorm(self.node_width)
        self.mha_ln_e   = nn.LayerNorm(self.edge_width)
        self.mha_ln_a   = nn.LayerNorm(self.angle_width)
        self.mha_ln_t   = nn.LayerNorm(self.torsion_angle_width)
        self.lin_QKV    = nn.Linear(self.node_width, self.node_width*3)
        self.lin_EG     = nn.Linear(self.edge_width, self.num_heads*2)
        self.lin_A      = nn.Linear(self.angle_width, self.num_heads)
        self.lin_T      = nn.Linear(self.torsion_angle_width, self.num_heads)
        
        self.lin_O_e    = nn.Linear(self.num_heads, self.edge_width)
        self.lin_O_a    = nn.Linear(self.num_heads, self.angle_width)
        self.lin_O_t    = nn.Linear(self.num_heads, self.torsion_angle_width)

        nn.init.xavier_uniform_(self.lin_QKV.weight)
        nn.init.xavier_uniform_(self.lin_EG.weight)
        nn.init.xavier_uniform_(self.lin_A.weight)
        nn.init.xavier_uniform_(self.lin_T.weight)
        nn.init.xavier_uniform_(self.lin_O_e.weight)
        nn.init.xavier_uniform_(self.lin_O_a.weight)
        nn.init.xavier_uniform_(self.lin_O_t.weight)

    def forward(self, h, e, a, t, mask, angle_indices, torsion_indices):
        bsize, num_nodes, embed_dim = h.shape
        h_ln = self.mha_ln_h(h)
        e_ln = self.mha_ln_e(e)
        a_ln = self.mha_ln_a(a)
        t_ln = self.mha_ln_t(t)
        
        # Projections
        Q, K, V = self.lin_QKV(h_ln).chunk(3, dim=-1)
        E, G = self.lin_EG(e_ln).chunk(2, dim=-1)
        A = self.lin_A(a_ln)
        T = self.lin_T(t_ln)
            
        # Multi-head attention
        Q = Q.view(bsize, num_nodes, self._dot_dim, self.num_heads)
        K = K.view(bsize, num_nodes, self._dot_dim, self.num_heads)
        V = V.view(bsize, num_nodes, self._dot_dim, self.num_heads)
        
        Q = Q * self._scale_factor
        
        gates = torch.sigmoid(G + mask)
        H_hat = torch.einsum('bldh,bmdh->blmh', Q, K) + E
        # H_hat_2 = torch.einsum('bldh,bmdh->blmh', K, V) * self._scale_factor + G
        # H_hat_3 = torch.einsum('bldh,bmdh->blmh', Q, V) + I

 
        P_hat_all = (H_hat.unsqueeze(3) + H_hat.unsqueeze(1) + H_hat.unsqueeze(2)) * self._scale_factor 
        # P_hat_all = self.mha_ln_P_hat_all(P_hat_all)
        ab = H_hat.unsqueeze(3).unsqueeze(4)
        bc = H_hat.unsqueeze(3).unsqueeze(1)
        cd = H_hat.unsqueeze(1).unsqueeze(2)
        ac = H_hat.unsqueeze(2).unsqueeze(4)
        bd = H_hat.unsqueeze(1).unsqueeze(3)
        ad = H_hat.unsqueeze(2).unsqueeze(3)
        T_hat_all = (ab + bc + cd + ac + bd + ad) * self._scale_factor 
        
        
        # P_hat_all = (H_hat.unsqueeze(3) + H_hat_2.unsqueeze(1)) * self._scale_factor 
        # T_hat_all = (H_hat.unsqueeze(3).unsqueeze(4) + H_hat_2.unsqueeze(3).unsqueeze(1) + H_hat_3.unsqueeze(1).unsqueeze(1)) * self._scale_factor
        
        ai = angle_indices[:, :, 0] # Shape: (batch_size, max_angles)
        aj = angle_indices[:, :, 1]
        ak = angle_indices[:, :, 2]

        batch_indices = torch.arange(bsize).unsqueeze(1)

        P_hat = A # Shape: (batch_size, max_angles+1, embed_dim)
        P_hat[:,:,:] = P_hat[:,:,:] + P_hat_all[batch_indices, ai, aj, ak]

        ti = torsion_indices[:, :, 0]  # Shape: (batch_size, max_torsions)
        tj = torsion_indices[:, :, 1]
        tk = torsion_indices[:, :, 2]
        tl = torsion_indices[:, :, 3]

        T_hat = T # Shape: (batch_size, max_torsion+1, embed_dim)
        T_hat[:,:,:] = T_hat[:,:,:] + T_hat_all[batch_indices, ti, tj, tk, tl]

        e = self.lin_O_e(H_hat)
        a = self.lin_O_a(P_hat)
        t = self.lin_O_t(T_hat)
        return h, e, a, t



class FFN(nn.Module):
    def __init__(self,
                 width,
                 multiplier = 1.,
                 act_dropout = 0.,
                 activation = 'gelu',
                 ):
        super().__init__()
        self.width = width
        self.multiplier = multiplier
        self.act_dropout = act_dropout
        self.activation = activation
        
        self.ffn_fn, self.act_mul = get_activation(activation)
        inner_dim = round(self.width*self.multiplier)
        
        self.ffn_ln = nn.LayerNorm(self.width)
        self.lin_W1  = nn.Linear(self.width, inner_dim*self.act_mul)
        self.lin_W2  = nn.Linear(inner_dim, self.width)
        self.dropout = nn.Dropout(self.act_dropout)

        nn.init.xavier_uniform_(self.lin_W1.weight)
        nn.init.xavier_uniform_(self.lin_W2.weight)
    
    def forward(self, x):
        x_ln = self.ffn_ln(x)
        x = self.ffn_fn(self.lin_W1(x_ln))
        x = self.dropout(x)
        x = self.lin_W2(x)
        return x


class DropPath(nn.Module):
    def __init__(self, drop_path=0.):
        super().__init__()
        self.drop_path = drop_path
        self._keep_prob = 1 - self.drop_path
    
    def forward(self, x):
        if self.drop_path > 0 and self.training:
            mask_shape = [x.size(0)] + [1]*(x.ndim-1)
            mask = x.new_empty(size=mask_shape).bernoulli_(self._keep_prob)
            x = x.div(self._keep_prob) * mask
        return x
    
    def __repr__(self):
        return f'{self.__class__.__name__}(drop_path={self.drop_path})' 

class AGT_Layer_noVN(nn.Module):
    def __init__(self,
                 node_width                         ,
                 edge_width                         ,
                 angle_width                        ,
                 torsion_angle_width                ,
                 num_heads                          ,
                 activation          = 'gelu'       ,
                 scale_degree        = True         ,
                 node_update         = True         ,
                 edge_update         = True         ,
                 angle_update        = True         ,
                 triplet_heads       = 0            ,
                 triplet_type        = 'aggregate'  ,
                 triplet_dropout     = 0            ,
                 node_ffn_multiplier = 1.           ,
                 edge_ffn_multiplier = 1.           ,
                 source_dropout      = 0            ,
                 drop_path           = 0            ,
                 node_act_dropout    = 0            ,
                 edge_act_dropout    = 0            ,
                 ):
        super().__init__()
        self.node_width          = node_width
        self.edge_width          = edge_width
        self.angle_width         = angle_width
        self.torsion_angle_width = torsion_angle_width
        self.num_heads           = num_heads
        self.activation          = activation
        self.node_ffn_multiplier = node_ffn_multiplier
        self.edge_ffn_multiplier = edge_ffn_multiplier
        self.node_act_dropout    = node_act_dropout
        self.edge_act_dropout    = edge_act_dropout
        self.source_dropout      = source_dropout
        self.drop_path           = drop_path
        self.scale_degree        = scale_degree
        self.node_update         = node_update
        self.edge_update         = edge_update
        self.angle_update        = angle_update
        self.triplet_heads       = triplet_heads
        self.triplet_type        = triplet_type
        self.triplet_dropout     = triplet_dropout
        
        self._triplet_update     = self.triplet_heads > 0
        
        if self.node_update:
            self.update = EGT_Attention(
                node_width      = self.node_width,
                edge_width      = self.edge_width,
                angle_width     = self.angle_width,
                torsion_angle_width = self.torsion_angle_width,
                num_heads       = self.num_heads,
                source_dropout  = self.source_dropout,
                scale_degree    = self.scale_degree,
                edge_update     = self.edge_update,
                )
        elif self.edge_update:
            self.update = EdgeUpdate(
                node_width      = self.node_width,
                edge_width      = self.edge_width,
                angle_width     = self.angle_width,
                torsion_angle_width = self.torsion_angle_width,
                num_heads       = self.num_heads,
                )
        else:
            raise ValueError('At least one of node_update and edge_update must be True')
        
        if self.node_update:
            self.node_ffn = FFN(
                width           = self.node_width,
                multiplier      = self.node_ffn_multiplier,
                act_dropout     = self.node_act_dropout,
                activation      = self.activation,
                )
            # self.h_norm = nn.LayerNorm(self.node_width)
            # self.h_scale = nn.Parameter(torch.ones(1) * 0.8)
        if self.edge_update:
            if self._triplet_update:
                TripletLayer = get_triplet_layer(self.triplet_type)
                self.tria = TripletLayer(
                    edge_width        = self.edge_width,
                    num_heads         = self.triplet_heads,
                    attention_dropout = self.triplet_dropout,
                )
            
            self.edge_ffn = FFN(
                width           = self.edge_width,
                multiplier      = self.edge_ffn_multiplier,
                act_dropout     = self.edge_act_dropout,
                activation      = self.activation,
                )
            # self.e_norm = nn.LayerNorm(self.edge_width)
            # self.e_scale = nn.Parameter(torch.ones(1) * 0.8)
        if self.angle_update:
            AngleLayer = get_torsionangle_layer('torsion_attention')
            self.angle = AngleLayer(
                    angle_width       = self.angle_width,
                    num_heads         = 16,
                    attention_dropout = 0.,
                )
            
            self.angle_ffn = FFN(
                width           = self.angle_width,
                multiplier      = self.edge_ffn_multiplier,
                act_dropout     = self.edge_act_dropout,
                activation      = self.activation,
                )

            TorsionAngleLayer = get_torsionangle_layer('torsion_attention')
            self.tors = TorsionAngleLayer(
                    angle_width       = self.torsion_angle_width,
                    num_heads         = 16,
                    attention_dropout = 0.,
                )

            self.torsion_angle_ffn = FFN(
                width           = self.torsion_angle_width,
                multiplier      = self.edge_ffn_multiplier,
                act_dropout     = self.edge_act_dropout,
                activation      = self.activation,
                )
            # self.p_norm = nn.LayerNorm(self.angle_width)
            # self.t_norm = nn.LayerNorm(self.torsion_angle_width)
            # self.p_scale = nn.Parameter(torch.ones(1) * 0.8)  
            # self.t_scale = nn.Parameter(torch.ones(1) * 0.8)
        
        self.drop_path = DropPath(self.drop_path)

    def forward(self, g):
        h, e, p, t, mask, mask_a, mask_t= g.h, g.e, g.p, g.t, g.mask, g.mask_a, g.mask_t 
        angle_indices, torsion_indices = g.angle_indices, g.torsion_indices
        # print(f'h is {h.size()}, e shape {e.size()}, p shape {p.size()}, mask shape {mask.size()}, angle_indices shape {angle_indices.size()}')
        
        h_r1, e_r1, p_r1, t_r1 = h, e, p, t
        h, e, p, t = self.update(h, e, p, t, mask, angle_indices, torsion_indices)
        # with torch.cuda.amp.autocast(enabled=False):   
        #     h, e, p, t = self.update(h.float(), e.float(), p.float(), t.float(), mask, angle_indices, torsion_indices)

        
        if self.node_update:
            h = self.drop_path(h)
            h.add_(h_r1) 
            # print_statistics(h, 'after drop_path h')
            h_r2 = h
            h = self.node_ffn(h)
            h = self.drop_path(h)
            h.add_(h_r2)
            # print_statistics(h, 'after node_ffn h') 
        
        if self.edge_update:
            e = self.drop_path(e)
            e.add_(e_r1)
            
            if self._triplet_update:
                e_rt = e

                e = self.tria(e, mask)
                e = self.drop_path(e)
                e.add_(e_rt)
            
            e_r2 = e
            e = self.edge_ffn(e)
            e = self.drop_path(e)
            e.add_(e_r2)
        
        if self.angle_update:

            p = self.drop_path(p)
            p.add_(p_r1)          
            p_rt = p

            p = self.angle(p, mask_a)
            p = self.drop_path(p)
            p.add_(p_rt)
          
            p_r2 = p
            p = self.angle_ffn(p)
            p = self.drop_path(p)
            p.add_(p_r2)

            t = self.drop_path(t)
            t.add_(t_r1)           
            t_rt = t
            
            t = self.tors(t, mask_t)
            t = self.drop_path(t)
            t.add_(t_rt)

            t_r2 = t
            t = self.torsion_angle_ffn(t)
            t = self.drop_path(t)
            t.add_(t_r2) 


        g = g.copy()
        g.h, g.e, g.p, g.t = h, e, p, t


        return g
    
    def __repr__(self):
        rep = super().__repr__()
        rep = (rep + ' ('
                   + f'activation: {self.activation}, '
                   + f'source_dropout: {self.source_dropout}'
                   +')')
        return rep