import logging
from pathlib import Path
from typing import Any

import backoff
import numpy as np
from anthropic import Anthropic, APIError, APITimeoutError, RateLimitError

from agent_studio.config.config import Config
from agent_studio.llm.base_model import BaseModel
from agent_studio.llm.utils import anthropic_encode_image
from agent_studio.utils.types import MessageList

config = Config()
logger = logging.getLogger(__name__)


class AnthropicProvider(BaseModel):
    name = "claude"

    def __init__(self, **kwargs: Any) -> None:
        super().__init__()
        self.client = Anthropic(api_key=config.anthropic_api_key)
        self.system_prompt = None

    def _format_messages(
        self,
        raw_messages: MessageList,
    ) -> list[dict[str, Any]]:
        """
        Composes the messages to be sent to the model.
        """
        model_message: list[dict[str, Any]] = []
        past_role = None
        for msg in raw_messages:
            if msg.role == "system":
                self.system_prompt = msg.content
                continue
            if isinstance(msg.content, np.ndarray) or isinstance(msg.content, Path):
                content: dict = {
                    "type": "image",
                    "source": {
                        "type": "base64",
                        "media_type": "image/jpeg",
                        "data": anthropic_encode_image(msg.content),
                    },
                }
            elif isinstance(msg.content, str):
                content = {"type": "text", "text": msg.content}
            current_role = msg.role
            if past_role != current_role:
                model_message.append(
                    {
                        "role": current_role,
                        "content": [content],
                    }
                )
                past_role = current_role
            else:
                model_message[-1]["content"].append(content)
        return model_message

    def generate_response(
        self, messages: MessageList, **kwargs
    ) -> tuple[str, dict[str, Any]]:
        """Creates a chat completion using the Anthropic API."""

        model = kwargs.get("model", None)
        if model is None:
            raise ValueError("Model name is not set")
        temperature = kwargs.get("temperature", config.temperature)
        max_tokens = kwargs.get("max_tokens", config.max_tokens)
        model_message = self._format_messages(raw_messages=messages)
        logger.info(f"Creating chat completion with model {model}.")

        @backoff.on_exception(
            backoff.constant,
            (APIError, RateLimitError, APITimeoutError),
            max_tries=config.max_retries,
            interval=10,
        )
        def _generate_response_with_retry() -> tuple[str, dict[str, int]]:
            if self.system_prompt is not None:
                response = self.client.messages.create(
                    model=model,
                    messages=model_message,
                    temperature=temperature,
                    max_tokens=max_tokens,
                    system=self.system_prompt,
                )
            else:
                response = self.client.messages.create(
                    model=model,
                    messages=model_message,
                    temperature=temperature,
                    max_tokens=max_tokens,
                )

            if response is None:
                logger.error("Failed to get a response from Anthropic. Try again.")

            response_message = response.content[0].text
            if response.usage is None:
                info = {}
                logger.warn("Failed to get usage information from OpenAI.")
            else:
                info = {
                    "prompt_tokens": response.usage.input_tokens,
                    "completion_tokens": response.usage.output_tokens,
                    "total_tokens": response.usage.input_tokens
                    + response.usage.output_tokens,
                }

            logger.info(f"\nReceived response:\n{response_message}\nInfo:\n{info}")

            return response_message, info

        return _generate_response_with_retry()
