from transformers import pipeline, set_seed, utils, BitsAndBytesConfig
import numpy as np
import torch
import seaborn as sns
import os
import matplotlib.pyplot as plt
import numpy as np
from tqdm.auto import tqdm
import scipy
from typing import Optional, Callable, Iterable, Sequence
import itertools

EOT = '<|endoftext|>'
class Generator:
    def __init__(self, model_name: str, device: torch.device | str, random_seed: int = 42, load_in_4bit: bool = False, load_in_16bit: bool = False):
        self.model_name = model_name
        self.device = device
        utils.logging.set_verbosity_error()
        with open('huggingface_auth_token') as f:
            auth_token = f.read().strip()
        if load_in_4bit is True: 
            from transformers import BitsAndBytesConfig
            if device in [torch.device('cpu'), 'cpu']:
                raise ValueError('The load_in_4bit must be on GPUs.')
            model_args = {'load_in_4bit': True, 'quantization_config': BitsAndBytesConfig(
                    load_in_4bit=True,
                    bnb_4bit_compute_dtype=torch.bfloat16,
                    bnb_4bit_use_double_quant=True,
                    bnb_4bit_quant_type='nf4'
                )}
            device_map = 'auto'
            self.generator = pipeline('text-generation', model=model_name, device_map=device_map, torch_dtype=torch.bfloat16, use_auth_token=auth_token, model_kwargs=model_args)
        else:
            if isinstance(device, str):
                device_map = device
                self.generator = pipeline('text-generation', model=model_name, device_map=device_map, use_auth_token=auth_token, torch_dtype=torch.bfloat16 if load_in_16bit else None)
            else:
                self.generator = pipeline('text-generation', model=model_name, device=device, use_auth_token=auth_token, torch_dtype=torch.bfloat16 if load_in_16bit else None)
        set_seed(random_seed)
        self.model = self.generator.model
        self.tokenizer = self.generator.tokenizer 
        self.optimizer = None
        self.n_layer = self.model.config.num_hidden_layers
        self.n_head = self.model.config.num_attention_heads
        if isinstance(self.device, str):
            self.device = self.model.device
        if hasattr(self.model.config, 'pretraining_tp') and self.model.config.pretraining_tp != 1:
            raise ValueError('Do not support pretraining_tp != 1. If you want to use pretraining_tp > 1, you need to modify the \'modeling_llama.py\' line 292-307. The shape of the hidden_states does not match the slices.')

    def set_optimizer(self, opt_name='SGD', lr=1e-4, weight_decay=0.0):
        self.optimizer = getattr(torch.optim, opt_name)(self.model.parameters(), lr=lr, weight_decay=weight_decay)

class TextProcess:
    def __init__(self, generator: Generator, input_text: str | list[str], head_mask: Optional[tuple[int, int] | list[tuple[int, int]]] = None, add_end_of_text: bool = True, not_output: bool = False):
        """
        input_text: text or batch of text. IF LIST, PLEASE CONFIRM THE LENGTH IS THE SAME.
        add_end_of_text: if True, add the <|end_of_text|> at the end of text, please confirm the text is complete. It is used when autoregression.
        not_output: if True, the output is None, and the pred_dists is None
        -----------------------------
        output:
        output[0]: last_hidden_state (batch, seq_len, embed_dim)
        output[1]: past_key_values (tuple: len==num_layers), each tuple contains the key and value in the layer, and each tensor has shape (batch, num_heads, seq_len, embed_dim/num_heads). The output[1][layer][k=0, v=1] is a tensor.
        output[2]: If the output_attentions is True, the attention matrix. taple: len==num_layers, each tensor has shape (batch, num_heads, seq_len, seq_len)
        tokens: tokenized input, a list of strings, len: seq_len
        pred_dist: the output predicted distribution **without softmax**, Tensor: (batch, seq_len, vocab_size)
        """
        self.generator = generator
        self.model = self.generator.model
        self.tokenizer = self.generator.tokenizer
        self.input_text = input_text
        self.is_batch = isinstance(input_text, list)
        self.is_llama = 'Llama-2' in self.generator.model.name_or_path
        self.is_opt = 'opt' in self.generator.model.name_or_path
        self.batchsize = len(input_text) if self.is_batch else None
        self.encoded_input = self.tokenizer(self.input_text, return_tensors='pt') 
        for k, v in self.encoded_input.items():
            self.encoded_input[k] = v.to(self.generator.device)
        self.seq_len = self.encoded_input["input_ids"].shape[1]
        if not self.is_batch:
            self.tokens = self.tokenizer.convert_ids_to_tokens(self.encoded_input["input_ids"][0].tolist()) # a list of strings, len: seq_len
            self.tokens_id = self.encoded_input["input_ids"][0].tolist() # a list of int, len: seq_len
        else:
            assert self.batchsize is not None
            self.tokens = [self.tokenizer.convert_ids_to_tokens(self.encoded_input["input_ids"][i].tolist()) for i in range(self.batchsize)]
            self.tokens_id = [self.encoded_input["input_ids"][i].tolist() for i in range(self.batchsize)]
    
        # output: see the docstring
        if not_output:
            self.output = None
        elif self.is_opt:
            self.output = self.model.model(**self.encoded_input, output_attentions=True, head_mask=self.head_mask_tensor(head_mask)) if not not_output else None
        elif self.is_llama: # llama-2 does not provide head_mask.
            self.output = self.model.model(**self.encoded_input, output_attentions=True) if not not_output else None
        else:
            self.output = self.model.transformer(**self.encoded_input, output_attentions=True, head_mask=self.head_mask_tensor(head_mask)) if not not_output else None

        self.pred_dists = self.model.lm_head(self.output[0]) if not not_output else None # (batch, seq_len, vocab_size), batch == 1.

        if self.is_llama: # llama-2 add a special token at the beginning of the sequence, remove it.
            self.tokens = self.tokens[1:]
            self.tokens_id = self.tokens_id[1:]
            self.pred_dists = self.pred_dists[:, 1:, :] 

        self.eos_token_id = self.model.config.eos_token_id
        self.add_end_of_text = add_end_of_text

    def get_golden_answer(self, idx: Optional[int] = None):
        if self.is_batch:
            assert self.batchsize is not None
            return [self.get_golden_answer(idx=idx) for idx in range(self.batchsize)]

        if self.add_end_of_text:
            return self.tokens[1:] + [EOT]
        else:
            return self.tokens[1:]
    def get_golden_answer_idx(self):
        """If the data is batched, return a list of tensor, shape: (len_seq or len_seq-1, ); else return a tensor."""
        if self.is_batch:
            return [self.get_golden_answer_idx(idx=idx) for idx in range(self.batchsize)]

        if self.add_end_of_text:
            return torch.concat([self.encoded_input["input_ids"][0, 1:], torch.tensor([self.eos_token_id]).to(self.encoded_input['input_ids'].device)])
        else:
            return self.encoded_input["input_ids"][0, 1:]
    def get_golden_answer_onehot(self):
        """
        size: (len_seq or len_seq-1, vocab_size)
        """
        if self.is_batch:
            return [self.get_golden_answer_onehot(idx=idx) for idx in range(self.batchsize)]
        return torch.nn.functional.one_hot(self.get_golden_answer_idx(), num_classes=self.model.config.vocab_size).float()

    def head_mask_tensor(self, head_mask: Optional[tuple[int, int] | list[tuple[int, int]]]):
        if head_mask is None:
            return None
        t = torch.ones(self.generator.n_layer, self.generator.n_head, dtype=torch.float)
        if isinstance(head_mask, tuple):
            t[head_mask[0], head_mask[1]] = 0.0
        elif isinstance(head_mask, list):
            for m in head_mask:
                t[m[0], m[1]] = 0.0
        else:
            raise ValueError('head_mask should be tuple or list of tuple')
        return t.to(self.model.device)