import requests
import json
from LLMServices.LLMService import LLMService
import logging 
from LLMServices import Message
    

class VllamaLLMService(LLMService):

    def __init__(self) -> None:
        
        with open(".local/vllama.json", "r") as f:
            vllama_info = json.loads(f.read())
            self.api_url = vllama_info["api_url"]
            self.model = vllama_info["model"]
            self.max_tokens = vllama_info["max_tokens"]
            self.temperature = vllama_info["temperature"]
            self.top_p = vllama_info["top_p"]
            self.top_k = vllama_info["top_k"]
            self.frequency_penalty = vllama_info["frequency_penalty"]
            self.presence_penalty = vllama_info["presence_penalty"]
            self.timeout = vllama_info["timeout"]



    def do_prompt_get_text(self, prompt, name):
        response = self._generate_completion(prompt=prompt,name=name)
        try:
            return response["choices"][0]["message"]["content"]
        except Exception as E:
            print(f"Reponse {response} yields error {E}")


    def get_model_stats(self):
        response = requests.get(
            url=f"{self.api_url}/ps",
            verify=False
        )
        if response.status_code != 200:
            response.raise_for_status()
        models = response.json()["models"]
        if len(models) == 0:
            return {}
        else:
            for model in models:
                if self.model in model["name"]:
                    return model
        



    def _generate_completion(
            self, 
            prompt: str, 
            name: str,
            model: str = None,
            ) -> dict:
        request_params = {
            "model": self.model,
            #" prompt": prompt,
            # "max_tokens": self.max_tokens,
            # "temperature": self.temperature,
            # "top_p": self.top_p,
            # "top_k": self.top_k,
            # "frequency_penalty": self.frequency_penalty,
            # "presence_penalty": self.presence_penalty,
            # "stream": False,
            "messages": [{
                "role": name,
                "content": prompt
            }]
        }
        if model != None:
            request_params["model"] = model
        try:
            response = requests.post(
                url=  "https://wire.westpoint.edu/vllm/v1/chat/completions", #f"{self.api_url}",
                json=request_params,
                verify= False,
                timeout=self.timeout
                )
            if response.status_code == 200:
                return response.json()
            else:
                response.raise_for_status()
        except requests.exceptions.SSLError as ssl_err:
            logging.error(f"SSL Error: {ssl_err}")
        except requests.exceptions.Timeout:
            logging.error("Request timed out")
        except requests.exceptions.RequestException as e:
            logging.error(f"Request failed: {e}")
        return dict({"response": f""" We are unable to answer your prompt. We apologize. Could you simplify this?"""})