import torch
from typing import List, Dict
from .abstract_language_model import AbstractLanguageModel
from vllm import LLM, SamplingParams


class _LLM(AbstractLanguageModel):


    def __init__(
        self,
        config_path: str = "./source/language_models/config.json",
        model_name: str = "FT_model",
        weight_path=None,
        cache: bool = False,
        verbose: bool = False,
        GPUs: list = [0,1,2,3]
    ) -> None:
        super().__init__(config_path, model_name, cache)
        self.config: Dict = self.config[model_name]
        self.verbose = verbose
        self.temperature: float = self.config["temperature"]
        self.top_k: int = self.config["top_k"]
        self.max_tokens: int = self.config["max_tokens"]
        self.top_p: float = self.config["top_p"]
        self.repetition_penalty: float = self.config["repetition_penalty"]
        self.cache = cache
        if weight_path is None:
            weight_path = model_name

     
        if "meta-llama" in model_name:
            print("Llama model")
            self.model = LLM(
                model=weight_path,
                tensor_parallel_size=4,
                tokenizer=model_name,
            )
        else:
            self.model = LLM(
                model=weight_path,
                tensor_parallel_size=4,
                tokenizer="meta-llama/Meta-Llama-3-8B-Instruct",
            )
            
        if cache:
            self.response_cache = {}


    def get_action(
        self,
        query: list,
        num_responses: int = 1,
        temperature=None,
        top_p=None,
        top_k=None,
        repetition_penalty=None,
        max_new_tokens=300,
        use_tqdm=False
    ) -> List[List]:
        """
        Query the LLaMA 2 model for responses.

        :param query: The query to be posed to the language model.
        :type query: str
        :param num_responses: Number of desired responses, default is 1.
        :type num_responses: int
        :return: Response(s) from the LLaMA 2 model.
        :rtype: List[Dict]
        """

        if self.cache:
            all_query = query
            query = []
            for q in all_query:
                if q not in self.respone_cache:
                    query.append(q)

        if temperature is None:
            temperature = self.temperature
        if top_p is None:
            top_p = self.top_p

        if repetition_penalty is None:
            repetition_penalty = self.repetition_penalty

        if top_k is None:
            top_k = self.top_k

        if query != []:
            sampling_params = SamplingParams(
                top_k=top_k,
                max_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p,
                repetition_penalty=repetition_penalty,
                n=num_responses,
            )
                
            outputs = self.model.generate(query, sampling_params=sampling_params, use_tqdm=use_tqdm)

            new_outputs = []
            for action_per_state in outputs:
                output = []
                for action in action_per_state.outputs:
                    output.append(action.text)
                new_outputs.append(output)

        if self.cache:
            for i, q in enumerate(query):
                self.respone_cache[q] = new_outputs[i]
            new_outputs = [self.respone_cache[q] for q in all_query]

        return new_outputs
