# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from copy import deepcopy
import math
from typing import Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import OmegaConf

import esm
from esm.modules import (
    TransformerLayer,
    LearnedPositionalEmbedding,
    SinusoidalPositionalEmbedding,
    RobertaLMHead,
    ESM1bLayerNorm,
    ContactPredictionHead,
    ESM1LayerNorm,
    FeedForwardNetwork,
    NormalizedResidualBlock,
    gelu,
)
from esm.multihead_attention import MultiheadAttention
from src.utils.config import compose_config as Cfg, merge_config
from .timestep import (
    modulate,
    TimestepEmbedder,
    ConditionalLayerNorm,
    MyRobertaLMHead
)

class ESM2WithStructuralAdatperTime(nn.Module):
    @classmethod
    def from_pretrained(cls, args, override_args=None, name='esm2_t33_650M_UR50D'):
        import esm
        pretrained_model, alphabet = esm.pretrained.load_model_and_alphabet_hub(name)

        pretrained_args = Cfg(
            num_layers=pretrained_model.num_layers, 
            embed_dim=pretrained_model.embed_dim, 
            attention_heads=pretrained_model.attention_heads, 
            token_dropout=pretrained_model.token_dropout, 
        )
        args = merge_config(pretrained_args, args)
        # args.adapter_layer_indices = getattr(args, 'adapter_layer_indices', [6, 20, 32])

        args.adapter_layer_indices = [-1]
        args.adapter_layer_indices = list(
            map(lambda x: (args.num_layers + x) % args.num_layers, 
                args.adapter_layer_indices)
        )

        model = cls(args, deepcopy(alphabet)) 
        # model.load_state_dict(pretrained_model.state_dict(), strict=False)        

        del pretrained_model

        # freeze pretrained parameters
        # for pname, param in model.named_parameters():
        #     if ('adapter' not in pname) and ('timesteps' not in pname) and ('embed_structures' not in pname) and ('adaLN' not in pname):
        #         param.requires_grad = False
        return model 

    def __init__(
        self,
        args,
        alphabet: Union[esm.data.Alphabet, str] = "ESM-1b",
        # num_layers: int = 33,
        # embed_dim: int = 1280,
        # attention_heads: int = 20,
        # token_dropout: bool = True,
    ):
        super().__init__()
        self.args = args
        self.num_layers = args.num_layers
        self.embed_dim = args.embed_dim
        self.attention_heads = args.attention_heads
        if not isinstance(alphabet, esm.data.Alphabet):
            alphabet = esm.data.Alphabet.from_architecture(alphabet)
        self.alphabet = alphabet
        self.alphabet_size = len(alphabet)
        self.padding_idx = alphabet.padding_idx
        self.mask_idx = alphabet.mask_idx
        self.cls_idx = alphabet.cls_idx
        self.eos_idx = alphabet.eos_idx
        self.prepend_bos = alphabet.prepend_bos
        self.append_eos = alphabet.append_eos
        self.token_dropout = args.token_dropout

        self._init_submodules()
        self.embed_timesteps = TimestepEmbedder(self.args.embed_dim)
        self.embed_structures = nn.Sequential(
            nn.Linear(self.args.encoder.d_model, self.args.embed_dim, bias=True),
            nn.SiLU(),
            nn.Linear(self.args.embed_dim, self.args.embed_dim, bias=True),
        )
        self.initialize_weights()

    def initialize_weights(self):
        # Initialize transformer layers:
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
        self.apply(_basic_init)
        
        # Initialize label embedding table:
        nn.init.normal_(self.embed_structures[0].weight, std=0.02)
        nn.init.normal_(self.embed_structures[2].weight, std=0.02)
        
        # Initialize timestep embedding MLP:
        nn.init.normal_(self.embed_timesteps.mlp[0].weight, std=0.02)
        nn.init.normal_(self.embed_timesteps.mlp[2].weight, std=0.02)

        # Zero-out adaLN modulation layers in DiT blocks:
        for layer_idx, layer in enumerate(self.layers):
            if layer_idx in self.args.adapter_layer_indices:
                pass
                # nn.init.constant_(layer.structural_adapter_attn.layer.out_proj.weight, 0)
                # nn.init.constant_(layer.structural_adapter_attn.layer.out_proj.bias, 0)
                # nn.init.constant_(layer.structural_adapter_ffn.layer.fc2.weight, 0)
                # nn.init.constant_(layer.structural_adapter_ffn.layer.fc2.bias, 0)
                # Zero-out adaLN modulation layers in adapter blocks:
                # nn.init.constant_(layer.structural_adapter_ffn.adaLN_modulation[-1].weight, 0)
                # nn.init.constant_(layer.structural_adapter_ffn.adaLN_modulation[-1].bias, 0)

        # Zero-out output layers:
        # nn.init.constant_(self.final_layer.linear.weight, 0)
        # nn.init.constant_(self.final_layer.linear.bias, 0)
        
    def _init_submodules(self):
        self.embed_scale = 1
        self.embed_tokens = nn.Embedding(
            self.alphabet_size,
            self.embed_dim,
            padding_idx=self.padding_idx,
        )

        self.layers = nn.ModuleList(
            [
                self._init_layer(_)
                for _ in range(self.num_layers)
            ]
        )

        self.contact_head = ContactPredictionHead(
            self.num_layers * self.attention_heads,
            self.prepend_bos,
            self.append_eos,
            eos_idx=self.eos_idx,
        )
        # self.emb_layer_norm_after = ESM1bLayerNorm(self.embed_dim)
        self.emb_layer_norm_after = ConditionalLayerNorm(self.args.embed_dim, self.args.embed_dim)
        # self.lm_head = RobertaLMHead(
        #     embed_dim=self.embed_dim,
        #     output_dim=self.alphabet_size,
        #     weight=self.embed_tokens.weight,
        # )
        self.lm_head = MyRobertaLMHead( # Spike: unstable training
            embed_dim=self.args.embed_dim,
            output_dim=self.alphabet_size,
            weight=self.embed_tokens.weight,
        )
        
    def _init_layer(self, layer_idx):
        if layer_idx in self.args.adapter_layer_indices:
            layer = TransforerLayerWithStructralAdapterTime(
                self.embed_dim,
                4 * self.embed_dim,
                self.attention_heads,
                add_bias_kv=False,
                use_esm1b_layer_norm=True,
                use_rotary_embeddings=True,
                encoder_embed_dim=self.args.encoder.d_model,
                dropout=self.args.dropout
            )
        else:
            layer = TransformerLayer(
                self.embed_dim,
                4 * self.embed_dim,
                self.attention_heads,
                add_bias_kv=False,
                use_esm1b_layer_norm=True,
                use_rotary_embeddings=True,
            )
        return layer

    def forward_layers(self, x, c, encoder_out, padding_mask, repr_layers=[], hidden_representations=[], need_head_weights=False, attn_weights=[]):
        for layer_idx, layer in enumerate(self.layers):
            if layer_idx in self.args.adapter_layer_indices:
                x, attn = layer(
                    x, c, encoder_out, self_attn_padding_mask=padding_mask, need_head_weights=need_head_weights
                )
            else:
                x, attn = layer(
                    x, self_attn_padding_mask=padding_mask, need_head_weights=need_head_weights
                )
            if (layer_idx + 1) in repr_layers:
                hidden_representations[layer_idx + 1] = x.transpose(0, 1)
            if need_head_weights:
                # (H, B, T, T) => (B, H, T, T)
                attn_weights.append(attn.transpose(1, 0))

        return x, hidden_representations, attn_weights, layer_idx

    def forward(self, tokens, alpha_t_bar, context, timesteps, encoder_out, repr_layers=[], need_head_weights=False, return_contacts=False):
        if return_contacts:
            need_head_weights = True

        assert tokens.ndim == 2
        padding_mask = tokens.eq(self.padding_idx)  # B, T

        x = self.embed_scale * self.embed_tokens(tokens)

        if self.token_dropout:
            x.masked_fill_((tokens == self.mask_idx).unsqueeze(-1), 0.0)
            # x: B x T x C
            mask_ratio_train = 0.15 * 0.8
            src_lengths = (~padding_mask).sum(-1)
            mask_ratio_observed = (tokens == self.mask_idx).sum(-1).to(x.dtype) / src_lengths
            x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]

        t = self.embed_timesteps(timesteps.squeeze(-1)) # B, H
        c = t
        if context is not None:
            c += context
        if self.embed_structures is not None:
            pooled_feats = encoder_out['aligned_feats'].sum(1) / encoder_out['aligned_label_mask'].sum(1, keepdim=True)
            c += self.embed_structures(pooled_feats)

        if padding_mask is not None:
            x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))

        repr_layers = set(repr_layers)
        hidden_representations = {}
        if 0 in repr_layers:
            hidden_representations[0] = x

        if need_head_weights:
            attn_weights = []

        # (B, T, E) => (T, B, E)
        x = x.transpose(0, 1)

        if not padding_mask.any():
            padding_mask = None

        # for layer_idx, layer in enumerate(self.layers):
        #     x, attn = layer(
        #         x,
        #         self_attn_padding_mask=padding_mask,
        #         need_head_weights=need_head_weights,
        #     )
        #     if (layer_idx + 1) in repr_layers:
        #         hidden_representations[layer_idx + 1] = x.transpose(0, 1)
        #     if need_head_weights:
        #         # (H, B, T, T) => (B, H, T, T)
        #         attn_weights.append(attn.transpose(1, 0))

        x, hidden_representations, attn_weights, layer_idx = self.forward_layers(
            x, c, encoder_out, padding_mask, 
            repr_layers=repr_layers, 
            hidden_representations=hidden_representations,
            need_head_weights=need_head_weights,
            attn_weights=attn_weights if need_head_weights else None
        )


        x = self.emb_layer_norm_after(x, c)
        x = x.transpose(0, 1)  # (T, B, E) => (B, T, E)

        # last hidden representation should have layer norm applied
        if (layer_idx + 1) in repr_layers:
            hidden_representations[layer_idx + 1] = x
        x = self.lm_head(x, c)

        result = {"logits": x, "representations": hidden_representations}
        if need_head_weights:
            # attentions: B x L x H x T x T
            attentions = torch.stack(attn_weights, 1)
            if padding_mask is not None:
                attention_mask = 1 - padding_mask.type_as(attentions)
                attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
                attentions = attentions * attention_mask[:, None, None, :, :]
            result["attentions"] = attentions
            if return_contacts:
                contacts = self.contact_head(tokens, attentions)
                result["contacts"] = contacts

        return result

    def predict_contacts(self, tokens):
        return self(tokens, return_contacts=True)["contacts"]


class TransforerLayerWithStructralAdapterTime(nn.Module):
    def __init__(
        self,
        embed_dim,
        ffn_embed_dim,
        attention_heads,
        encoder_embed_dim,
        add_bias_kv=True,
        use_esm1b_layer_norm=False,
        use_rotary_embeddings: bool = False,
        dropout=0.1,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.ffn_embed_dim = ffn_embed_dim
        self.attention_heads = attention_heads
        self.use_rotary_embeddings = use_rotary_embeddings

        self.encoder_embed_dim = encoder_embed_dim
        self.dropout = dropout
        self._init_submodules(add_bias_kv, use_esm1b_layer_norm)


    def _init_submodules(self, add_bias_kv, use_esm1b_layer_norm):
        BertLayerNorm = ESM1bLayerNorm if use_esm1b_layer_norm else ESM1LayerNorm

        self.self_attn = MultiheadAttention(
            self.embed_dim,
            self.attention_heads,
            add_bias_kv=add_bias_kv,
            add_zero_attn=False,
            use_rotary_embeddings=self.use_rotary_embeddings,
        )
        # self.self_attn_layer_norm = BertLayerNorm(self.embed_dim)
        self.self_attn_layer_norm = ConditionalLayerNorm(self.embed_dim, self.embed_dim)
        
        self.fc1 = nn.Linear(self.embed_dim, self.ffn_embed_dim)
        self.fc2 = nn.Linear(self.ffn_embed_dim, self.embed_dim)

        # self.final_layer_norm = BertLayerNorm(self.embed_dim)
        self.final_layer_norm = ConditionalLayerNorm(self.embed_dim, self.embed_dim)
        
        # structural adapter
        self.structural_adapter_attn = NormalizedResidualBlock( # loss will be nan with adaLN
            layer=MultiheadAttention(
                self.embed_dim,
                self.attention_heads,
                kdim=self.encoder_embed_dim,
                vdim=self.encoder_embed_dim,
                add_bias_kv=add_bias_kv,
                add_zero_attn=False,
                use_rotary_embeddings=True,
            ),
            embedding_dim=self.embed_dim,
            dropout=self.dropout
        )
        self.structural_adapter_ffn = NormalizedResidualBlock(
            layer=FeedForwardNetwork(
                self.embed_dim,
                self.embed_dim // 2, # NOTE: bottleneck FFN is important
                # self.ffn_embed_dim,
                activation_dropout=self.dropout
            ),
            embedding_dim=self.embed_dim,
            dropout=self.dropout
        )

    def forward(
        self, x, c, encoder_out, self_attn_mask=None, self_attn_padding_mask=None, need_head_weights=False
    ):
        residual = x
        x = self.self_attn_layer_norm(x, c)
        x, attn = self.self_attn(
            query=x,
            key=x,
            value=x,
            key_padding_mask=self_attn_padding_mask,
            need_weights=True,
            need_head_weights=need_head_weights,
            attn_mask=self_attn_mask,
        )
        x = residual + x

        # x = self.forward_adapter(x, encoder_out, attn_mask=self_attn_mask, attn_padding_mask=self_attn_padding_mask)

        residual = x
        x = self.final_layer_norm(x, c)
        x = gelu(self.fc1(x))
        x = self.fc2(x)
        x = residual + x

        x = x + self.forward_adapter(x, encoder_out, attn_mask=self_attn_mask, attn_padding_mask=self_attn_padding_mask)
        return x, attn

    def forward_adapter(self, x, encoder_out, attn_mask, attn_padding_mask):
        # encoder_feats = encoder_out['feats']
        encoder_feats = encoder_out['aligned_feats']
        encoder_feats = encoder_feats.transpose(0, 1)

        x = self.structural_adapter_attn(
            x, 
            key=encoder_feats,
            value=encoder_feats,
            key_padding_mask=attn_padding_mask,
            attn_mask=attn_mask,
            need_weights=False
        )[0]

        x = self.structural_adapter_ffn(x)
        # x = x.transpose(0, 1)
        return x