import torch
from transformers import (
    AutoConfig, AutoTokenizer, AutoModelForCausalLM, 
    BitsAndBytesConfig, StoppingCriteriaList, StoppingCriteria
)
from peft import get_peft_config, get_peft_model, get_peft_model_state_dict, PrefixTuningConfig, TaskType,PeftModelForCausalLM,PromptTuningInit, PromptTuningConfig
from typing import List, Tuple, Optional
import pickle,os
from peft import PeftModel, PeftConfig
idx=0
attn_values=[]
attn_norms_inf=[]
attn_norms_fro=[]
def _split_heads(tensor, num_heads, attn_head_size):
    new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
    tensor = tensor.view(new_shape) #(batch,seq_length,head,head_dim)
    return tensor.permute(0,2, 1, 3)  # (batch, head, seq_length, head_features)


def set_attn_hook_llama(model,save_path,use_lora=False):
    def get_attn_hook():
        def attn_hook(module,input,output):
            attn_output=output[0]
            attn_value=_split_heads(attn_output,model.config.num_attention_heads
                                    ,model.config.hidden_size//model.config.num_attention_heads)
            #print("shape==",len(output),output[0].shape,attn_value.shape)
            global attn_values,idx
            attn_values.append(attn_value)
            attn_norms_inf.append(torch.norm(attn_value,p=float('inf'),dim=(2,3)).to('cuda:0'))
            attn_norms_fro.append(torch.norm(attn_value,dim=(2,3)).to('cuda:1'))
            if len(attn_values) > 5000:
                file=str(idx)+'.pkl'
                path=os.path.join(save_path,file)
                with open(path,'wb') as f:
                    pickle.dump(attn_values,f)
                idx+=1
                attn_values.clear()
        return attn_hook
    hooks=[]

    for i in range(model.config.num_hidden_layers):
        hooks.append(model.model.layers[i].self_attn.register_forward_hook(get_attn_hook()))

    return hooks


class LLM:
    """
    A class for loading and generating text using a Language Model (LM) with support for quantization
    and custom stopping criteria.
    
    Attributes:
        model_id (str): Identifier for the model to load.
        device (str): Device to run the model on, e.g. 'cuda'.
        quantization_bits (Optional[int]): Number of bits for quantization, supports 4 or 8 bits.
        stop_list (Optional[List[str]]): List of tokens where generation should stop.
        model_max_length (int): Maximum length of the model inputs.
    """
    def __init__(
        self, 
        model_id: str, 
        setting: str,
        device: str = 'cuda', 
        quantization_bits: Optional[int] = None, 
        stop_list: Optional[List[str]] = None, 
        model_max_length: int = 4096,
        use_lora=True,
        check_point=None
    ):
        self.device = device
        self.model_max_length = model_max_length
        self.use_lora=not (check_point == None)
        self.check_point=check_point
        self.stop_list = stop_list
        if stop_list is None:
            self.stop_list = ['\nHuman:', '\n```\n', '\nQuestion:', '<|endoftext|>', '\n']
        self.attn_value_save_path=os.path.join('attn_values',model_id)
        second_dir='original' if check_point == None else 'check_point'
        self.attn_value_save_path=os.path.join(self.attn_value_save_path,second_dir)
        self.attn_value_save_path=os.path.join(self.attn_value_save_path,setting)

        self.bnb_config = self._set_quantization(quantization_bits)
        self.model, self.tokenizer, self.hooks = self._initialize_model_tokenizer(model_id)
        self.stopping_criteria = self._define_stopping_criteria()

    def _set_quantization(self, quantization_bits: Optional[int]) -> Optional[BitsAndBytesConfig]:
        """
        Configure quantization settings based on the specified number of bits.
        """
        if quantization_bits in [4, 8]:
            bnb_config = BitsAndBytesConfig()
            if quantization_bits == 4:
                bnb_config.load_in_4bit = True
                bnb_config.bnb_4bit_quant_type = 'nf4'
                bnb_config.bnb_4bit_use_double_quant = True
                bnb_config.bnb_4bit_compute_dtype = torch.bfloat16
            elif quantization_bits == 8:
                bnb_config.load_in_8bit = True
            return bnb_config
        return None


    def _initialize_model_tokenizer(self, model_id: str) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
        """
        Initializes the model and tokenizer with the given model ID.
        """
        print("model_id==",model_id)
        if model_id=='lmsys/vicuna-7b-v1.5':
            model_id='/data/somebody/data/huggingface_cache/hub/models--lmsys--vicuna-7b-v1.5/snapshots/3321f76e3f527bd14065daf69dad9344000a201d'
        if model_id=='mistralai/Mistral-7B-Instruct-v0.3':
            model_id='/data/somebody/data/huggingface_cache/hub/models--mistralai--Mistral-7B-Instruct-v0.3/snapshots/e0bc86c23ce5aae1db576c8cca6f06f1f73af2db'
        model_config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
        model_config.max_seq_len = self.model_max_length
        print("model config==",model_config)

        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            trust_remote_code=True,
            config=model_config,
            quantization_config=self.bnb_config,
            torch_dtype=torch.bfloat16,
            device_map='auto',
        )
        if self.check_point != None:
            model=PeftModel.from_pretrained(model,self.check_point)
            #model=model.merge_and_upload()
        prompt_config = PromptTuningConfig(
            task_type=TaskType.CAUSAL_LM,
            prompt_tuning_init=PromptTuningInit.TEXT,
            num_virtual_tokens=8,
            prompt_tuning_init_text="Answer the question based on the provided documents and feel free to ignore the irrelavant or distracting ones",
            tokenizer_name_or_path=model_id,
        )
        model=get_peft_model(model,prompt_config)
        print("model==",model)
        hooks=[]
        # if 'llama' in model_id:
        #     hooks=set_attn_hook_llama(model,self.attn_value_save_path)
        model.eval() # Set the model to evaluation mode
        
        tokenizer = AutoTokenizer.from_pretrained(
            model_id, padding_side="left", truncation_side="left",
            model_max_length=self.model_max_length
        )
        # Most LLMs don't have a pad token by default
        tokenizer.pad_token = tokenizer.eos_token  

        return model, tokenizer, hooks


    def _define_stopping_criteria(self) -> StoppingCriteriaList:
        """
        Defines stopping criteria for text generation based on the provided stop_list.
        """
        stop_token_ids = [self.tokenizer(x)['input_ids'] for x in self.stop_list]
        stop_token_ids = [torch.LongTensor(x).to(self.device) for x in stop_token_ids]

        class StopOnTokens(StoppingCriteria):
            def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
                for stop_ids in stop_token_ids:
                    if torch.eq(input_ids[0][-len(stop_ids):], stop_ids).all():
                        return True
                return False

        return StoppingCriteriaList([StopOnTokens()])
    
    
    def generate_answer(self, prompt: str, max_new_tokens: int = 15) -> List[str]:
        """
        Generates text based on the given prompt.
        
        Args:
            prompt (str): Input text prompt for generation.
        
        Returns:
            List[str]: The generated text responses.
        """
        inputs = self.tokenizer(
            prompt, 
            padding=True, 
            truncation=True, 
            max_length=self.model_max_length, 
            return_tensors="pt"
        ).to(self.device)
        
        generated_ids = self.model.generate(
            **inputs,
            do_sample=False,
            max_new_tokens=15,
            repetition_penalty=1.1,
            pad_token_id=self.tokenizer.eos_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
        )
        return self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    
    def remove_hooks(self):
        global attn_norms_fro,attn_norms_inf, idx
        for hook in self.hooks:
            hook.remove()
        if len(attn_norms_fro)<=0:
            return
        path=os.path.join(self.attn_value_save_path,'attn_norm_fro.pkl')
        attn_norms_fro=torch.cat(attn_norms_fro,dim=0)
        with open(path,'wb') as f:
            pickle.dump(attn_norms_fro,f)
        path=os.path.join(self.attn_value_save_path,'attn_norm_inf.pkl')
        attn_norms_inf=torch.cat(attn_norms_inf,dim=0)
        with open(path,'wb') as f:
            pickle.dump(attn_norms_inf,f)
            
        file=str(idx)+'.pkl'
        path=os.path.join(self.attn_value_save_path,file)
        with open(path,'wb') as f:
            pickle.dump(attn_values,f)
        idx+=1
        attn_values.clear()

        

