import json
import time

from langsmith import traceable

from logger import logger

from .model import ChatModel, ModelType


class GPTModelType(ModelType):
    GPT4 = "gpt-4"
    GPT4TURBO = "gpt-4-turbo"
    GPT4O = "gpt-4o"
    GPT4OMINI = "gpt-4o-mini"
    GPT35TURBO = "gpt-35"
    GPT35TURBO_CUSTOMFT1 = "gpt-35-ft1"
    GPT4O1 = "gpt-4o1"
    GPT4O1MINI = "gpt-4o1-mini"


class GPT(ChatModel):
    MAX_RETRIES = 10

    def __init__(self, model_type: GPTModelType, enable_langsmith: bool = False):
        self.enable_langsmith = enable_langsmith
        super().__init__(model_type, enable_langsmith)
        self._client = None

    def setup_client(self):
        from langsmith.wrappers import wrap_openai
        from openai import OpenAI

        self._client = OpenAI(timeout=40.0, max_retries=0)
        if self.enable_langsmith:
            self._client = wrap_openai(self._client)

    @property
    def client(self):
        if self._client is None:
            self.setup_client()
        assert self._client is not None
        return self._client

    @traceable
    def query(self, messages: list[dict[str, str]]) -> str | None:
        """Query the model.

        Args:
            messages: the messages to pass to the model

        Returns:
            str: the LLM's response
        """
        logger.debug(f"Querying model: {self.config.model_name}")
        logger.debug(f"MESSAGES:\n{json.dumps(messages, indent=2)}")

        gpt_messages = []
        for message in messages:
            if (
                isinstance(message, dict)
                and message["role"] == "system"
                and not self.config.supports_system
            ):
                gpt_messages.append({"role": "user", "content": message["content"]})
            else:
                gpt_messages.append(message)

        if not self.config.supports_system:
            # Support for gpt o1
            completion = self.client.chat.completions.create(
                model=self.config.model_name,
                messages=gpt_messages,
            )
        else:
            completion = self.client.chat.completions.create(
                model=self.config.model_name,
                temperature=self.config.default_temperature,
                max_tokens=self.config.max_output_tokens,
                messages=gpt_messages,
            )

        ret = completion.choices[0].message.content
        logger.debug(f"RESPONSE:\n{ret}")

        return ret

    @traceable
    def query_with_retry(
        self, messages: list[dict[str, str]], backoff: int = 4
    ) -> str | None:
        """Query the model, retry if you hit a ratelimit or timeout error.

        Args:
            messages: the messages to pass to the model
            backoff: the time to wait between unsuccessful calls

        Returns:
            str: the LLM's response
        """
        from httpcore import ReadTimeout
        from openai import (
            APIConnectionError,
            APITimeoutError,
            InternalServerError,
            RateLimitError,
        )

        current_backoff = backoff

        for _ in range(GPT.MAX_RETRIES):
            try:
                return self.query(messages)
            except (
                RateLimitError,
                TimeoutError,
                APIConnectionError,
                APITimeoutError,
                InternalServerError,
                ReadTimeout,
            ) as e:
                logger.warning(
                    f"GPT hit rate limit or timeout. Error was {e}. Will try again in {current_backoff}s"
                )
                time.sleep(current_backoff)
                current_backoff *= 1.5
        raise Exception("Failed to query model after max retries")
