from typing import List, Optional

import PIL.Image
import google.generativeai as genai
from google.generativeai import GenerationConfig
from google.generativeai.types import ContentType

from src.llm_messenger.classes.content import Content
from src.llm_messenger.classes.image_content import ImageContent
from src.llm_messenger.classes.llm_messenger import LLMMessenger
from src.llm_messenger.classes.text_content import TextContent


class GoogleMessenger(LLMMessenger):
    def __init__(
        self,
        api_key: str,
        system_prompt: Optional[str] = None,
        model_name: str = "gemini-1.5-flash",
        temperature: float = 0.0,
        log_directory: str = "",
    ):
        super().__init__(model_name, temperature, log_directory)
        genai.configure(api_key=api_key)
        self.__model = genai.GenerativeModel(
            model_name,
            generation_config=GenerationConfig(
                max_output_tokens=1024,
                temperature=temperature,
                response_mime_type="text/plain",
            ),
            system_instruction=system_prompt,
        )
        self._context: Optional[List[ContentType]] = None

    def ask(self, contents: List[Content]) -> str:
        if self._context is None:
            self.log("INFO", [TextContent("Opening new context")])

        self.log("USER", contents)

        messages = self.__get_messages_with_context(contents)
        response = self.__model.generate_content(messages)

        try:
            model_response = response.text
        except Exception as e:
            print(
                "Response blocked",
                contents,
                response.prompt_feedback.block_reason,
                response.prompt_feedback.safety_ratings,
            )
            raise e

        self.log("ASSISTANT", [TextContent(model_response)])
        self.__update_context(messages, model_response)
        return model_response

    def __get_messages_with_context(self, contents: List[Content]) -> List[ContentType]:
        messages = self.get_context()
        for content in contents:
            if isinstance(content, ImageContent):
                messages.append(content.to_pil_image())
            elif isinstance(content, TextContent):
                messages.append(content.text)
        return messages

    def __update_context(self, messages: List, model_response: str):
        if self._context is not None:
            messages = messages.copy()
            if not self._keep_image_history:
                messages = [m for m in messages if not isinstance(m, PIL.Image.Image)]
            messages.append(model_response)
            self._context = messages
