import json
import time

from langsmith import traceable

from logger import logger

from .model import ChatModel, ModelType


class ClaudeModelType(ModelType):
    CLAUDE_35_SONNET = "claude-35-sonnet"
    CLAUDE_3_HAIKU = "claude-3-haiku"
    CLAUDE_3_OPUS = "claude-3-opus"
    CLAUDE_3_SONNET = "claude-3-sonnet"


class Claude(ChatModel):
    MAX_RETRIES = 10

    def __init__(self, model_type: ClaudeModelType, enable_langsmith: bool = False):
        self.enable_langsmith = enable_langsmith
        super().__init__(model_type, enable_langsmith)
        self._client = None

    def setup_client(self):
        from anthropic import Anthropic

        self._client = Anthropic(timeout=40.0, max_retries=0)
        if self.enable_langsmith:
            raise NotImplementedError(
                "Haven't implemented langchain for anthropic models"
            )

    @property
    def client(self):
        if self._client is None:
            self.setup_client()
        assert self._client is not None
        return self._client

    @traceable
    def query(self, messages: list[dict[str, str]]) -> str | None:
        """Query the model.

        Args:
            messages: the messages to pass to the model

        Returns:
            str: the LLM's response
        """
        from anthropic._types import (
            NOT_GIVEN,  # I know this is a bit bad, but its the cleanest I think
        )
        from anthropic.types import TextBlock

        logger.debug(f"Querying model: {self.config.model_name}")
        logger.debug(f"MESSAGES:\n{json.dumps(messages, indent=2)}")

        # Anthropic wants system messages in a different way. Handle that here
        system_message = NOT_GIVEN
        claude_messages = []
        last_message_role = None
        for message in messages:
            if message["role"] == "system":
                if system_message == NOT_GIVEN:
                    system_message = message["content"].strip()
                else:
                    raise NotImplementedError(
                        "Anthropic does not support multiple system messages"
                    )
                # message_role = "user"
                # message_content = message["content"].strip()
                # claude_messages.append(
                #     {"role": message_role, "content": message_content}
                # )
                # last_message_role = message_role
            else:
                # Anthropic doesn't handle repeated messages from the same role
                message_role = message["role"]
                message_content = message["content"].strip()

                if message_role == last_message_role:
                    claude_messages[-1]["content"] += "\n" + message_content
                else:
                    claude_messages.append(
                        {"role": message_role, "content": message_content}
                    )
                last_message_role = message_role

        message = self.client.messages.create(
            model=self.config.model_name,
            temperature=self.config.default_temperature,
            max_tokens=self.config.max_output_tokens,
            messages=claude_messages,
            system=system_message,
        )
        ret = message.content[0]
        logger.debug(f"RESPONSE:\n{ret}")

        assert isinstance(ret, TextBlock), "Haven't implemented tool use here"
        return ret.text

    @traceable
    def query_with_retry(
        self, messages: list[dict[str, str]], backoff: float = 4
    ) -> str | None:
        """Query the model, retry if you hit a ratelimit or timeout error.

        Args:
            messages: the messages to pass to the model
            backoff: the time to wait between unsuccessful calls

        Returns:
            str: the LLM's response
        """
        from anthropic import (
            APIConnectionError,
            APITimeoutError,
            InternalServerError,
            RateLimitError,
        )

        current_backoff = backoff
        for _ in range(Claude.MAX_RETRIES):
            try:
                return self.query(messages)
            except (
                RateLimitError,
                TimeoutError,
                InternalServerError,
                APITimeoutError,
                APIConnectionError,
            ) as e:
                logger.warning(
                    f"Claude hit rate limit or timeout. Error was {e}. Will try again in {current_backoff}s"
                )
                time.sleep(current_backoff)
                current_backoff *= 1.5
        raise Exception("Failed to query model after max retries")
