#!/usr/bin/env python
# encoding: utf-8
'''
@license: (C) Copyright 2021, Hey.
@author: Hey
@email: sanyuan.hy@alibaba-inc.com
@tel: 137****6540
@datetime: 2023/7/24 10:01
@project: LucaOne
@file: lucaone_gplm.py
@desc: LucaOne Model
'''

from utils.get_embed.lucaone_utils.alphabet import Alphabet
from utils.get_embed.lucaone_utils.modeling_gplm import *
from transformers.configuration_utils import PretrainedConfig


class LucaGPLMConfig(PretrainedConfig):
    def __init__(self,
                 vocab_size=-1,
                 pad_token_id=0,
                 max_position_embeddings: int = 4096,
                 type_vocab_size: int = 2,
                 num_hidden_layers: int = 24,
                 hidden_size: int = 1280,
                 num_attention_heads: int = 20,
                 no_position_embeddings: bool = False,
                 no_token_type_embeddings: bool = False,
                 alphabet: str = "gene_prot",
                 token_dropout: bool = True,
                 attention_probs_dropout_prob=0.1,
                 hidden_dropout_prob=0.1,
                 classifier_dropout_prob=0.1,
                 use_embed_layer_norm=True,
                 use_last_layer_norm=True,
                 embed_scale=1.0,
                 ignore_index=-100,
                 **kwargs):
        super().__init__(pad_token_id=pad_token_id, **kwargs)
        self.alphabet = alphabet
        self.vocab_size = vocab_size
        self.max_position_embeddings = max_position_embeddings
        self.type_vocab_size = type_vocab_size
        self.no_token_type_embeddings = no_token_type_embeddings
        self.no_position_embeddings = no_position_embeddings
        self.num_hidden_layers = num_hidden_layers
        self.hidden_size = hidden_size
        self.num_attention_heads = num_attention_heads
        self.token_dropout = token_dropout
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.hidden_dropout_prob = hidden_dropout_prob
        self.classifier_dropout_prob = classifier_dropout_prob
        self.ignore_index = ignore_index
        self.use_embed_layer_norm = use_embed_layer_norm
        self.use_last_layer_norm = use_last_layer_norm
        self.embed_scale = embed_scale
        
        
class LucaGPLM(nn.Module):
    def __init__(
            self,
            config,
            args=None
    ):
        super().__init__()
        self.config = config
        self.max_position_embeddings = config.max_position_embeddings
        self.type_vocab_size = config.type_vocab_size
        self.num_layers = config.num_hidden_layers
        self.embed_dim = config.hidden_size
        self.attention_heads = config.num_attention_heads
        self.no_position_embeddings = config.no_position_embeddings
        self.no_token_type_embeddings = config.no_token_type_embeddings
        if not isinstance(config.alphabet, Alphabet):
            self.alphabet = Alphabet.from_predefined(config.alphabet)
        else:
            self.alphabet = config.alphabet
        self.alphabet_size = len(self.alphabet)
        self.padding_idx = self.alphabet.padding_idx
        self.mask_idx = self.alphabet.mask_idx
        self.cls_idx = self.alphabet.cls_idx
        self.eos_idx = self.alphabet.eos_idx
        self.prepend_bos = self.alphabet.prepend_bos
        self.append_eos = self.alphabet.append_eos
        self.token_dropout = config.token_dropout
        self.ignore_index = config.ignore_index
        self.use_embed_layer_norm = config.use_embed_layer_norm
        self.use_last_layer_norm = config.use_last_layer_norm
        self.embed_scale = config.embed_scale
        self.pretrained_model_name = args.pretrained_model_name
        self._init_submodules()
        if self.pretrained_model_name is not None:
            print("Load pretrained_model_name=%s" % self.pretrained_model_name)
            self._init_submodules_new(self.pretrained_model_name)


    def _init_submodules(self):
        # normal_(0, 1)
        self.embed_tokens = nn.Embedding(
            self.alphabet_size,
            self.embed_dim,
            padding_idx=self.padding_idx,
        )
        self.embed_pos = None
        if not self.no_position_embeddings:
            self.embed_pos = nn.Embedding(self.max_position_embeddings, self.embed_dim)
        self.embed_type = None
        if not self.no_token_type_embeddings:
            self.embed_type = nn.Embedding(self.type_vocab_size, self.embed_dim)
        if self.use_embed_layer_norm:
            self.embed_layer_norm = LayerNorm(self.embed_dim)
        else:
            self.embed_layer_norm = None

        self.layers = nn.ModuleList(
            [
                LucaGPLMTransformerLayer(
                    self.embed_dim,
                    4 * self.embed_dim,
                    self.attention_heads,
                    add_bias_kv=False,
                    use_lucagplm1b_layer_norm=True,
                    use_rotary_embeddings=True,
                    )
                for _ in range(self.num_layers)
            ]
        )
        self.layer_size = len(self.layers)

        self.contact_head = ContactPredictionHead(
            self.num_layers * self.attention_heads,
            self.prepend_bos,
            self.append_eos,
            eos_idx=self.eos_idx,
            )
        if self.use_last_layer_norm:
            self.last_layer_norm = LayerNorm(self.embed_dim)
        else:
            self.last_layer_norm = None

        self.lm_head = RobertaLMHead(
            embed_dim=self.embed_dim,
            output_dim=self.alphabet_size,
            weight=self.embed_tokens.weight,
        )

    def _init_embedding(self, pretrained_token_matrix, token_matrix):
        '''
        0->2
        1->0
        2->3
        3->1
        4->10
        ...
        28->34
        29->36
        30->37
        31->38
        32->4
        '''
        print("Load pretrained exsists embedding vectors:")
        token_matrix[2, :] = pretrained_token_matrix[0, :]
        token_matrix[0, :] = pretrained_token_matrix[1, :]
        token_matrix[3, :] = pretrained_token_matrix[2, :]
        token_matrix[1, :] = pretrained_token_matrix[3, :]
        for idx in range(10, 35):
            token_matrix[idx, :] = pretrained_token_matrix[idx - 6, :]
        token_matrix[36, :] = pretrained_token_matrix[29, :]
        token_matrix[37, :] = pretrained_token_matrix[30, :]
        token_matrix[38, :] = pretrained_token_matrix[31, :]
        token_matrix[4, :] = pretrained_token_matrix[32, :]
        return token_matrix

    def _init_submodules_new(self, pretrained_model_name):
        print("Load pretrained model exists weights:")
        from esm import pretrained
        from collections import OrderedDict
        pretrained, _ = pretrained.load_model_and_alphabet(pretrained_model_name)
        pretrained_state_dict = pretrained.state_dict()
        new_state_dict = OrderedDict()
        our_model_state_dict = {}
        for key, value in self.state_dict().items():
            our_model_state_dict[key] = value
        for name, weight in pretrained_state_dict.items():
            if "final_layer_norm" in name:
                name = name.replace("final_layer_norm", "post_layer_norm")
            elif "self_attn_layer_norm" in name:
                name = name.replace("self_attn_layer_norm", "pre_layer_norm")
            elif "emb_layer_norm_after" in name:
                name = name.replace("emb_layer_norm_after", "last_layer_norm")
            if name.startswith("layers."):
                layer_id = name.split(".")[1]
                if int(layer_id) >= self.num_layers:
                    continue
            if name == "embed_tokens.weight":
                new_state_dict[name] = self._init_embedding(weight, our_model_state_dict[name])
                del our_model_state_dict[name]
            elif name in our_model_state_dict and our_model_state_dict[name].shape == weight.shape:
                del our_model_state_dict[name]
                new_state_dict[name] = weight

        print("Exists layer names:")
        print(new_state_dict.keys())
        print("Not exists Layer names:")
        print(our_model_state_dict.keys())
        new_state_dict.update(our_model_state_dict)
        self.load_state_dict(new_state_dict)


    def __forword__(self,
                    input_ids: Optional[torch.Tensor] = None,
                    attention_mask: Optional[torch.Tensor] = None,
                    token_type_ids: Optional[torch.Tensor] = None,
                    position_ids: Optional[torch.Tensor] = None,
                    output_keys: Optional[dict[str, set[str]]] = None,
                    labels: Optional[dict[str, dict[str, torch.Tensor]]] = None,
                    repr_layers=[-1],
                    need_head_weights=False,
                    return_contacts=False,
                    use_last_layer_norm=True):
        assert all(-(self.layer_size + 1) <= i <= self.layer_size for i in repr_layers)
        repr_layers = [(i + self.layer_size + 1) % (self.layer_size + 1) for i in repr_layers]

        if return_contacts:
            need_head_weights = True

        assert input_ids.ndim == 2
        # 动态求mask，(B * Seq_len) 被mask掉位置的值为True
        if attention_mask is None:
            padding_mask = input_ids.eq(self.padding_idx)
        else:
            padding_mask = attention_mask.eq(self.padding_idx)

        x = self.embed_scale * self.embed_tokens(input_ids)
        if self.embed_pos is not None and position_ids is not None:
            x += self.embed_scale * self.embed_pos(position_ids)
        if self.embed_type is not None and token_type_ids is not None:
            x += self.embed_scale * self.embed_type(token_type_ids)
        if self.embed_layer_norm is not None:
            x = self.embed_layer_norm(x)
        # Token dropout
        if self.token_dropout:
            x.masked_fill_((input_ids == self.mask_idx).unsqueeze(-1), 0.0)
            # x: B x L x C
            mask_ratio_train = 0.15 * 0.8
            src_lengths = (~padding_mask).sum(-1)
            mask_ratio_observed = (input_ids == self.mask_idx).sum(-1).to(x.dtype) / src_lengths
            x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]

        # Mask 操作
        if padding_mask is not None:
            x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))

        # 返回值包括哪些
        repr_layers = set(repr_layers)
        hidden_representations = {}
        # 0:embedding
        if 0 in repr_layers:
            hidden_representations[0] = x

        # 是否需要返回head weights
        if need_head_weights:
            attn_weights = []

        # (B, L, E) => (L, B, E)
        x = x.transpose(0, 1)

        if not padding_mask.any():
            padding_mask = None

        for layer_idx, layer in enumerate(self.layers):
            x, attn = layer(
                x,
                self_attn_padding_mask=padding_mask,
                need_head_weights=need_head_weights,
            )
            if (layer_idx + 1) in repr_layers:
                hidden_representations[layer_idx + 1] = x.transpose(0, 1)
            if need_head_weights:
                # (H, B, L, L) => (B, H, L, L)
                attn_weights.append(attn.transpose(1, 0))

        # (L, B, E)
        if self.last_layer_norm is not None and use_last_layer_norm:
            # 最后一层隐含层 加一层layernorm
            x = self.last_layer_norm(x)
        x = x.transpose(0, 1)  # (L, B, E) => (B, L,  E)

        # last hidden representation should have layer norm applied
        if (layer_idx + 1) in repr_layers:
            hidden_representations[layer_idx + 1] = x
        # 最后一层作为表征矩阵
        # (B, L, E)
        representation_matrix = hidden_representations[self.layer_size]
        # mask 任务
        # B * Seq_len * vocab_size
        lm_mask_logits = self.lm_head(x)
        
        return lm_mask_logits, hidden_representations


    def forward(
            self,
            input_ids: Optional[torch.Tensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            global_attention_mask: Optional[torch.Tensor] = None,
            token_type_ids: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.Tensor] = None,
            head_mask: Optional[torch.Tensor] = None,
            inputs_embeds: Optional[torch.Tensor] = None,
            output_keys: Optional[dict[str, set[str]]] = None,
            labels: Optional[dict[str, dict[str, torch.Tensor]]] = None,
            input_ids_b: Optional[torch.Tensor] = None,
            attention_mask_b: Optional[torch.Tensor] = None,
            global_attention_mask_b: Optional[torch.Tensor] = None,
            token_type_ids_b: Optional[torch.Tensor] = None,
            position_ids_b: Optional[torch.Tensor] = None,
            head_mask_b: Optional[torch.Tensor] = None,
            inputs_embeds_b: Optional[torch.Tensor] = None,
            output_keys_b: Optional[dict[str, set[str]]] = None,
            labels_b: Optional[dict[str, dict[str, torch.Tensor]]] = None,
            pair_label: Optional[dict[str, dict[str, torch.Tensor]]] = None,
            pair_output_keys: Optional[dict[str, set[str]]] = None,
            output_hidden_states: Optional[dict[str, set[str]]] = None,
            output_attentions: Optional[dict[str, set[str]]] = None,
            need_head_weights: Optional[bool] = None,
            return_contacts: Optional[bool] = None,
            repr_layers: Optional[list[int]] = None,
            return_dict: Optional[bool] = None,
            use_last_layer_norm: Optional[bool] = True
    ):
        if repr_layers is None or len(repr_layers) == 0:
            repr_layers = [-1]
        if return_contacts is None:
            return_contacts = False
        if need_head_weights is None:
            need_head_weights = True

        
        return self.__forword__(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                position_ids=position_ids,
                output_keys=output_keys,
                labels=labels,
                repr_layers=repr_layers,
                need_head_weights=need_head_weights,
                return_contacts=return_contacts,
                use_last_layer_norm=use_last_layer_norm
            )