import os
import torch
from typing import List, Dict, Union
from .abstract_language_model import AbstractLanguageModel
from time import time
import transformers
import numpy as np
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer

model_name_or_path = "TheBloke/Llama-2-70B-chat-AWQ"

class Llama2HF(AbstractLanguageModel):
    """
    An interface to use LLaMA 2 models through the HuggingFace library.
    """

    def __init__(
        self, config_path: str = "", model_name: str = "TheBloke/Llama-2-70B-chat-AWQ", cache: bool = False, verbose: bool = False
    ) -> None:
        """
        Initialize an instance of the Llama2HF class with configuration, model details, and caching options.

        :param config_path: Path to the configuration file. Defaults to an empty string.
        :type config_path: str
        :param model_name: Specifies the name of the LLaMA model variant. Defaults to "llama7b-hf".
                           Used to select the correct configuration.
        :type model_name: str
        :param cache: Flag to determine whether to cache responses. Defaults to False.
        :type cache: bool
        """
        super().__init__(config_path, model_name, cache)
        self.config: Dict = self.config[model_name]
   
        self.verbose = verbose
        self.temperature: float = self.config["temperature"]
        self.top_k: int = self.config["top_k"]
        self.max_tokens: int = self.config["max_tokens"]

        self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=False)

        self.model = AutoAWQForCausalLM.from_quantized(model_name, fuse_layers=True,trust_remote_code=False, safetensors=True)
        
        self.model.eval()
        torch.no_grad()
        

    def query(self, query: str, num_responses: int = 1, temperature=None, top_p=None, top_k=None, repetition_penalty=None, max_new_tokens=500) -> List[List]:
        """
        Query the LLaMA 2 model for responses.

        :param query: The query to be posed to the language model.
        :type query: str
        :param num_responses: Number of desired responses, default is 1.
        :type num_responses: int
        :return: Response(s) from the LLaMA 2 model.
        :rtype: List[Dict]
        """

        if self.cache and query in self.respone_cache:
            return self.respone_cache[query]
        sequences = []
      
        tokens = self.tokenizer(query, return_tensors='pt').input_ids.cuda()

        if tokens.shape[1] > 4024:
            tokens = torch.concat([tokens[:,:1000],tokens[:,-1000:]],dim=1)
            torch.cuda.empty_cache()       

        if temperature == None:
            temperature = self.temperature
        if top_p == None:
            top_p = self.config["top_p"]
        
        if repetition_penalty == None:
            repetition_penalty = self.config["repetition_penalty"]
        
        if top_k == None:
            top_k = self.top_k
        if self.verbose:
            print("ANSWERING QUERY")
        sequences.extend(
            self.model.generate(
                tokens,
                do_sample=True,
                top_k=top_k,
                num_return_sequences=num_responses,
                eos_token_id=self.tokenizer.eos_token_id,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p,
                repetition_penalty=repetition_penalty,
                
            )
        )
    
        response = ([self.tokenizer.decode(sequence)[:len(query)+5] for sequence in sequences],[self.tokenizer.decode(sequence,skip_special_tokens=True )[len(query):] for sequence in sequences])
      
        if self.verbose:
            print("Prompt : ", response[0][0])
            print("Response : ", response[1][0])
        
        if self.cache:
            self.respone_cache[query] = response
        return response

    def query_trivia(self, query: str, num_responses: int = 1, temperature=None, top_p=None, top_k=None, repetition_penalty=None, max_new_tokens=500) -> List[List]:
        """
        Query the LLaMA 2 model for responses.

        :param query: The query to be posed to the language model.
        :type query: str
        :param num_responses: Number of desired responses, default is 1.
        :type num_responses: int
        :return: Response(s) from the LLaMA 2 model.
        :rtype: List[Dict]
        """
        if self.cache and query in self.respone_cache:
            return self.respone_cache[query]
        sequences = []
      
        tokens = self.tokenizer(query, return_tensors='pt').input_ids.cuda()

  
        if temperature == None:
            temperature = self.temperature
        if top_p == None:
            top_p = self.config["top_p"]
        
        if repetition_penalty == None:
            repetition_penalty = self.config["repetition_penalty"]
        
        if top_k == None:
            top_k = self.top_k
        if self.verbose:
            print("ANSWERING QUERY")
        sequences.extend(
            self.model.generate(
                tokens,
                do_sample=True,
                top_k=top_k,
                num_return_sequences=num_responses,
                eos_token_id=self.tokenizer.eos_token_id,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p,
                repetition_penalty=repetition_penalty,
                
            )
        )
    
        response = ([self.tokenizer.decode(sequence)[:len(query)+5] for sequence in sequences],[self.tokenizer.decode(sequence)[len(query)+5:] for sequence in sequences])
      
        if self.verbose:
            print("Prompt : ", response[0][0])
            print("Response : ", response[1][0])
        
        if self.cache:
            self.respone_cache[query] = response
        return response

    def score(self, query: str, temperature=None, top_p=None, top_k=None, repetition_penalty=None, max_new_tokens=1) -> List[List]:
        """
        Query the LLaMA 2 model for responses.

        :param query: The query to be posed to the language model.
        :type query: str
        :param num_responses: Number of desired responses, default is 1.
        :type num_responses: int
        :return: Response(s) from the LLaMA 2 model.
        :rtype: List[Dict]
        """
        
        if self.cache and query in self.respone_cache:
            return self.respone_cache[query]

        tokens = self.tokenizer(query, return_tensors='pt').input_ids.cuda()
        
        if tokens.shape[1] > 4024:
            tokens = torch.concat([tokens[:,:2012],tokens[:,-2012:]],dim=1)
            print(query)
        
        if temperature == None:
            temperature = self.temperature
        if top_p == None:
            top_p = self.config["top_p"]
        
        if repetition_penalty == None:
            repetition_penalty = self.config["repetition_penalty"]
        
        if top_k == None:
            top_k = self.top_k
        
        if self.verbose:
            print("ANSWERING SCORE")

        output = self.model.generate(
                tokens,
                do_sample=True,
                top_k=self.top_k,
                num_return_sequences=1,
                eos_token_id=self.tokenizer.eos_token_id,
                max_new_tokens=max_new_tokens,
                temperature=self.temperature,
                top_p=self.config["top_p"],
                output_scores=True,
                return_dict_in_generate=True
            )
        response = output.scores[0]
        if self.verbose:
            print("Scoring answer : ", self.tokenizer.decode(output.sequences[0]))
        
        if self.cache:
            self.respone_cache[query] = response
        return response
    
    def score_v2(self, query: str) -> List[List]:
            """
            Query the LLaMA 2 model for responses.

            :param query: The query to be posed to the language model.
            :type query: str
            :param num_responses: Number of desired responses, default is 1.
            :type num_responses: int
            :return: Response(s) from the LLaMA 2 model.
            :rtype: List[Dict]
            """
            
   
            tokens = self.tokenizer(query, return_tensors='pt').input_ids.cuda()
            
            if self.verbose:
                print("ANSWERING SCORE")

            output = self.model.generate(
                    tokens,
                    do_sample=True,
                    top_k=self.top_k,
                    num_return_sequences=1,
                    eos_token_id=self.tokenizer.eos_token_id,
                    max_new_tokens=1,
                    temperature=self.temperature,
                    top_p=self.config["top_p"],
                    output_scores=True,
                    return_dict_in_generate=True
                )
            response = [x.cpu() for x in output.scores] 
   
            
            return response


    def embeddings(self, query):
        
        tokens = self.tokenizer(query, return_tensors='pt').input_ids.cuda()

        output = self.model.model.model.embedding(tokens)[0].cpu().detach().numpy()
        return output
    
    

    def get_response_texts(self, query_responses: List[Dict]) -> List[str]:
        """
        Extract the response texts from the query response.

        :param query_responses: The response list of dictionaries generated from the `query` method.
        :type query_responses: List[Dict]
        :return: List of response strings.
        :rtype: List[str]
        """
        return [query_response["generated_text"] for query_response in query_responses]