""" huggingface model adapter

Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
"""

import re

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch import TensorType
try:
    import transformers
    from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, AutoConfig, PretrainedConfig
    from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
        BaseModelOutputWithPoolingAndCrossAttentions
except ImportError as e:
    transformers = None


    class BaseModelOutput:
        pass


    class PretrainedConfig:
        pass

from .hf_configs import arch_dict

# utils
def _camel2snake(s):
    return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()

_POOLERS = {}

def register_pooler(cls):
    """Decorator registering pooler class"""
    _POOLERS[_camel2snake(cls.__name__)] = cls
    return cls


@register_pooler
class MeanPooler(nn.Module):
    """Mean pooling"""
    def forward(self, x:BaseModelOutput, attention_mask:TensorType):
        masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
        return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)

@register_pooler
class MaxPooler(nn.Module):
    """Max pooling"""
    def forward(self, x:BaseModelOutput, attention_mask:TensorType):
        masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
        return masked_output.max(1).values

@register_pooler
class ClsPooler(nn.Module):
    """CLS token pooling"""
    def __init__(self, use_pooler_output=True):
        super().__init__()
        self.cls_token_position = 0
        self.use_pooler_output = use_pooler_output

    def forward(self, x:BaseModelOutput, attention_mask:TensorType):
        
        if (self.use_pooler_output and 
            isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
            (x.pooler_output is not None)
            ):
            return x.pooler_output
        
        return x.last_hidden_state[:, self.cls_token_position, :]

class HFTextEncoder(nn.Module):
    """HuggingFace model adapter"""
    def __init__(
            self, 
            model_name_or_path: str,
            output_dim: int,
            tokenizer_name: str = None,
            config: PretrainedConfig = None,
            pooler_type: str = None,
            proj: str = None,
            pretrained: bool = True,
            masked_language_modeling: bool = False):
        super().__init__()

        self.output_dim = output_dim

        uses_transformer_pooler = (pooler_type == "cls_pooler")

        if transformers is None:
            raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
        if config is None:
            self.config = AutoConfig.from_pretrained(model_name_or_path)
            if masked_language_modeling:
                create_func, model_args = (AutoModelForMaskedLM.from_pretrained, model_name_or_path) if pretrained else (
                    AutoModelForMaskedLM.from_config, self.config)
            else:
                create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
                    AutoModel.from_config, self.config)
            if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
                self.transformer = create_func(model_args)
                self.transformer = self.transformer.encoder
            else:
                self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
        else:
            self.config = config
            if masked_language_modeling:
                self.transformer = AutoModelForMaskedLM.from_config(config)
            else:
                self.transformer = AutoModel.from_config(config)

        if pooler_type is None: # get default arch pooler
            self.pooler = _POOLERS[(arch_dict[self.config.model_type]["pooler"])]()
        else:
            self.pooler = _POOLERS[pooler_type]()

        d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
        if (d_model == output_dim) and (proj is None): # do we always need a proj?
            self.proj = nn.Identity()
        elif proj == 'linear':
            self.proj = nn.Linear(d_model, output_dim, bias=False)
        elif proj == 'mlp':
            hidden_size = (d_model + output_dim) // 2
            self.proj = nn.Sequential(
                nn.Linear(d_model, hidden_size, bias=False),
                nn.GELU(),
                nn.Linear(hidden_size, output_dim, bias=False),
            )

        # self.itm_proj = nn.Linear(d_model, 2, bias=False)
        # self.mlm_proj = nn.Linear(d_model, self.config.vocab_size), bias=False)
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

    # def forward_itm(self, x:TensorType, image_embeds:TensorType) -> TensorType:
    #     image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(x.device)  
    #     attn_mask = (x != self.config.pad_token_id).long()
    #     out = self.transformer(
    #         input_ids=x, 
    #         attention_mask=attn_mask,
    #         encoder_hidden_states = image_embeds,
    #         encoder_attention_mask = image_atts,
    #         )
    #     pooled_out = self.pooler(out, attn_mask)

    #     return self.itm_proj(pooled_out)

    def mask(self, input_ids, vocab_size, device, targets=None, masked_indices=None, probability_matrix=None):
        if masked_indices is None:                                       
            masked_indices = torch.bernoulli(probability_matrix).bool()
                                               
        masked_indices[input_ids == self.tokenizer.pad_token_id] = False
        masked_indices[input_ids == self.tokenizer.cls_token_id] = False
        
        if targets is not None:
            targets[~masked_indices] = -100 # We only compute loss on masked tokens            

        # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
        indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices
        input_ids[indices_replaced] = self.tokenizer.mask_token_id

        # 10% of the time, we replace masked input tokens with random word
        indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device)
        input_ids[indices_random] = random_words[indices_random]                     
        # The rest of the time (10% of the time) we keep the masked input tokens unchanged   
        
        if targets is not None:
            return input_ids, targets
        else:
            return input_ids

    def forward_mlm(self, input_ids, image_embeds, mlm_probability=0.25):
        labels = input_ids.clone()
        attn_mask = (input_ids != self.config.pad_token_id).long()
        image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(input_ids.device) 
        vocab_size = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["vocab_size"])
        probability_matrix = torch.full(labels.shape, mlm_probability)
        input_ids, labels = self.mask(input_ids, vocab_size, input_ids.device, targets=labels,
                                      probability_matrix = probability_matrix)
        mlm_output = self.transformer(input_ids,
                        attention_mask = attn_mask,
                        encoder_hidden_states = image_embeds,
                        encoder_attention_mask = image_atts,
                        return_dict = True,
                        labels = labels,
                    )
        return mlm_output.loss
        # mlm_output = self.transformer(input_ids,
        #                 attention_mask = attn_mask,
        #                 encoder_hidden_states = image_embeds,
        #                 encoder_attention_mask = image_atts,
        #                 return_dict = True,
        #             ).last_hidden_state
        # logits = self.mlm_proj(mlm_output)

        # # logits = logits[:, :-1, :].contiguous().view(-1, vocab_size)
        # logits = logits[:, 1:, :].contiguous().view(-1, vocab_size)
        # labels = labels[:, 1:].contiguous().view(-1)

        # mlm_loss = F.cross_entropy(
        #     logits,
        #     labels,
        #     # label_smoothing=0.1,
        # )
        # return mlm_loss


    def forward(self, x:TensorType) -> TensorType:
        attn_mask = (x != self.config.pad_token_id).long()
        out = self.transformer(input_ids=x, attention_mask=attn_mask)
        pooled_out = self.pooler(out, attn_mask)

        return self.proj(pooled_out)

    def lock(self, unlocked_layers:int=0, freeze_layer_norm:bool=True):
        if not unlocked_layers: # full freezing
             for n, p in self.transformer.named_parameters():
                 p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
             return

        encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
        layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
        print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
        embeddings = getattr(
            self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
        modules = [embeddings, *layer_list][:-unlocked_layers]
        # freeze layers
        for module in modules:
            for n, p in module.named_parameters():
                p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False


    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        self.transformer.gradient_checkpointing_enable()

    def get_num_layers(self):
        encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
        layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
        return len(layer_list)

    def init_parameters(self):
        pass
