import requests
import json
from LLMServices.LLMService import LLMService
import logging 
class OllamaLLMService(LLMService):

    def __init__(self) -> None:
        
        with open(".local/ollama.json", "r") as f:
            ollama_info = json.loads(f.read())
            self.api_url = ollama_info["api_url"]
            self.model = ollama_info["model"]
            self.max_tokens = ollama_info["max_tokens"]
            self.temperature = ollama_info["temperature"]
            self.top_p = ollama_info["top_p"]
            self.top_k = ollama_info["top_k"]
            self.frequency_penalty = ollama_info["frequency_penalty"]
            self.presence_penalty = ollama_info["presence_penalty"]
            self.timeout = ollama_info["timeout"]



    def do_prompt_get_text(self, prompt,name=None):
        response = self._generate_completion(prompt=prompt)
        return response["response"]



    def get_model_stats(self):
        response = requests.get(
            url=f"{self.api_url}/ps",
            verify=False
        )
        if response.status_code != 200:
            response.raise_for_status()
        models = response.json()["models"]
        if len(models) == 0:
            return {}
        else:
            for model in models:
                if self.model in model["name"]:
                    return model
        



    def _generate_completion(
            self, 
            prompt: str, 
            model: str = None
            ) -> dict:
        request_params = {
            "model": self.model,
            "prompt": prompt,
            "max_tokens": self.max_tokens,
            "temperature": self.temperature,
            "top_p": self.top_p,
            "top_k": self.top_k,
            "frequency_penalty": self.frequency_penalty,
            "presence_penalty": self.presence_penalty,
            "stream": False
        }
        if model != None:
            request_params["model"] = model
        try:
            response = requests.post(
                url=f"{self.api_url}/generate",
                json=request_params,
                verify= False,
                timeout=self.timeout
                )
            if response.status_code == 200:
                return response.json()
            else:
                response.raise_for_status()
        except requests.exceptions.SSLError as ssl_err:
            logging.error(f"SSL Error: {ssl_err}")
            print(f"SSL Error: {ssl_err}")
        except requests.exceptions.Timeout:
            logging.error("Request timed out")
            print("Request timed out")
        except requests.exceptions.RequestException as e:
            logging.error(f"Request failed: {e}")
            print(f"Request failed: {e}")
        return dict({"response": f""" We are unable to answer your prompt. We apologize. Could you simplify this?"""})
