from turtle import forward
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import LayerNorm
from torch.nn import LayerNorm as ESM1bLayerNorm
from torch.nn.utils.weight_norm import weight_norm

from fairseq.models import FairseqDecoder

from .modules import gelu, apc, symmetrize


class ProteinContactMapDecoder(FairseqDecoder):
    def __init__(self, args, embed_dim, alphabet):
        super().__init__(alphabet)
        self.embed_dim = embed_dim
        self.predict = nn.Sequential(
            nn.Dropout(), nn.Linear(2 * embed_dim, 2))

    def forward(self, prev_result, with_prompt_num):
        prod = prev_result[:, :, None, :] * prev_result[:, None, :, :]
        diff = prev_result[:, :, None, :] - prev_result[:, None, :, :]
        pairwise_features = torch.cat((prod, diff), -1)
        prediction = self.predict(pairwise_features)
        prediction = (prediction + prediction.transpose(1, 2)) / 2
        prediction = prediction[:, 1:-1-with_prompt_num, 1:-1-with_prompt_num].contiguous()  # remove start/stop tokens
        outputs = prediction
        return outputs


class ProteinContactMapLogisticRegressionDecoder(FairseqDecoder):
    def __init__(self, args,  embed_dim, alphabet):
        super().__init__(alphabet)
        self.regression = nn.Linear(embed_dim, 1, True)
        self.activation = nn.Sigmoid()

    def forward(self, prev_result, with_prompt_num):
        attentions = prev_result[:, :, :, 1:-1-with_prompt_num, 1:-1-with_prompt_num]
        batch_size, layers, heads, seqlen, _ = attentions.size()
        attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
        attentions = attentions.to(
            next(self.parameters())
        )

        attentions = apc(symmetrize(attentions))
        attentions = attentions.permute(0, 2, 3, 1)

        return self.activation(self.regression(attentions).squeeze(3))


class ProteinSecondaryStructureDecoder(FairseqDecoder):
    def __init__(self, args, embed_dim, label_num, alphabet):
        super().__init__(alphabet)
        self.conv = nn.Conv1d(in_channels=embed_dim, out_channels=embed_dim, kernel_size=3)
        self.dense = nn.Linear(embed_dim , int(embed_dim / 2))
        self.layer_norm = LayerNorm(int(embed_dim / 2))
        self.classifier = nn.Linear(int(embed_dim / 2), label_num)

    def forward(self, prev_result):
        input = prev_result.permute(0, 2, 1)
        x = self.conv(input).permute(0, 2, 1)
        x = self.dense(x)
        x = gelu(x)
        x = self.layer_norm(x)
        return self.classifier(x)


class ProteinLMDecoder(FairseqDecoder):
    def __init__(self, args, embed_dim, weight, alphabet):
        super().__init__(alphabet)
        self.args = args
        self.embed_dim = embed_dim
        self.alphabet = alphabet
        self.protein_toks_size = len(alphabet) - 3

        self.dense = nn.Linear(self.embed_dim, self.embed_dim)
        self.layer_norm = LayerNorm(embed_dim)
        self.weight = weight
        self.bias = nn.Parameter(torch.zeros(self.protein_toks_size))


    def forward(self, prev_result):
        x = self.dense(prev_result)
        x = gelu(x)
        x = self.layer_norm(x)
        x = F.linear(x, self.weight[:len(self.alphabet)-3, :]) + self.bias
        return x


class ProteinCoordinateDecoder(FairseqDecoder):
    def __init__(self, args, embed_dim, alphabet):
        super().__init__(alphabet)
        self.args = args
        self.padding_idx = alphabet.padding_idx
        self.conv = nn.Conv1d(in_channels=embed_dim, out_channels=embed_dim, kernel_size=3)
        self.dense = nn.Linear(embed_dim, embed_dim)
        self.layer_norm = LayerNorm(embed_dim)
        self.dense2coord = nn.Linear(embed_dim, 3)

    def forward(self, prev_result):
        input = prev_result.permute(0, 2, 1)
        x = self.conv(input).permute(0, 2, 1)
        x = self.dense(x)
        x = gelu(x)
        x = self.layer_norm(x)
        x = self.dense2coord(x)
        return x


class ProteinInteractionDecoder(FairseqDecoder):
    def __init__(self, args, embed_dim, alphabet):
        super().__init__(alphabet)
        self.args = args
        self.embed_dim = embed_dim
        self.alphabet = alphabet
        self.dense = nn.Linear(2 * self.embed_dim, int(self.embed_dim / 2))
        self.linear = nn.Linear(int(self.embed_dim / 2), int(self.embed_dim / 8))
        self.classifier = nn.Linear(int(self.embed_dim / 8), 2)

    def forward(self, prev_result, tokens):
        batch_size = prev_result.size(0)
        ppi_embedding = torch.empty(
            (prev_result.size(0), prev_result.size(-1)),
            dtype=prev_result.dtype
        ).cuda()
        for idx, output_sequence in enumerate(prev_result):
            ppi_embedding[idx, :] = torch.mean(output_sequence[:sum(tokens[idx].ne(self.alphabet.padding_idx)), :], axis=0)
        result = torch.empty((batch_size, batch_size, int(self.embed_dim / 2)), dtype=prev_result.dtype).cuda()
        for i in range(batch_size):
            for j in range(batch_size):
                result[i, j, :] = self.dense(torch.cat([ppi_embedding[i, :], ppi_embedding[j, :]]))
        result = symmetrize(result)
        result = gelu(self.linear(result))
        result = self.classifier(result)
        return result


class ProteinPairInteractionDecoder(FairseqDecoder):
    def __init__(self, args, embed_dim, alphabet):
        super().__init__(alphabet)
        self.args = args
        self.embed_dim = embed_dim
        self.alphabet = alphabet
        self.dense = nn.Linear(2 * self.embed_dim, int(self.embed_dim / 2))
        self.linear = nn.Linear(int(self.embed_dim / 2), int(self.embed_dim / 8))
        self.classifier = nn.Linear(int(self.embed_dim / 8), 2)

    def forward(self, prev_result, with_prompt_num, tokens):
        ppi_embedding = torch.empty(
            (prev_result.size(0), prev_result.size(-1)),
            dtype=prev_result.dtype
        ).cuda()
        for idx, output_sequence in enumerate(prev_result):
            ppi_embedding[idx, :] = torch.mean(output_sequence[:sum(tokens[idx].ne(self.alphabet.padding_idx))-with_prompt_num, :], axis=0)
        result = ppi_embedding.reshape(2, -1, prev_result.size(-1))
        result = self.dense(torch.cat([result[0], result[1]], dim=1)) + self.dense(torch.cat([result[1], result[0]], dim=1))
        result = gelu(self.linear(result))
        result = self.classifier(result)
        return result


class ProteinFusionDecoder(FairseqDecoder):
    def __init__(self, args, embed_dim, weight, alphabet):
        super().__init__(alphabet)
        self.mlm_decoder = ProteinLMDecoder(args, embed_dim, weight, alphabet)
        self.crd_decoder = ProteinCoordinateDecoder(args, embed_dim, alphabet)
        self.ppi_decoder = ProteinInteractionDecoder(args, embed_dim, alphabet)
    
    def mlm_forward(self, prev_result):
        decoder_out = self.mlm_decoder(prev_result)
        return decoder_out

    def crd_forward(self, prev_result):
        decoder_out = self.crd_decoder(prev_result)
        return decoder_out
    
    def ppi_forward(self, prev_result, tokens):
        decoder_out = self.ppi_decoder(prev_result, tokens)
        return decoder_out





# class ProteinFunctionDecoder(FairseqDecoder):
#     def __init__(self, args, embed_dim, class_num, alphabet):
#         super().__init__(alphabet)
#         self.alphabet = alphabet
#         self.embed_dim = embed_dim
#         self.class_num = class_num
#         self.dense = nn.Linear(self.embed_dim, int(self.embed_dim / 2))
#         self.layer_norm = LayerNorm(int(embed_dim / 2))
#         self.classifier = nn.Linear(int(self.embed_dim / 2), class_num * 2)

#     def forward(self, prev_result, with_prompt_num, tokens):
#         batch_size = prev_result.size(0)
#         embedding = torch.empty(
#             (prev_result.size(0), prev_result.size(-1)),
#             dtype=prev_result.dtype
#         ).cuda()
#         for idx, output_sequence in enumerate(prev_result):
#             embedding[idx, :] = torch.mean(output_sequence[:sum(tokens[idx].ne(self.alphabet.padding_idx))-with_prompt_num, :], axis=0)

#         x = self.dense(embedding)
#         x = gelu(x)
#         x = self.layer_norm(x)
#         x = self.classifier(x).reshape(batch_size, self.class_num, 2)
#         return x


class ProteinAnnotationDecoder(FairseqDecoder):
    def __init__(self, args, embed_dim, class_num, alphabet):
        super().__init__(alphabet)
        self.alphabet = alphabet
        self.embed_dim = embed_dim
        self.class_num = class_num
        self.dense = nn.Linear(self.embed_dim, int(self.embed_dim / 2))
        self.layer_norm = LayerNorm(int(embed_dim / 2))
        self.classifier = nn.Linear(int(self.embed_dim / 2), class_num)

    def forward(self, prev_result, with_prompt_num, tokens):
        embedding = torch.empty(
            (prev_result.size(0), prev_result.size(-1)),
            dtype=prev_result.dtype
        ).cuda()
        for idx, output_sequence in enumerate(prev_result):
            embedding[idx, :] = torch.mean(output_sequence[:sum(tokens[idx].ne(self.alphabet.padding_idx))-with_prompt_num, :], axis=0)
        x = self.dense(embedding)
        x = gelu(x)
        x = self.layer_norm(x)
        x = self.classifier(x)
        return x


class ProteinFoldDecoder(FairseqDecoder):
    def __init__(self, args, embed_dim, class_num, alphabet):
        super().__init__(alphabet)
        self.dropout = 0.1
        self.alphabet = alphabet
        self.classifier = nn.Sequential(
            weight_norm(nn.Linear(embed_dim, int(embed_dim / 2)), dim=None),
            nn.ReLU(),
            nn.Dropout(self.dropout, inplace=False),
            weight_norm(nn.Linear(int(embed_dim / 2), class_num), dim=None)
        )
        self.init_weights()
    
    def init_weights(self):
        """ Initialize and prunes weights if needed. """
        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        """ Initialize the weights """
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
        elif isinstance(module, LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    def forward(self, prev_result, with_prompt_num, tokens):
        embedding = torch.empty(
            (prev_result.size(0), prev_result.size(-1)),
            dtype=prev_result.dtype
        ).cuda()
        for idx, output_sequence in enumerate(prev_result):
            embedding[idx, :] = torch.mean(output_sequence[:sum(tokens[idx].ne(self.alphabet.padding_idx))-with_prompt_num, :], axis=0)

        outputs = self.classifier(embedding)
        return outputs
 

class ProteinStabilityDecoder(FairseqDecoder):
    def __init__(self, args, embed_dim, alphabet, dropout=0.1):
        super().__init__(alphabet)
        self.embed_dim = embed_dim
        self.alphabet = alphabet
        self.main = nn.Sequential(
            weight_norm(nn.Linear(embed_dim, int(embed_dim / 2)), dim=None),
            nn.ReLU(),
            nn.Dropout(dropout, inplace=False),
            weight_norm(nn.Linear(int(embed_dim / 2), 1), dim=None)
        )
    
    def forward(self, prev_result, with_prompt_num, tokens):
        embedding = torch.empty(
            (prev_result.size(0), prev_result.size(-1)),
            dtype=prev_result.dtype
        ).cuda()
        for idx, output_sequence in enumerate(prev_result):
            embedding[idx, :] = torch.mean(output_sequence[:sum(tokens[idx].ne(self.alphabet.padding_idx))-with_prompt_num, :], axis=0)
        x = self.main(embedding)
        return x


class ProteinFluorescenceDecoder(FairseqDecoder):
    def __init__(self, args, embed_dim, alphabet, dropout=0.1):
        super().__init__(alphabet)
        self.embed_dim = embed_dim
        self.alphabet = alphabet
        self.main = nn.Sequential(
            weight_norm(nn.Linear(embed_dim, int(embed_dim / 2)), dim=None),
            nn.ReLU(),
            nn.Dropout(dropout, inplace=False),
            weight_norm(nn.Linear(int(embed_dim / 2), 1), dim=None)
        )
    
    def forward(self, prev_result, with_prompt_num, tokens):
        embedding = torch.empty(
            (prev_result.size(0), prev_result.size(-1)),
            dtype=prev_result.dtype
        ).cuda()
        for idx, output_sequence in enumerate(prev_result):
            embedding[idx, :] = torch.mean(output_sequence[:sum(tokens[idx].ne(self.alphabet.padding_idx))-with_prompt_num, :], axis=0)
        x = self.main(embedding)
        return x


class ProteinMutationDecoder(FairseqDecoder):
    def __init__(self, args, embed_dim, alphabet, dropout=0.1):
        super().__init__(alphabet)
        self.embed_dim = embed_dim
        self.alphabet = alphabet
        self.main = nn.Sequential(
            weight_norm(nn.Linear(embed_dim, int(embed_dim / 2)), dim=None),
            nn.ReLU(),
            nn.Dropout(dropout, inplace=False),
            weight_norm(nn.Linear(int(embed_dim / 2), 1), dim=None)
        )

    def forward(self, origin_prev_result, mutate_pre_result, tokens, with_prompt_num):
        result = torch.empty(
            (origin_prev_result.size(0), 1),
            dtype=torch.float32
        ).cuda()
        for idx, origin_output_sequence in enumerate(origin_prev_result):
            result[idx, :] = self.main(torch.mean(origin_output_sequence[:sum(tokens[idx].ne(self.alphabet.padding_idx))-with_prompt_num, :], axis=0)) / self.main(torch.mean(mutate_pre_result[idx, :sum(tokens[idx].ne(self.alphabet.padding_idx))-with_prompt_num, :], axis=0))
        return result
