import torch
import torch.nn as nn
import math


def PE1d_sincos(seq_length, dim):

    if dim % 2 != 0:
        raise ValueError("Cannot use sin/cos positional encoding with "
                         "odd dim (got dim={:d})".format(dim))
    pe = torch.zeros(seq_length, dim)
    position = torch.arange(0, seq_length).unsqueeze(1)
    div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) *
                         -(math.log(10000.0) / dim)))
    pe[:, 0::2] = torch.sin(position.float() * div_term)
    pe[:, 1::2] = torch.cos(position.float() * div_term)

    return pe.unsqueeze(1)


class PositionEmbedding(nn.Module):

    def __init__(self, seq_length, dim, dropout, grad=False):
        super().__init__()
        self.embed = nn.Parameter(data=PE1d_sincos(seq_length, dim), requires_grad=grad)
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, x):
        
        x = x + self.embed.expand(x.shape)
        x = self.dropout(x)
        return x
    
class ResConv1DBlock(nn.Module):
    def __init__(self, n_in, n_state, kernel_sizes, padding):
        super().__init__()
        self.block = nn.Sequential(
            nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True),
            nn.ReLU(),
            nn.Conv1d(n_in, n_state, kernel_sizes, 1, padding),
            nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True),
            nn.ReLU(),
            nn.Conv1d(n_state, n_in, 1, 1, 0,),
            )

    def forward(self, x):
        return x + self.block(x)


class Resnet1D(nn.Module):
    def __init__(self, n_in, n_depth, kernel_sizes, padding):
        super().__init__()
        blocks = [ResConv1DBlock(n_in, n_in, kernel_sizes, padding) for d in range(n_depth)]

        self.model = nn.Sequential(*blocks)

    def forward(self, x):
        return self.model(x)
    

class ProteinConv(nn.Module):
    def __init__(self,
                 latent_conv_dim=256,
                 latent_trans_dim=256,
                 down_t=2,
                 stride_t=2,
                 depth=3, 
                 kernel_size=15):
        super().__init__()
        
        blocks = []
        padding = kernel_size // 2
        self.conv1 = nn.Conv1d(latent_conv_dim, latent_trans_dim, 3, 1, 1)
        self.act1 = nn.ReLU()
        self.bn1 = nn.BatchNorm1d(num_features=latent_trans_dim, eps=1e-6, affine=True)
        
        for i in range(down_t):
            block = nn.Sequential(
                nn.Conv1d(latent_trans_dim, latent_trans_dim, kernel_size, stride_t, padding=padding),
                Resnet1D(latent_trans_dim, depth, kernel_size, padding=padding),
            )
            blocks.append(block)
        self.model = nn.Sequential(*blocks)

    def forward(self, x):
        x = self.conv1(x)
        x = self.act1(x)
        x = self.bn1(x)
        x = self.model(x)
        return x.transpose(-2, -1)
    

class ProteinEncoder(nn.Module):
    def __init__(self,
                 pos_grad=False,
                 down_t=2,
                 stride_t=2,
                 depth=3, 
                 kernel_size=15,
                 n_token = 800,
                 protein_vocab_size = 30,
                 latent_conv_dim=256,
                 latent_trans_dim=256,
                 nb_layer=4,
                 nb_head=4,
                 dropout={'pos':0.0, 'trans':0.1},
                 activation="gelu"):
        super().__init__()

        self.n_token = n_token
        self.dropout = dropout
        self.down_t = down_t
        self.latent_trans_dim = latent_trans_dim

        self.token_embed = nn.Embedding(protein_vocab_size, latent_conv_dim)
        self.resconv = ProteinConv(latent_conv_dim, latent_trans_dim, down_t, stride_t, depth, kernel_size)
        self.global_token = nn.Parameter(torch.randn(1, 1, latent_trans_dim))
        self.pos_embed = PositionEmbedding(n_token // (2**down_t) + 1, latent_trans_dim, dropout['pos'], pos_grad)
        ff_size = 4 * latent_trans_dim
        seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=latent_trans_dim,
                                                          nhead=nb_head,
                                                          dim_feedforward=ff_size,
                                                          dropout=dropout['trans'],
                                                          activation=activation)
        
        self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer,
                                                     num_layers=nb_layer)

    def forward(self, protein):
        
        bs, num_token = protein.shape
        assert num_token == self.n_token, "Different length !"
        token_embeddings = self.token_embed(protein)
        downsample = self.resconv(token_embeddings.transpose(-2, -1))
        protein_in = downsample.transpose(0, 1)

        # adding global token
        global_token = self.global_token.repeat(1, bs, 1)
        x = torch.cat((global_token, protein_in), axis=0)

        # add positional encoding
        xseq = self.pos_embed(x)
        
        # final prediction
        final = self.seqTransEncoder(xseq)
                
        return final[0:1].transpose(0, 1), final[1:].transpose(0, 1)
    
    def update(self, seq_length):
        self.pos_embed = PositionEmbedding(seq_length // (2**self.down_t) + 1, self.latent_trans_dim, self.dropout['pos'], False)
        self.n_token = seq_length

