"""
Custom wrapper class for a pretrained NanoGPT model """

import tiktoken
import torch
from omegaconf import OmegaConf

# Own library imports
from src.nlp.finetune_heads.PoolingModule import PoolingModule
from src.nlp.models.BaseModel import BaseModel
from src.nanoGPT.model import GPT
from src.nanoGPT.utils import pad_sequence


class NanoGPTWrapper(BaseModel):

    def __init__(self, cfg, device):
        super(NanoGPTWrapper, self).__init__(cfg, device)

        self.model_name = "Custom-NanoGPT"

        # Load pretrained L2-GPT model
        if cfg.backbone == "NanoGPT":
            checkpoint_path = cfg.l2_gpt2_path
            with open(checkpoint_path, 'rb') as f:
                checkpoint = torch.load(f, map_location=self.cfg.device)

        # Load pretrained GPT-2 model
        elif cfg.backbone == "GPT2-Custom":
            checkpoint_path = cfg.gpt2_path
            with open(checkpoint_path, 'rb') as f:
                checkpoint = torch.load(f, map_location=self.cfg.device)

        checkpoint_cfg = OmegaConf.create(checkpoint['config'])
        model = GPT(checkpoint_cfg.model)
        state_dict = checkpoint['model']
        # Remove unwanted prefix from state_dict
        unwanted_prefix = '_orig_mod.'
        for k,v in list(state_dict.items()):
            if k.startswith(unwanted_prefix):
                state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)

        model.load_state_dict(state_dict)
        self.hidden_dim = checkpoint_cfg.model.n_embed

    
        checkpoint_cfg = OmegaConf.create(model.config)
        self.hidden_dim = checkpoint_cfg.n_embed


        self.backbone = model

        self.backbone.set_embed_mode(True) # Only return last hidden states
        self.backbone.eval()

        # Turn off gradient computation
        for param in self.backbone.parameters():
            param.requires_grad = False

        # Initialize tokenizer
        self.tokenizer_object = tiktoken.get_encoding("gpt2")
        self.tokenizer = self.tokenizer_object.encode
        self.eos_str = "<|endoftext|>"
        self.eos_token = self.tokenizer(self.eos_str, allowed_special={"<|endoftext|>"})[0]


        # Create language modeling head by using embedding weights as initialisation
        self.lm_head = torch.nn.Linear(*self.backbone.get_input_embeddings().weight.shape[::-1], bias=False).to(dtype=torch.bfloat16)

        if not cfg.lm_head_random_init: # Use pretrained weights
            self.lm_head.weight = torch.nn.Parameter(self.backbone.get_input_embeddings().weight.detach().clone())

        self.lm_head.weight.requires_grad = True

        # Initialize finetune linear projection
        if self.cfg.use_lm_downstream:
            self.linear_head = self.lm_head.to(dtype=torch.bfloat16)
        else:
            self.linear_head = torch.nn.Linear(self.hidden_dim, self.hidden_dim, dtype=torch.bfloat16)

        self.linear_finetune = False

        self.init_pooling(torch.bfloat16)

    def _tokenize(self, x):
        """
        Tokenize input text """
        # Tokenize every text one by one and pad to block size length
        tokenized = [
            pad_sequence(
                self.tokenizer(x_i, allowed_special={"<|endoftext|>"}),
                self.eos_token,
                self.cfg.datasets.max_length,
                self.cfg.padding_side)
                for x_i in x
            ]

        input_ids = torch.stack([t['input_ids'] for t in tokenized])
        attention_mask = torch.stack([t['attn_mask'] for t in tokenized])

        tokenized_batch = {
        'input_ids': input_ids,
        'attention_mask': attention_mask
        }
        # padded = pad_sequence(tokenized, self.eos_token, self.backbone.config.block_size)
        return tokenized_batch

    def forward(self, x):
        """
        Forward pass through the model """

        # Add EOS token to all inputs
        x = [x_i + self.eos_str for x_i in x]

        x_tokenized = self._tokenize(x)
        assert type(x_tokenized) == dict, "Tokenized input must be a dictionary with input ids and attention mask."

        # Move to device
        x_tokenized = {k: v.to(self.device) for k, v in x_tokenized.items()}

        # Get the hidden states
        outputs = self.backbone(x_tokenized)

        # Convert output to bfloat16
        outputs = outputs.to(torch.bfloat16)

        # Pool the hidden states
        pooled_output = self.pooling(outputs, attention_mask=x_tokenized['attention_mask'])
        # If linear finetune enabled
        if self.linear_finetune:
            pooled_output = self.linear_head(pooled_output)
        return pooled_output
