from typing import List, Optional, Dict

import torch
from PIL import Image
from transformers import AutoModel, AutoTokenizer

from src.llm_messenger.classes.content import Content
from src.llm_messenger.classes.image_content import ImageContent
from src.llm_messenger.classes.llm_messenger import LLMMessenger
from src.llm_messenger.classes.text_content import TextContent


class HuggingfaceMessenger(LLMMessenger):
    def __init__(
        self,
        model_name: str,
        temperature: float = 0.0,
        log_directory: str = "",
    ):
        super().__init__(model_name, temperature, log_directory)
        model = AutoModel.from_pretrained(
            model_name,
            trust_remote_code=True,
            attn_implementation="sdpa",
            torch_dtype=torch.bfloat16,
        )  # sdpa or flash_attention_2, no eager
        self.model = model.eval().cuda()
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name, trust_remote_code=True
        )
        self._context: Optional[List[Dict]] = None  # set type explicitly

    def __update_context(self, contents: List[Content], model_response: str):
        if self._context is not None:
            if not self._keep_image_history:
                contents = [
                    content
                    for content in contents
                    if not isinstance(content, ImageContent)
                ]

            message = self.__get_messages(contents)
            self._context.append(message)
            self._context.append({"role": "assistant", "content": [model_response]})

    def __get_messages(self, contents: List[Content]) -> dict:
        messages = []
        for content in contents:
            if isinstance(content, ImageContent):
                msg = Image.open(content.image_path).convert("RGB")
                messages.append(msg)
            elif isinstance(content, TextContent):
                messages.append(content.text)
        return {"role": "user", "content": messages}

    def ask(self, contents: List[Content]) -> str:
        if self._context is None:
            self.log("INFO", [TextContent("Opening new context")])

        self.log("USER", contents)

        context = self.get_context()
        message = self.__get_messages(contents)
        msgs = context + [message]

        model_response = self.model.chat(
            image=None, msgs=msgs, tokenizer=self.tokenizer
        )
        print(model_response)

        self.log("ASSISTANT", [TextContent(model_response)])
        self.__update_context(contents, model_response)
        return model_response
