import os
import gc
import torch
from typing import List

from transformers import AutoModelForCausalLM, AutoTokenizer
from src.utils import conversation_temp


PATH_DICT = {
    "sheared-llama": "princeton-nlp/Sheared-LLaMA-1.3B",
    "llama-3": "meta-llama/Meta-Llama-3-8B",
    "mistral-7b": "mistralai/Mistral-7B-v0.1",
    "llama-2": "meta-llama/Llama-2-7B-hf"
}


class LanguageModel():
    """
    A basic class for language models.
    """
    def __init__(self, model_name):
        self.model_name = model_name
    
    def batched_generation(self, prompts:List[str], max_seq_len:int, max_batch_size:int):
        """
        Response generation with language models.
        
        Args:
            prompts (List[str]): A list of strings.
            max_seq_len (int): The maximum of the sequence length.
            max_batch_size (int): The maximum of the batch size.
        """
        raise NotImplementedError


class HuggingFace(LanguageModel):
    
    def __init__(self, model_name:str, model, tokenizer):
        """
        Args:
            model_name (str): The name of the adopted language model, e.g., LLama.
            model: The language model.
            tokenizer: The tokenizer.
        """
        self.model_name = model_name
        self.model = model
        self.tokenizer = tokenizer
        self.eos_tk_ids = [self.tokenizer.eos_token_id]
    
    def batched_generation(self, 
                           prompts:List[str], 
                           max_seq_len:int, 
                           max_batch_size:int=1,
                           temperature:float=1.0,
                           top_p:float=1.0) -> List[str]:
        """
        Response generation with language models.
        
        Args:
            prompts (List[str]): A list of strings.
            max_seq_len (int): The maximum of the sequence length.
            max_batch_size (int): The maximum of the batch size. Default 1 in our work regarding agent.
            temperature (float): The temperature coefficient of the language model.
        
        Returns:
            A list of responses (the output of the language model).
        """
        if self.model_name == "llama3":
            assert 1 <= max_seq_len <= 8192, f"The maximum of the sequence length is 8192."
        
        inputs = self.tokenizer(prompts, return_tensors="pt", padding=True)
        inputs = {k: v.to(self.model.device.index) for k, v in inputs.items()}
        
        if temperature > 0.:
            output_ids = self.model.generate(
                **inputs,
                max_new_tokens=max_seq_len,
                do_sample=True,
                temperature=temperature,
                top_p=top_p,
                eos_token_id=self.eos_tk_ids
            )
        else:
            output_ids = self.model.generate(
                **inputs,
                max_new_tokens=max_seq_len,
                do_sample=False,
                eos_token_id=self.eos_tk_ids,
                top_p=1,
                temperature=1.0
            )
        
        if not self.model.config.is_encoder_decoder:
            output_ids = output_ids[:, inputs["input_ids"].shape[1]:]
        
        output_list = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        
        for key in inputs:
            inputs[key].to("cpu")
        output_ids.to("cpu")
        del inputs, output_ids
        
        gc.collect()
        torch.cuda.empty_cache()
        
        return output_list


class LanguageModelSpecification():
    """
    Specify a language model.
    """
    def __init__(self, 
                 model_name:str, 
                 max_seq_len:int, 
                 max_batch_size:int,
                 temperature:float=1.,
                 top_p:float=1.) -> None:
        
        self.model_name = model_name
        self.max_batch_size = max_batch_size
        self.top_p = top_p
        self.temperature = temperature
        self.max_seq_len = max_seq_len
        if model_name == "llama-3":
            assert 1 <= self.max_seq_len <= 8192, f"The maximum of the sequence length of LLaMA3 is 8192."
        
        self.template, self.tokenizer, self.model = self._get_model_and_tokenizer(model_name=model_name)
    
    def _get_model_and_tokenizer(self, model_name:str):
        """
        Initialize the language model and tokenizer.
        """
        ### TODO: More language model can be added later.
        if model_name == "llama-3":
            model_id = PATH_DICT["llama-3"]
            template = "llama-3"
        elif model_name == "llama-2":
            model_id = PATH_DICT["llama-2"]
            template = "llama-2"
        elif model_name == "sheared-llama":
            model_id = PATH_DICT["sheared-llama"]
            template = "sheared-llama"
        elif model_name == "mistral-7b":
            model_id = PATH_DICT["mistral-7b"]
            template = "mistral-7b"
        else:
            raise ValueError("Unrecognized model name.")
            
        tokenizer = AutoTokenizer.from_pretrained(
            pretrained_model_name_or_path=model_id,
            use_fast=False
        )
        
        model = AutoModelForCausalLM.from_pretrained(
            pretrained_model_name_or_path=model_id,
            torch_dtype=torch.bfloat16,
            low_cpu_mem_usage=True,
        ).eval().cuda()
        
        if model_name == "llama-2":
            tokenizer.pad_token = tokenizer.unk_token
            tokenizer.padding_side = "left"
        
        if not tokenizer.pad_token:
            tokenizer.pad_token = tokenizer.eos_token
        model.generation_config.pad_token_id = tokenizer.pad_token_id
        
        lm = HuggingFace(model_name=self.model_name, 
                         model=model, 
                         tokenizer=tokenizer)
        return template, tokenizer, lm
    
    def response_generation(self, prompt:str):
        conversation = conversation_temp(self.template)
        conversation.append_message(conversation.roles[0], prompt)
        conversation.append_message(conversation.roles[1], None)
        
        full_prompts = conversation.get_prompt()
        #print(full_prompts)

        output = self.model.batched_generation(prompts=full_prompts,
                                               max_seq_len=self.max_seq_len,
                                               temperature=self.temperature,
                                               top_p=self.top_p)

        return output