import transformers
import torch
import os
import warnings
import json

# Own library imports
from src.nlp.finetune_heads.PoolingModule import PoolingModule
from src.nlp.models.BaseModel import BaseModel


class BertModel(BaseModel):
    def __init__(self, cfg, device):
        super(BertModel, self).__init__(cfg, device)

        self.model_name = "bert-base-uncased"

        # Load model (set output_hidden_states to True for consistency)
        self.backbone = transformers.AutoModel.from_pretrained(
            "bert-base-uncased",
            output_hidden_states=True  # Ensure hidden states are returned
        )
        
        # Freeze all backbone parameters first
        for param in self.backbone.parameters():
            param.requires_grad = False

        # Load tokenizer
        self.tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-uncased")
        self.tokenizer.padding_side = cfg.padding_side
        
        # Hidden dimension from model config
        self.hidden_dim = self.backbone.config.hidden_size

        # Create language modeling head using the input embedding weights as initialization.
        # For BERT, we use torch.float (you can adjust dtype as needed).
        self.lm_head = torch.nn.Linear(
            *self.backbone.get_input_embeddings().weight.shape[::-1],
            dtype=torch.float,
            bias=False
        )

        if not cfg.lm_head_random_init:  # Use pretrained weights for initialization
            self.lm_head.weight = torch.nn.Parameter(
                self.backbone.get_input_embeddings().weight.detach().clone()
            )
        self.lm_head.weight.requires_grad = True

        # Set up the linear head based on configuration
        if self.cfg.use_lm_downstream:
            self.linear_head = self.lm_head.to(dtype=torch.float)
        else: # Projection
            self.linear_head = torch.nn.Linear(self.hidden_dim, self.hidden_dim, dtype=torch.float)

        self.linear_finetune = False

        self.init_pooling(self.backbone.dtype)

        if self.cfg.log_diagnostics:
            self.norm_stats = {
                "head_weight_norms": [],
                "head_grad_norms": [],
                "avg_token_norms": [],
                "sum_token_norms": [],
                "pooled_repr_norms": [],
            }

    def _tokenize(self, x):
        """
        Tokenize the input text.
        :param x: The text or list of texts to process.
        :return: The tokenized text as a dictionary.
        """
        return self.tokenizer(
            x,
            return_tensors='pt',
            truncation=True,
            max_length=self.cfg.datasets.max_length,
            padding='max_length'
        )

    def forward(self, x):
        """
        Forward pass of the model.
        :param x: The text or batched texts to process.
        :return: The pooled representation for each text.
        """
        # Tokenize input text. For BERT, the tokenizer automatically adds special tokens ([CLS], [SEP]).
        x_tokenized = self._tokenize(x)
        x_tokenized = {k: v.to(self.device) for k, v in x_tokenized.items()}

        # Get the hidden states; use last_hidden_state from BERT output.
        outputs = self.backbone(**x_tokenized).last_hidden_state

        # Pool the hidden states using the provided pooling module.
        pooled_output = self.pooling(outputs, attention_mask=x_tokenized['attention_mask'])

        # If linear fine-tuning is enabled, pass through the linear head.
        if self.linear_finetune:
            pooled_output = self.linear_head(pooled_output)

        return pooled_output

    def get_cls_token(self, x):
        """
        Retrieve the [CLS] token representation from the model.
        :param x: The text or batched texts to process.
        :return: A tensor with the CLS token representation for each input.
        """
        # Tokenize input text. The tokenizer adds [CLS] as the first token automatically.
        x_tokenized = self._tokenize(x)
        x_tokenized = {k: v.to(self.device) for k, v in x_tokenized.items()}

        # Get the hidden states from the backbone model.
        outputs = self.backbone(**x_tokenized).last_hidden_state

        # The [CLS] token is typically the first token in the sequence.
        cls_tokens = outputs[:, 0, :]
        return cls_tokens

    def return_aggregated_diagnostic_data(self):
        if not self.cfg.log_diagnostics:
            warnings.warn("Diagnostics logging is not enabled.")
            return None
        # Average each list as value in the dictionary
        avg_dict = {k: sum(v) / len(v) for k, v in self.norm_stats.items()} 
        # empty the norm_stats dictionary
        self.norm_stats = {
            "head_weight_norms": [],
            "head_grad_norms": [],
            "avg_token_norms": [],
            "sum_token_norms": [],
            "pooled_repr_norms": [],
        }
        return avg_dict
        

