
from collections import OrderedDict
import torch
from torch.nn.parameter import Parameter
from fairseq.models import FairseqEncoderDecoderModel
from fairseq.models import register_model, register_model_architecture

from ..modules import ProteinBertEncoder, ProteinFusionDecoder


@register_model("protein_esm_fusion")
class FusionModel(FairseqEncoderDecoderModel):

    @staticmethod
    def add_args(parser):
        parser.add_argument(
            "--layer_num", default=33, type=int, metavar="N", help="number of layers"
        )
        parser.add_argument(
            "--embed_dim", default=1280, type=int, metavar="N", help="embedding dimension"
        )
        parser.add_argument(
            "--ffn_embed_dim",
            default=5120,
            type=int,
            metavar="N",
            help="embedding dimension for FFN",
        )
        parser.add_argument(
            "--attention_head_num",
            default=20,
            type=int,
            metavar="N",
            help="number of attention heads",
        )
        parser.add_argument("--max_position_num", default=1024, type=int, help="number of positional embeddings to learn")
        parser.add_argument("--emb_layer_norm_before", default=True, type=bool)
        parser.add_argument("--checkpoint_path", type=str)

    @classmethod
    def build_model(cls, args, task):
        encoder = ProteinBertEncoder(args, args.max_position_num, args.layer_num, args.attention_head_num, args.embed_dim, args.ffn_embed_dim, task.alphabet)
        decoder = ProteinFusionDecoder(args, args.embed_dim, encoder.embed_tokens.weight, task.alphabet)
        model = FusionModel(encoder, decoder)

        with torch.no_grad():
            new_state_dict = OrderedDict()
            with open('../../esm1b_t33_650M_UR50S.pt', 'rb') as f:
                pretrain_decoder_dict = torch.load(f, map_location=torch.device('cpu'))
                for k, v in pretrain_decoder_dict['model'].items():
                    if 'embed_tokens' in k:
                        encoder.embed_tokens.weight[:33, :] = v
                    elif 'sentence_encoder.' in k:
                        k = k.replace('sentence_encoder.', '')
                        new_state_dict[k] = v
                    elif 'encoder.lm_head.weight' == k:
                        decoder.mlm_decoder.weight[:33, :] = v
                    elif 'lm_head.' in k:
                        k = k.replace('encoder.lm_head.', 'decoder.mlm_decoder.')
                        new_state_dict[k] = v

            model.load_state_dict(new_state_dict, strict=False)

        for name, params in model.encoder.named_parameters():
            if 'layer_gated.weight' in name:
                params.data = torch.zeros((1, 1280), dtype=torch.float32)
        return model


    def mlm_forward(self, tokens, with_prompt_num, layer_gate=None):
        encoder_out = self.encoder(tokens, with_prompt_num=with_prompt_num, layer_gate=layer_gate)['logits']
        decoder_out = self.decoder.mlm_forward(encoder_out)
        return decoder_out

    def crd_forward(self, tokens, with_prompt_num, layer_gate=None):
        encoder_out = self.encoder(tokens, with_prompt_num=with_prompt_num, layer_gate=layer_gate)['logits']
        decoder_out = self.decoder.crd_forward(encoder_out)
        return decoder_out
    
    def ppi_forward(self, tokens, with_prompt_num, layer_gate=None):
        encoder_out = self.encoder(tokens, with_prompt_num=with_prompt_num, layer_gate=layer_gate)['logits']
        decoder_out = self.decoder.ppi_forward(encoder_out, tokens)
        return decoder_out


@register_model_architecture('protein_esm_fusion', 'fusion')
def roberta_large(args):
    args.layer_num = getattr(args, 'layer_num', 33)
    args.embed_dim = getattr(args, 'embed_dim', 1280)
    args.ffn_embed_dim = getattr(args, 'ffn_embed_dim', 5120)
    args.attention_head_num = getattr(args, 'attention_head_num', 20)
    args.max_position_num = getattr(args, 'max_position_num', 1024)
    args.emb_layer_norm_before = getattr(args, 'emb_layer_norm_before', True)