import os
import torch
from typing import List, Dict, Union
from .abstract_language_model import AbstractLanguageModel
from time import time
import transformers
import numpy as np
from awq import AutoAWQForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer
from vllm import LLM, SamplingParams

model_name_or_path = "mistralai/Mixtral-8x7B-Instruct-v0.1"

class LLM_vllm(AbstractLanguageModel):
    """
    An interface to use LLaMA 2 models through the HuggingFace library.
    """

    def __init__(
        self, config_path: str = "", model_name: str = None, cache: bool = False, verbose: bool = False
    ) -> None:
        """
        Initialize an instance of the Mistral class with configuration, model details, and caching options.

        :param config_path: Path to the configuration file. Defaults to an empty string.
        :type config_path: str
        :param model_name: Specifies the name of the LLaMA model variant. Defaults to "llama7b-hf".
                           Used to select the correct configuration.
        :type model_name: str
        :param cache: Flag to determine whether to cache responses. Defaults to False.
        :type cache: bool
        """
        super().__init__(config_path, model_name, cache)
        self.config: Dict = self.config[model_name]
   
        self.verbose = verbose
        self.temperature: float = self.config["temperature"]
        self.top_k: int = self.config["top_k"]
        self.max_tokens: int = self.config["max_tokens"]

        self.model = LLM(model=model_name,
            tensor_parallel_size=4,
            tokenizer=model_name,
            )
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

        
        # self.model.eval()
        torch.no_grad()
        

    def query(self, query: str, num_responses: int = 1, temperature=None, top_p=None, top_k=None, repetition_penalty=None, max_new_tokens=500) -> List[List]:
        """
        Query the LLaMA 2 model for responses.

        :param query: The query to be posed to the language model.
        :type query: str
        :param num_responses: Number of desired responses, default is 1.
        :type num_responses: int
        :return: Response(s) from the LLaMA 2 model.
        :rtype: List[Dict]
        """

        for q in query:
            q = q.replace("<<SYS>>", "").replace("<</SYS>>", "")

        if self.cache and query in self.respone_cache:
            return self.respone_cache[query]
        sequences = []
      
        # tokens = self.tokenizer(query, return_tensors='pt').input_ids.cuda()

        # if tokens.shape[1] > 4024:
        #     tokens = torch.concat([tokens[:,:1000],tokens[:,-1000:]],dim=1)
        #     torch.cuda.empty_cache()       

        if temperature == None:
            temperature = self.temperature
        if top_p == None:
            top_p = self.config["top_p"]
        
        if repetition_penalty == None:
            repetition_penalty = self.config["repetition_penalty"]
        
        if top_k == None:
            top_k = self.top_k
        if self.verbose:
            print("ANSWERING QUERY")
            
        sampling_params = SamplingParams(
        top_k=top_k,
        max_tokens=max_new_tokens,
        temperature=temperature,
        top_p=top_p,
        # repetition_penalty=repetition_penalty,
        n=num_responses,
        )
        
        outputs = self.model.generate(query, sampling_params, use_tqdm=False)[0]
        # print("outputs : ", outputs)
        response = (query,[output.text for output in outputs.outputs])
      
        if self.verbose:
            print("Prompt_ : ", response[0])
            print("Response : ", response[1])
        
        if self.cache:
            self.respone_cache[query] = response
        return response


    def score(self, query: str, temperature=None, top_p=None, top_k=None, repetition_penalty=None, max_new_tokens=1) -> List[List]:
        """
        Query the LLaMA 2 model for responses.

        :param query: The query to be posed to the language model.
        :type query: str
        :param num_responses: Number of desired responses, default is 1.
        :type num_responses: int
        :return: Response(s) from the LLaMA 2 model.
        :rtype: List[Dict]
        """
        
        if self.cache and query in self.respone_cache:
            return self.respone_cache[query]

        # tokens = self.tokenizer(query, return_tensors='pt').input_ids.cuda()
        
        # if tokens.shape[1] > 4024:
        #     tokens = torch.concat([tokens[:,:2012],tokens[:,-2012:]],dim=1)
        #     print(query)
        
        for q in query:
            q = q.replace("<<SYS>>", "").replace("<</SYS>>", "")
        
        
        if temperature == None:
            temperature = self.temperature
        if top_p == None:
            top_p = self.config["top_p"]
        
        if repetition_penalty == None:
            repetition_penalty = self.config["repetition_penalty"]
        
        if top_k == None:
            top_k = self.top_k
        
        if self.verbose:
            print("ANSWERING SCORE")

        sampling_params = SamplingParams(
            top_k=top_k,
            max_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            # repetition_penalty=repetition_penalty,
            logprobs = 5,
            # n=1,
  
        )
        output = self.model.generate(query, sampling_params, use_tqdm=False)
        # print(output[0].outputs)
        # response = [torch.nn.functional.softmax(x[0], dim=0).cpu() for x in output.scores] 
        temp  = output[0].outputs[0].logprobs
        response = []
        for t in temp:
            output_dict ={}
            for key in t:
                output_dict[t[key].decoded_token] = t[key].logprob
            response.append(output_dict)
        
        if self.verbose:
            print("Scoring answer : ", output[0].outputs[0].text)
        
        if self.cache:
            self.respone_cache[query] = response
        return response



    def embeddings(self, query):
        
        tokens = self.tokenizer(query, return_tensors='pt').input_ids.cuda()

        output = self.model.model.model.embedding(tokens)[0].cpu().detach().numpy()
        return output
    
    

    def get_response_texts(self, query_responses: List[Dict]) -> List[str]:
        """
        Extract the response texts from the query response.

        :param query_responses: The response list of dictionaries generated from the `query` method.
        :type query_responses: List[Dict]
        :return: List of response strings.
        :rtype: List[str]
        """
        return [query_response["generated_text"] for query_response in query_responses]