import torch
import numpy as np
import torch.nn as nn
from torch.distributed import get_rank, get_world_size
from torch.nn.functional import cross_entropy
from transformers import PreTrainedModel, PretrainedConfig, AutoConfig, AutoModel, EsmModel
from model.layers import CLIPLoss


class ProteinTextCLIPConfig(PretrainedConfig):
    model_type = "protein_text_clip"
    is_composition = True

    def __init__(self,
                 protein_model_config,
                 text_model_config,
                 projection_dim,
                 **kwargs):
        super().__init__(**kwargs)
        self.protein_model_config = protein_model_config
        self.text_model_config = text_model_config

        if isinstance(protein_model_config, dict):
            self.protein_model_config = AutoConfig.for_model(**protein_model_config)
        if isinstance(text_model_config, dict):
            self.text_model_config = AutoConfig.for_model(**text_model_config)
        self.projection_dim = projection_dim

        self.hidden_sizes = [self.protein_model_config.hidden_size,
                             self.text_model_config.hidden_size,
                             self.projection_dim]
        self.logit_scale_init = kwargs.pop("logit_scale_init", 0.07)


class ProteinTextCLIPForPretrainSequenceOnly(PreTrainedModel):
    config_class = ProteinTextCLIPConfig

    def __init__(self, config):
        super().__init__(config)
        protein_model_config = config.protein_model_config
        text_model_config = config.text_model_config

        self.protein_model = EsmModel.from_pretrained(
            protein_model_config._name_or_path)  # use this line if you want to train from scratch
        self.protein_model.gradient_checkpointing_enable()
        self.text_model = AutoModel.from_pretrained(
            text_model_config._name_or_path)  # use this line if you want to train from scratch

        self.protein_projection = nn.Sequential(
            nn.Linear(protein_model_config.hidden_size, self.config.projection_dim),
            nn.GELU(),
            nn.Linear(self.config.projection_dim, self.config.projection_dim),
        )
        self.text_projection = nn.Sequential(
            nn.Linear(text_model_config.hidden_size, self.config.projection_dim),
            nn.GELU(),
            nn.Linear(self.config.projection_dim, self.config.projection_dim),
        )
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / self.config.logit_scale_init))

        self.mlm_head = nn.Sequential(
            nn.Linear(protein_model_config.hidden_size, self.config.projection_dim),
            nn.GELU(),
            nn.LayerNorm(self.config.projection_dim),
            nn.Linear(self.config.projection_dim, self.config.projection_dim),
        )

    def forward(self,
                protein_input_ids,
                protein_attention_mask,
                text_input_ids,
                text_attention_mask,
                protein_masked_input_ids,
                protein_masked_labels
                ):
        protein_embeds = self.protein_model(
            input_ids=protein_input_ids, attention_mask=protein_attention_mask
        ).last_hidden_state.mean(dim=1)
        protein_embeds = self.protein_projection(protein_embeds)

        text_embeds = self.text_model(
            input_ids=text_input_ids, attention_mask=text_attention_mask
        ).last_hidden_state.mean(dim=1)
        text_embeds = self.text_projection(text_embeds)

        # normalize the embeddings
        protein_embeds = protein_embeds / protein_embeds.norm(dim=-1, keepdim=True)
        text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)

        cl_loss = CLIPLoss(
            local_loss=False,
            gather_with_grad=True,
            cache_labels=True,
            rank=get_rank() if torch.distributed.is_initialized() else 0,
            world_size=get_world_size() if torch.distributed.is_initialized() else 1
        )(protein_embeds, text_embeds, self.logit_scale.exp())

        protein_perturbed_output = self.protein_model(
            input_ids=protein_masked_input_ids,
            attention_mask=protein_attention_mask,
        )

        mlm_logits = self.mlm_head(protein_perturbed_output.last_hidden_state)
        mlm_loss = cross_entropy(mlm_logits.view(-1, mlm_logits.shape[-1]), protein_masked_labels.view(-1))

        return {
            "loss": cl_loss + mlm_loss,
            "cl_loss": cl_loss,
            "mlm_loss": mlm_loss
        }
