import transformers 
import torch
import os
from torch.nn.parallel import DistributedDataParallel as DDP
import warnings

# Own library imports
from src.nlp.finetune_heads.PoolingModule import PoolingModule
from src.nlp.models.BaseModel import BaseModel


class QwenModel(BaseModel):

    def __init__(self, cfg, device):
        super(QwenModel, self).__init__(cfg, device)

        self.model_name = "Qwen/Qwen2.5-0.5B"

        # Load model 
        self.backbone = transformers.AutoModel.from_pretrained(
            "Qwen/Qwen2.5-0.5B",
            torch_dtype='auto')

        for param in self.backbone.parameters():
            param.requires_grad = False

        for param in reversed(list(self.backbone.parameters())):
            param.requires_grad = True  # Enable gradients for the last parameter
            break  # Exit after the first (last layer) parameter is found

        self.backbone.config.output_hidden_states = True

        # Load tokenizer
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(
            "Qwen/Qwen2.5-0.5B")
        

        # Turn off gradient computation
        for param in self.backbone.parameters():
            param.requires_grad = False

        self.tokenizer.padding_side = cfg.padding_side

        self.hidden_dim = self.backbone.config.hidden_size

        # Create language modeling head by using embedding weights as initialisation
        self.lm_head = torch.nn.Linear(*self.backbone.get_input_embeddings().weight.shape[::-1], dtype=torch.bfloat16, bias=False)

        if not cfg.lm_head_random_init: # Use pretrained weights
            self.lm_head.weight = torch.nn.Parameter(self.backbone.get_input_embeddings().weight.detach().clone())

        self.lm_head.weight.requires_grad = True

        if self.cfg.use_lm_downstream:
            self.linear_head = self.lm_head.to(dtype=torch.bfloat16)

        else:
            self.linear_head = torch.nn.Linear(self.hidden_dim, self.hidden_dim, dtype=torch.bfloat16)

        self.linear_finetune = False

        if self.cfg.svd.add_to_pooling and (self.cfg.svd.combine == "concat"):
            # Use linear layer to combine the singular values and the 'normal' pooling output
            self.svd_concat = torch.nn.Linear(self.hidden_dim * 2, self.hidden_dim, dtype=torch.bfloat16)

        self.init_pooling(self.backbone.dtype)

    def _tokenize(self, x):
        """
        Tokenize the input text.
        :param x: the text/batched texts to process.
        :return: The tokenized text.
        """
        return self.tokenizer(x, 
                              return_tensors='pt', 
                              truncation=True,
                              max_length=self.cfg.datasets.max_length,
                              padding='max_length')

    def forward(self, x):
        """
        Forward pass of the model.
        :param x: the text/batched texts to process.
        :return: The pooled representation for each text.
        """
        # Add EOS token to the end of each text
        x = [x_i + self.tokenizer.eos_token for x_i in x]
        x_tokenized = self._tokenize(x) 

        # Move to device
        x_tokenized = {k: v.to(self.device) for k, v in x_tokenized.items()}

        # Get the hidden states
        outputs = self.backbone(**x_tokenized).hidden_states[-1]

        # Pool the hidden states
        pooled_output = self.pooling(outputs, attention_mask=x_tokenized['attention_mask'])

        if self.linear_finetune:
            pooled_output = self.linear_head(pooled_output)

        return pooled_output

        