from LLMServices.LLMService import LLMService
import google.generativeai as genai
from google.generativeai.types import generation_types
import json
from google.cloud import aiplatform
aiplatform.init(project='nl-to-sql-model-eval')



class VertexLLMService(LLMService):

    def __init__(self) -> None:
        super().__init__()
        vertex_info_file = open(".local/vertex.json")
        vertex_info = json.loads(vertex_info_file.read())
        vertex_info_file.close()
        genai.configure(api_key=vertex_info["api_key"])
        self.generation_config = {
            "temperature": vertex_info["temperature"],
            "top_p": vertex_info["top_p"],
            "top_k": vertex_info["top_k"],
            "max_output_tokens": vertex_info["max_tokens"],
        }
        self.safety_settings = [
            {
                "category": "HARM_CATEGORY_HARASSMENT",
                "threshold": "BLOCK_MEDIUM_AND_ABOVE"
            },
            {
                "category": "HARM_CATEGORY_HATE_SPEECH",
                "threshold": "BLOCK_MEDIUM_AND_ABOVE"
            },
            {
                "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
                "threshold": "BLOCK_NONE"
            },
            {
                "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
                "threshold": "BLOCK_MEDIUM_AND_ABOVE"
            },
        ]




    def do_prompt_get_text(self, prompt: str, name = None) -> str:
        return self._call_vertex(prompt=prompt)



    def _call_vertex(
        self,
        prompt: str,
        model: str = "gemini-1.5-pro-latest"
    ) -> str:
        model = genai.GenerativeModel(model_name="gemini-1.5-pro-latest",
                                    generation_config=self.generation_config,
                                    safety_settings=self.safety_settings)
        convo = model.start_chat(history=[])
        try:
            convo.send_message(prompt)
        except generation_types.StopCandidateException as e:
            print(e)
            if hasattr(e, 'message'):
                return e.message
            else:
                return "Encountered exception without a message attribute."
        return convo.last.text