import base64
import json
import os
import re
import time
from typing import List, Union, Optional, Dict

import requests

from src.llm_messenger.classes.content import Content
from src.llm_messenger.classes.exceptions import UnsupportedImageException
from src.llm_messenger.classes.image_content import ImageContent
from src.llm_messenger.classes.llm_messenger import LLMMessenger
from src.llm_messenger.classes.text_content import TextContent


class GPTMessenger(LLMMessenger):
    def __init__(
        self,
        api_key: str,
        api_endpoint: str = "https://api.openai.com/v1/chat/completions",
        system_prompt: Optional[str] = None,
        model_name="gpt-4-turbo",
        temperature: float = 0.0,
        log_directory: str = "",
    ):
        super().__init__(model_name, temperature, log_directory)
        self.__api_key = api_key
        self.__api_endpoint = api_endpoint
        self.__system_prompt_content = (
            TextContent(system_prompt) if system_prompt is not None else None
        )
        self._context: Optional[Dict] = None

    def __get_headers(self) -> dict:
        return {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {self.__api_key}",
        }

    def __get_text_content(self, message: str):
        return {"type": "text", "text": message}

    def __get_image_content(self, image_path: str):
        _, ext = os.path.splitext(image_path)
        raw_ext = ext.replace(".", "")
        return {
            "type": "image_url",
            "image_url": {
                "url": f"data:image/{raw_ext};base64,{self.__encode_image(image_path)}"
            },
        }

    def __get_system_message(self) -> List[dict]:
        if self.__system_prompt_content is None:
            return []
        else:
            return [
                {
                    "role": "system",
                    "content": [
                        self.__get_text_content(self.__system_prompt_content.text)
                    ],
                }
            ]

    def __get_payload(self, messages: List[dict]):
        return {
            "model": self._model_name,
            "messages": messages,
            "temperature": self._temperature,
            "max_tokens": 4096,
        }

    def __update_context(self, contents: List[Content], model_response: str):
        if self._context is not None:
            if not self._keep_image_history:
                contents = [
                    content
                    for content in contents
                    if not isinstance(content, ImageContent)
                ]

            message = self.__get_message(contents)
            self._context.append(message)
            self._context.append({"role": "assistant", "content": model_response})

    def __convert_contents(self, contents: List[Content]) -> List[dict]:
        converted_contents = []

        for content in contents:
            if isinstance(content, ImageContent):
                converted_contents.append(self.__get_image_content(content.image_path))
            elif isinstance(content, TextContent):
                converted_contents.append(self.__get_text_content(content.text))

        return converted_contents

    def __get_message(self, contents: List[Content]) -> dict:
        contents_json = self.__convert_contents(contents)
        return {"role": "user", "content": contents_json}

    def __encode_image(self, image_path: str):
        with open(image_path, "rb") as image_file:
            return base64.b64encode(image_file.read()).decode("utf-8")

    def __should_retry(self, response: dict) -> bool:
        retry_time = self.__get_retry_time(json.dumps(response))
        if retry_time is not None:
            if "ms" in retry_time:
                self.__wait_miliseconds(retry_time)
            elif "s" in retry_time:
                self.__wait_seconds(retry_time)
            return True
        return False

    def __wait_miliseconds(self, retry_time: str):
        miliseconds = int(retry_time.replace("ms", ""))
        self.log(
            "INFO",
            [TextContent(f"Rate limit reached. Retrying in {miliseconds}ms...")],
        )
        time.sleep(miliseconds / 1000)

    def __wait_seconds(self, retry_time: str):
        seconds = int(retry_time.replace("s", ""))
        self.log(
            "INFO",
            [TextContent(f"Rate limit reached. Retrying in {seconds}s...")],
        )
        time.sleep(seconds)

    def __get_retry_time(self, response: str) -> Union[str, None]:
        pattern = r"(.*?)Please try again in (.*?)\.(.*?)"

        match = re.search(pattern, response, re.DOTALL)
        time = match.group(2) if match else None
        return time

    def __check_error(self, response):
        if "error" in response:
            error_message = response["error"]["message"]
            self.log("ERROR", [TextContent(error_message)])

            if (
                response["error"]["code"] is not None
                and "image_parse_error" in response["error"]["code"]
            ):
                raise UnsupportedImageException()

            raise Exception(error_message)

    def __get_model_response(self, response):
        return response["choices"][0]["message"]["content"]

    def ask(self, contents: List[Content]) -> str:
        if self._context is None:
            self.log("INFO", [TextContent("Opening new context")])

        if self.__system_prompt_content is not None:
            self.log("SYSTEM", [self.__system_prompt_content])
        self.log("USER", contents)

        headers = self.__get_headers()
        system = self.__get_system_message()
        context = self.get_context()
        message = self.__get_message(contents)
        payload = self.__get_payload(system + context + [message])

        response = requests.post(
            self.__api_endpoint,
            headers=headers,
            json=payload,
        )
        response = response.json()

        if self.__should_retry(response):
            return self.ask(contents)

        self.__check_error(response)

        model_response = self.__get_model_response(response)
        self.log("ASSISTANT", [TextContent(model_response)])
        self.__update_context(contents, model_response)
        return model_response
