import os

import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from transformers import AutoModel
from transformers import AutoTokenizer

from src.nlp.finetune_heads.PoolingModule import PoolingModule
from src.nlp.models.BaseModel import BaseModel

class LlamaModel(BaseModel):
    def __init__(self, cfg, device):
        super().__init__(cfg, device)

        self.cfg = cfg
        self.device = device

        self.model_name = "meta-llama/Llama-3.1-8B"
        self.backbone = AutoModel.from_pretrained(
            self.model_name,
            torch_dtype="auto",
            token=os.environ["HF_TOKEN"],
        )

        self.backbone.config.output_hidden_states = True
        # disable gradient for backbone, except for the linear head
        for param in self.backbone.parameters():
            param.requires_grad = False

        self.tokenizer = AutoTokenizer.from_pretrained(
            self.model_name,
            token=os.environ["HF_TOKEN"],
        )

        self.tokenizer.padding_side = cfg.padding_side


        self.hidden_dim = self.backbone.config.hidden_size

        # Check if pad token exists
        if self.tokenizer.pad_token_id is None:
            print(f"Pad token not detected, current vocab size:{len(self.tokenizer)}. Adding pad token...")
            self.tokenizer.add_special_tokens({"pad_token": "<pad>"})
            self.backbone.resize_token_embeddings(len(self.tokenizer))
            print(f"Pad token added, new vocab size:{len(self.tokenizer)}")

        # Create language modeling head by using embedding weights as initialisation
        self.lm_head = torch.nn.Linear(*self.backbone.get_input_embeddings().weight.shape[::-1], dtype=torch.bfloat16, bias=False)

        if not cfg.lm_head_random_init: # Use pretrained weights
            self.lm_head.weight = torch.nn.Parameter(self.backbone.get_input_embeddings().weight.detach().clone())

        self.lm_head.weight.requires_grad = True

        if self.cfg.use_lm_downstream:
            self.linear_head = self.lm_head.to(dtype=torch.bfloat16)
        else:
            self.linear_head = torch.nn.Linear(self.hidden_dim, self.hidden_dim, dtype=torch.bfloat16)
        self.linear_finetune = False

        self.init_pooling(self.backbone.dtype)



    def _tokenize(self, x):
        """
        Tokenize the input text.
        :param x: the text/batched texts to process.
        :return: The tokenized text.
        """
        return self.tokenizer(x, 
                              return_tensors='pt', 
                              truncation=True,
                              max_length=self.cfg.datasets.max_length,
                              padding='max_length')
    
    def forward(self, texts):
        texts = [t + self.tokenizer.eos_token for t in texts]

        inputs = self._tokenize(texts)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        outputs = self.backbone(**inputs).last_hidden_state
        pooled_outputs = self.pooling(outputs, attention_mask=inputs["attention_mask"])

        if self.linear_finetune:
            pooled_outputs = self.linear_head(pooled_outputs)

        return pooled_outputs