import numpy as np
import torch
import math
import torch.nn as nn
import torch.nn.functional as F

from project.utils.checkpointing import checkpoint_blocks


""" conv1d """
def conv1d( in_channels: int, 
            out_channels: int, 
            kernel_size: int, 
            stride: int = 1, 
            padding: str = "same", 
            dilation: int = 1, 
            group: int = 1, 
            bias: bool = False) -> nn.Conv1d:

    if padding == "same":
        padding = int((kernel_size - 1)/2)

    return nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, group, bias)

""" conv2d """
def conv2d(in_channels: int, 
            out_channels: int, 
            kernel_size: int, 
            stride: int = 1, 
            padding: str = "same", 
            dilation: int = 1, 
            group: int = 1, 
            bias: bool = False) -> nn.Conv2d:

    if padding == "same":
        padding = int((kernel_size - 1)/2)

    return nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, group, bias)

""" conv1d 1x1 """
def conv_identity_1d( in_channels  : int,
                      out_channels : int,
                      kernel_size  : int = 1,
                      stride       : int = 1,
                      padding      : str = "same",
                      dilation     : int = 1,
                      group        : int = 1,
                      bias         : bool = False,
                      norm         : str = "IN",
                      activation   : str = "Relu",
                      track_running_stats_ : bool = True):
    layers = []

    # convolution
    layers.append( conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, group, bias))

    # normalization
    if norm == "BN":
       layers.append( nn.BatchNorm1d(out_channels, affine=True, track_running_stats=track_running_stats_))
    elif norm == "IN":
        layers.append( nn.InstanceNorm1d(out_channels, affine=True, track_running_stats=track_running_stats_))
       
    # activation
    if activation == "ELU":
        layers.append( nn.ELU())
    elif activation == "Relu":
        layers.append(nn.LeakyReLU(negative_slope=0.01,inplace=True))

    return nn.Sequential(*layers)

""" conv2d 1x1"""
def conv_identity_2d( in_channels  : int,
                      out_channels : int,
                      kernel_size  : int = 1,
                      stride       : int = 1,
                      padding      : str = "same",
                      dilation     : int = 1,
                      group        : int = 1,
                      bias         : bool = False,
                      norm         : str = "IN",
                      activation   : str = "Relu",
                      track_running_stats_ : bool = True):
    layers = []

    # convolution
    layers.append(conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, group, bias))

    # normalization
    if norm == "BN":
       layers.append( nn.BatchNorm2d(out_channels, affine=True, track_running_stats=track_running_stats_))
    elif norm == "IN":
        layers.append( nn.InstanceNorm2d(out_channels, affine=True, track_running_stats=track_running_stats_))

    # activation
    if activation == "ELU":
        layers.append( nn.ELU())
    elif activation == "Relu":
        layers.append( nn.LeakyReLU(negative_slope=0.01,inplace=True))

    return nn.Sequential(*layers)




""" ResNetv2 BasicBlock1D """
class BasicBlock_ResNetV2_1D(nn.Module):

    def __init__(self,
        in_channels  : int,
        out_channels : int,
        kernel_size  : int,
        stride       : int = 1,
        downsample = None,
        padding      : str = "same",
        dilation     : int = 1,
        group        : int = 1,
        bias         : bool = False,
        track_running_stats_ : bool = True,
        norm         : str = "BN",
        activation   : str = "ELU"):

        super(BasicBlock_ResNetV2_1D, self).__init__()

        if norm == "BN":
            self.bn1 = nn.BatchNorm1d(in_channels, affine=True, track_running_stats=track_running_stats_)
            self.bn2 = nn.BatchNorm1d(out_channels, affine=True, track_running_stats=track_running_stats_)
        elif norm == "IN":
            self.bn1 = nn.InstanceNorm1d(in_channels, affine=True, track_running_stats=track_running_stats_)
            self.bn2 = nn.InstanceNorm1d(out_channels, affine=True, track_running_stats=track_running_stats_)

        if activation == "ELU":
            self.relu1 = nn.ELU()
            self.relu2 = nn.ELU()
        elif activation == "Relu":
            self.relu1 = nn.LeakyReLU(negative_slope=0.01,inplace=True)
            self.relu2 = nn.LeakyReLU(negative_slope=0.01,inplace=True)

        self.conv1 = conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, group, bias)
        self.conv2 = conv1d(out_channels, out_channels, kernel_size, stride, padding, dilation, group, bias)

        #self.downsample = downsample

    def forward(self, x):

        identity = x

        x = self.bn1(x)
        x = self.relu1(x)
        x = self.conv1(x)

        x = self.bn2(x)
        x = self.relu2(x)
        x = self.conv2(x)

        if self.downsample != None :
            identity = self.downsample(identity)

        x += identity

        return x


class BasicBlock_Inception2D_V1(nn.Module):

    def __init__(self,
        in_channels  : int,
        out_channels : int,
        kernel_size  : int,
        stride       : int = 1,
        downsample = None,
        padding      : str = "same",
        dilation     : int = 1,
        group        : int =1,
        bias         : bool = False,
        track_running_stats_ : bool = True,
        norm         : str = "IN",
        activation   : str = "Relu"):

        super(BasicBlock_Inception2D_V1, self).__init__()

        if norm == "BN":
            self.bns1 = nn.ModuleList( [ nn.BatchNorm2d(out_channels, affine=True, track_running_stats=track_running_stats_) for _ in range(3) ])
            self.bns2 = nn.ModuleList( [ nn.BatchNorm2d(out_channels, affine=True, track_running_stats=track_running_stats_) for _ in range(3) ])
        elif norm == "IN":
            self.bns1 = nn.ModuleList( [ nn.InstanceNorm2d(out_channels, affine=True, track_running_stats=track_running_stats_) for _ in range(3)])
            self.bns2 = nn.ModuleList( [ nn.InstanceNorm2d(out_channels, affine=True, track_running_stats=track_running_stats_) for _ in range(3)])

        if activation == "ELU":
            self.acts1 = nn.ModuleList( [ nn.ELU() for _ in range(3) ] )
            self.act = nn.ELU()
        elif activation == "Relu":
            self.acts1 = nn.ModuleList( [ nn.LeakyReLU(negative_slope=0.01,inplace=True)  for _ in range(3) ] )
            self.act = nn.LeakyReLU(negative_slope=0.01,inplace=True)

        self.convs1 = nn.ModuleList( [ nn.Conv2d(in_channels, out_channels, (1,9), stride, (0,4), dilation, group, bias),\
                                      nn.Conv2d(in_channels, out_channels, (9,1), stride, (4,0), dilation, group, bias),\
                                      nn.Conv2d(in_channels, out_channels, (3,3), stride, (1,1), dilation, group, bias) ] )

        self.convs2 = nn.ModuleList( [ nn.Conv2d(out_channels, out_channels, (1,9), stride, (0,4), dilation, group, bias),\
                                      nn.Conv2d(out_channels, out_channels, (9,1), stride, (4,0), dilation, group, bias),\
                                      nn.Conv2d(out_channels, out_channels, (3,3), stride, (1,1), dilation, group, bias) ] )

        self.downsample = downsample


    def forward(self, x):

        identity = x

        xs = None
        for i in range(3):

            xsi = self.convs1[i](x)
            xsi = self.bns1[i](xsi)
            xsi = self.acts1[i](xsi)

            xsi = self.convs2[i](xsi)
            xsi = self.bns2[i](xsi)

            if xs == None:
                xs = xsi
            else:
                xs = xs + xsi

        if self.downsample != None:
            identity = self.downsample(identity)

        return self.act(xs + identity)


""" concatenate 1D -> 2D """
def seq2pairwise_v3(rec1d, lig1d):

    device = rec1d.device
    b, c, L1 = rec1d.size()
    _, _, L2 = lig1d.size()

    out1 = rec1d.unsqueeze(3).to(device)
    repeat_idx = [1] * out1.dim()
    repeat_idx[3] = L2
    out1 = out1.repeat(*(repeat_idx))

    out2 = lig1d.unsqueeze(2).to(device)
    repeat_idx = [1] * out2.dim()
    repeat_idx[2] = L1
    out2 = out2.repeat(*(repeat_idx))

    return torch.cat([out1, out2], dim=1)


#####################################################################################################################

class TriangleMultiplication(nn.Module):
    def __init__(self, model_args):
        super(TriangleMultiplication, self).__init__()

        self.dz = model_args['Channel_z']
        self.dc = model_args['Channel_z']

        # init norm
        self.norm_com = nn.LayerNorm(self.dz)
        self.norm_rec = nn.LayerNorm(self.dz)
        self.norm_lig = nn.LayerNorm(self.dz)

        # linear * gate for com_rec, com_lig
        self.Linear_com_rec = nn.Linear(self.dz, self.dc)
        self.Linear_com_lig = nn.Linear(self.dz, self.dc)
        self.gate_com_rec = nn.Linear(self.dz, self.dc)
        self.gate_com_lig = nn.Linear(self.dz, self.dc)

        # linear * gate for rec, lig
        self.Linear_rec = nn.Linear(self.dz, self.dc)
        self.Linear_lig = nn.Linear(self.dz, self.dc)
        self.gate_rec = nn.Linear(self.dz, self.dc)
        self.gate_lig = nn.Linear(self.dz, self.dc)

        # final output
        self.norm_all = nn.LayerNorm(self.dc)
        self.Linear_all = nn.Linear(self.dc, self.dz)
        self.gate_all = nn.Linear(self.dz, self.dz)

    def forward(self, z_com, z_rec, z_lig, mask=None):
        """
        Argument:
            z_com : (B, nrec, nlig, dz)
            z_rec : (B, nrec, nrec, dz)
            z_lig : (B, nlig, nlig, dz)
            mask  : (B, nrec, nlig)
        return:
            z_com : (B, nrec, nlig, dz)
        """
        # Apply layer normalization
        z_com = self.norm_com(z_com)
        z_rec = self.norm_rec(z_rec)
        z_lig = self.norm_lig(z_lig)
        z_com_init = z_com

        # Apply linear transformatin, gate, and mask if mask exists
        if mask != None:
            z_com_rec = self.Linear_com_rec(z_com) * \
                        ( self.gate_com_rec(z_com).sigmoid() * mask)
            z_com_lig = self.Linear_com_lig(z_com) * \
                        ( self.gate_com_lig(z_com).sigmoid() * mask)
        else:
            z_com_rec = self.Linear_com_rec(z_com) * \
                        ( self.gate_com_rec(z_com).sigmoid())
            z_com_lig = self.Linear_com_lig(z_com) * \
                        ( self.gate_com_lig(z_com).sigmoid())


        z_rec = self.Linear_rec(z_rec) * self.gate_rec(z_rec).sigmoid()
        z_lig = self.Linear_lig(z_lig) * self.gate_lig(z_lig).sigmoid()

        # Calculate inter-protein attention for each row
        z_com_rec = torch.einsum(f"bikc,bkjc->bijc", z_rec, z_com_rec)
        z_com_lig = torch.einsum(f"bikc,bjkc->bjic", z_lig, z_com_lig)
        z_all = z_com_rec + z_com_lig

        z_com = self.gate_all(z_com_init).sigmoid() * self.Linear_all( self.norm_all(z_all))

        return z_com


class TriangleSelfAttention(nn.Module):
    def __init__(self, model_args):
        super(TriangleSelfAttention, self).__init__()

        self.dz = model_args['Channel_z']
        self.dc = model_args['Channel_c']
        self.num_head = model_args['num_head']
        self.dhc = self.num_head * self.dc

        self.norm_com = nn.LayerNorm(self.dz)
        self.Linear_Q = nn.Linear(self.dz, self.dhc)
        self.Linear_K = nn.Linear(self.dz, self.dhc)
        self.Linear_V = nn.Linear(self.dz, self.dhc)
        #self.Linear_bias = nn.Linear(self.dz, self.num_head)

        self.softmax = nn.Softmax(-1)
        self.gate_v = nn.Linear(self.dz, self.dhc)
        self.Linear_final = nn.Linear(self.dhc, self.dz)


    def reshape_dim(self, x):
        new_shape = x.size()[:-1] + (self.num_head, self.dc)
        return x.view(*new_shape)

    def forward(self, z_com, mask=None, eps=5e4):

        # Calculate intra-protein attention
        B, row, col, _ = z_com.shape
        z_com = self.norm_com(z_com)

        scalar = torch.sqrt( torch.tensor(1.0/self.dc) )
        q = self.reshape_dim(self.Linear_Q(z_com))
        k = self.reshape_dim(self.Linear_K(z_com))
        v = self.reshape_dim(self.Linear_V(z_com))
        #bias = self.Linear_bias(z_com).permute(0,3,1,2)

        attn = torch.einsum(f"bnihc, bnjhc->bhnij", q * scalar, k)
        if mask != None:
            attn = attn - ((1-mask[:,:,None,None,:])*eps).type_as(attn)

        if attn.dtype is torch.bfloat16:
            with torch.cuda.amp.autocast(enabled=False):
            #attn_weights = self.softmax(attn)
                attn_weights = torch.nn.functional.softmax(attn, -1)
        else:
            attn_weights = torch.nn.functional.softmax(attn, -1)

        v_avg = torch.einsum(f"bhnij, bnjhc->bnihc",attn_weights, v)
        gate_v = (self.reshape_dim(self.gate_v(z_com))).sigmoid()
        z_com = (v_avg * gate_v).contiguous().view( v_avg.size()[:-2] + (-1,) )

        z_final = self.Linear_final(z_com)

        return  z_final

class Transition(nn.Module):

    def __init__(self, model_args):
        super(Transition, self).__init__()

        self.dz = model_args['Channel_z']
        self.n = model_args['Transition_n']

        self.norm = nn.LayerNorm(self.dz)
        self.transition = nn.Sequential(   nn.Linear(self.dz, self.dz*self.n),
                                           nn.ReLU(),
                                           nn.Linear(self.dz*self.n, self.dz)
                                        )
    def forward(self, z_com):

        # Apply layer normalization and transition through a bottleneck architecture 
        # (expansion followed by compression) with ReLU activation in the middle.
        # This prevents too much linearity in the model (many Linear layers are used)
        # and aids in learning more complex patterns.
        z_com = self.norm(z_com)
        z_com = self.transition(z_com)

        return z_com


#####################################################################################################################


class DeepHomo_middle(nn.Module):

    def __init__(self, model_args):
        super(DeepHomo_middle, self).__init__()

        args1d =model_args['BasicBlock1D']
        args2d = model_args['BasicBlock2D']

        self.identity1 = conv_identity_2d(args1d['InChannels']*2, args1d['Channels'][0], 1, 1, bias=False)
        self.identity2 = conv_identity_2d(args2d['InChannels'], args2d['Channels'][0], 1, 1, bias=False)
        self.identity3 = conv_identity_2d(args2d['Channels'][0]*2, args2d['Channels'][0], 1, 1, bias=False)

        self.layer2 = self._make_layer(conv2d, args2d)

        # output
        # if model_args['dist'] == True:
        #     self.conv = conv2d( args2d['Channels'][-1], model_args['dist_bins'], 1, 1)
        #     self.acti = nn.Softmax(1)
        # else:
        #     self.conv = conv2d( args2d['Channels'][-1], 1, 1, 1)
        #     self.acti = nn.Sigmoid()

        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight)


    # downsample
    def _downsample(self, conv, in_channels, out_channels, stride):

        if in_channels == out_channels and stride == 1 :
            return None
        else :
            return nn.Sequential( conv(in_channels, out_channels, kernel_size=1, stride=stride) )

    # make layers
    def _make_layer(self, fn, config):

        conv         = fn
        Block        = config['name']
        Num_Blocks   = len(config['Channels'])
        Block_Cycle  = config['num_Cycle']
        in_channels  = config['Channels'][0]
        out_channels = config['Channels']
        kernel_size  = config['Kernel_size']
        dilations    = config['Dilation']
        group        = config['Group']
        bias         = config['Bias']
        track_running_stats = config['track_running_stats']
        stride = 1
        padding = "same"

        layers = []
        for i in range(Num_Blocks):

            n_dilation = len(dilations)
            dilation = dilations[ i % n_dilation]

            if i == 0:
                downsample = self._downsample(conv, in_channels, out_channels[0], stride)
                layers.append( Block(in_channels, out_channels[0], kernel_size, stride, downsample, padding, dilation, group, bias, track_running_stats) )

                for j in range(1, Block_Cycle):
                    layers.append( Block(out_channels[0], out_channels[0], kernel_size, stride, None, padding, dilation, group, bias, track_running_stats) )
            else :
                downsample = self._downsample(conv, out_channels[i-1], out_channels[i], stride)
                layers.append( Block(out_channels[i-1], out_channels[i], kernel_size, stride, downsample, padding, dilation, group, bias, track_running_stats) )

                for j in range(1, Block_Cycle):
                    layers.append( Block(out_channels[i], out_channels[i], kernel_size, stride, None, padding, dilation, group, bias, track_running_stats) )

        return nn.Sequential(*layers)

    def forward(self, rec1d, lig1d, com2d):

        pair1 = seq2pairwise_v3(rec1d, lig1d)
        pair1 = self.identity1(pair1)

        pair2 = self.identity2(com2d)
        pair = torch.cat([pair1, pair2], dim=1)
        pair = self.identity3(pair)

        out = self.layer2(pair)
        #out_act = self.conv(out)
        #out_act = self.acti(out_act)

        return out#, out_act


class DeepHomo_Triangle(nn.Module):

    def __init__(self, model_config):
        super(DeepHomo_Triangle, self).__init__()

        self.triangle_args = model_config['triangle_args']

        self.TriangleMulti = nn.ModuleList([ TriangleMultiplication(self.triangle_args) for _ in range(self.triangle_args['num_TriangleMulti']) ])
        self.TriangleSelfR = nn.ModuleList([ TriangleSelfAttention(self.triangle_args) for _ in range(self.triangle_args['num_TriangleSelfR']) ])
        if self.triangle_args['num_TriangleSelfC'] > 0:
            self.TriangleSelfC = nn.ModuleList([ TriangleSelfAttention(self.triangle_args) for _ in range(self.triangle_args['num_TriangleSelfC']) ])
        self.Transition = nn.ModuleList([ Transition(self.triangle_args) for _ in range(self.triangle_args['num_Transition']) ])

        self.norm_final = nn.LayerNorm(self.triangle_args['final'])
        self.Linear_final = nn.Linear(self.triangle_args['final'], 2)

        self.act = nn.Sigmoid()

        self.drop = nn.Dropout(0.10)

    def forward(self, rec2d, lig2d, com2d, mask=None):

        z_com = com2d

        mask = None
        for idx in range(self.triangle_args['num_TriangleMulti']):
            z_com = z_com + self.drop( self.TriangleMulti[idx](z_com, rec2d, lig2d, mask) )
            z_com = z_com + self.drop( self.TriangleSelfR[idx](z_com, mask) )

            if self.triangle_args['num_TriangleSelfC'] > 0:
                _mask = None
                z_com_T = z_com.permute(0,2,1,3)
                z_com_T = self.TriangleSelfC[idx](z_com_T, _mask)
                z_com = z_com + self.drop( z_com_T.permute(0,2,1,3))
            z_com = z_com + self.Transition[idx](z_com)

        z_final_norm = self.norm_final(z_com)
        z_final = self.act( self.Linear_final(z_final_norm))
        return z_final.permute(0,3,1,2)


class TriangleBlock(nn.Module):
    def __init__(self, model_config):
        super(TriangleBlock, self).__init__()

        self.triangle_args = model_config['triangle_args']

        self.TriangleMulti = TriangleMultiplication(self.triangle_args)
        self.TriangleSelfR = TriangleSelfAttention(self.triangle_args)
        self.TriangleSelfC = TriangleSelfAttention(self.triangle_args)
        self.Transition = Transition(self.triangle_args)

        self.drop = nn.Dropout(0.10)

    def forward(self, rec2d, lig2d, com2d):

        z_com = com2d

        z_com = z_com + self.drop( self.TriangleMulti(z_com, rec2d, lig2d) )
        z_com = z_com + self.drop( self.TriangleSelfR(z_com) )

        _mask = None
        z_com_T = z_com.permute(0,2,1,3)
        z_com_T = self.TriangleSelfC(z_com_T, _mask)
        z_com = z_com + self.drop( z_com_T.permute(0,2,1,3))
        z_com = z_com + self.Transition(z_com)
        
        return rec2d, lig2d, z_com
        
class TriangleStack(nn.Module):
    def __init__(self, model_config, n_blocks: int):
        super(TriangleStack, self).__init__()

        self.triangle_args = model_config['triangle_args']

        self.n_blocks = n_blocks
        self.blocks = nn.ModuleList([ TriangleBlock(model_config) for _ in range(n_blocks) ])

        self.norm_final = nn.LayerNorm(self.triangle_args['final'])
        self.Linear_final = nn.Linear(self.triangle_args['final'], 1)

        self.act = nn.Sigmoid()

    def forward(self, rec2d, lig2d, com2d):
        
        if torch.is_grad_enabled():
            blocks_per_ckpt = 1
        else:
            blocks_per_ckpt = None

        _, _, z_com = checkpoint_blocks(
            self.blocks,
            args=(rec2d, lig2d, com2d),
            blocks_per_ckpt=blocks_per_ckpt
        )

        z_final_norm = self.norm_final(z_com)
        z_final = self.act( self.Linear_final(z_final_norm))

        return z_final
        

class ResnetTriangleStack(nn.Module):
    def __init__(self, model_config, n_blocks: int):
        super(ResnetTriangleStack, self).__init__()

        self.model_args = model_config['model_args']
        self.triangle_args = model_config['triangle_args']

        self.resnet_rec = DeepHomo_middle(self.model_args)
        self.resnet_com = DeepHomo_middle(model_config['triangle_conv_args'])

        self.n_blocks = n_blocks
        self.blocks = nn.ModuleList([ TriangleBlock(model_config) for _ in range(n_blocks) ])

        self.norm_final = nn.LayerNorm(self.triangle_args['final'])
        self.Linear_final = nn.Linear(self.triangle_args['final'], 1)

        self.act = nn.Sigmoid()

    def forward(self, rec1d, rec2d, lig1d, lig2d, com2d):
        # Modify rec1d and lig1d dimensions from (L, C) to (B, C, L)
        # rec1d: (L, 1280)  lig1d: (L, 1280)
        # rec2d: (1, L, L, 64)  lig2d: (1, L, L, 64)  com2d: (1, L1, L2, 64)
        rec1d = rec1d.unsqueeze(0).permute(0,2,1)
        lig1d = lig1d.unsqueeze(0).permute(0,2,1)
        rec2d = rec2d.permute(0,3,1,2)
        lig2d = lig2d.permute(0,3,1,2)
        com2d = com2d.permute(0,3,1,2)
        
        # # intra_contact/intra_distance
        rec2d = self.resnet_rec(rec1d, rec1d, rec2d)
        lig2d = self.resnet_rec(lig1d, lig1d, lig2d)
        rec2d = rec2d.permute(0,2,3,1)
        lig2d = lig2d.permute(0,2,3,1)

        # # inter_contact/inter_distance
        z_com = self.resnet_com(rec1d, lig1d, com2d)
        z_com = z_com.permute(0,2,3,1)


        # Checkpoint blocks
        if torch.is_grad_enabled():
            blocks_per_ckpt = 1
        else:
            blocks_per_ckpt = None

        _, _, z_com = checkpoint_blocks(
            self.blocks,
            args=(rec2d, lig2d, z_com),
            blocks_per_ckpt=blocks_per_ckpt
        )

        # Normalization and activation
        z_final_norm = self.norm_final(z_com)
        z_final = self.act( self.Linear_final(z_final_norm))

        return z_final