from turtle import forward
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import LayerNorm
from torch.nn import LayerNorm as ESM1bLayerNorm
from torch.nn.utils.weight_norm import weight_norm

from fairseq.models import FairseqDecoder

from .modules import gelu, apc, symmetrize


class ProteinSequenceMulticlassDecoder(FairseqDecoder):
    def __init__(self, args, embed_dim, class_num, alphabet):
        super().__init__(alphabet)
        self.alphabet = alphabet
        self.embed_dim = embed_dim
        self.class_num = class_num
        self.dense = nn.Linear(self.embed_dim, int(self.embed_dim / 2))
        self.layer_norm = LayerNorm(int(embed_dim / 2))
        self.classifier = nn.Linear(int(self.embed_dim / 2), class_num * 2)

    def forward(self, prev_result, with_prompt_num, tokens):
        batch_size = prev_result.size(0)
        embedding = torch.empty(
            (prev_result.size(0), prev_result.size(-1)),
            dtype=prev_result.dtype
        ).cuda()
        for idx, output_sequence in enumerate(prev_result):
            embedding[idx, :] = torch.mean(output_sequence[:sum(tokens[idx].ne(self.alphabet.padding_idx))-with_prompt_num, :], axis=0)
        # x = self.dense(embedding)
        # x = gelu(x)
        # x = self.layer_norm(x)
        x = self.classifier(x).reshape(batch_size, self.class_num, 2)
        return x