import os

from typing import (
    Any,
    Dict,
    Generator,
    Optional,
)

from anthropic import (
    AnthropicBedrock,
    APIConnectionError,
    InternalServerError,
    RateLimitError,
)
from dotenv import load_dotenv
from tenacity import (
    retry,
    retry_if_exception_type,
    wait_random_exponential,
)

from src.generator.base_generator import BaseGenerator
from src.utils.log import log_exception_in_retry


load_dotenv()


class ClaudeGenerator(BaseGenerator):
    def __init__(self, model_name: str, temperature: float, api_key: str = os.environ.get("ANTHROPIC_API_KEY")) -> None:
        self.client = AnthropicBedrock(
            aws_access_key=os.getenv("AWS_ACCESS_KEY"),
            aws_secret_key=os.getenv("AWS_SECRET_KEY"),
            aws_region=os.getenv("AWS_REGION"),
        )
        self.model_name = model_name
        self.temperature = temperature

    @retry(
        retry=retry_if_exception_type((RateLimitError, APIConnectionError, InternalServerError)),
        wait=wait_random_exponential(min=1, max=120),
        before_sleep=log_exception_in_retry,
    )
    def _chat_completion_with_backoff(self, **kwargs):
        return self.client.messages.create(**kwargs)

    def generate(self, prompt: str, json_mode: bool = False, json_schema: Optional[Dict[str, Any]] = None) -> str:
        messages = [{"role": "user", "content": prompt}]
        kwargs = {
            "model": self.model_name,
            "messages": messages,
            "max_tokens": 1000,  # FIXME make this configurable
            "temperature": self.temperature,
            "stream": False,
            "timeout": 120,
        }
        if json_mode:
            kwargs["response_format"] = {"type": "json_object"}
        if json_schema is not None:
            kwargs["response_format"] = json_schema
        response = self._chat_completion_with_backoff(**kwargs)
        return response.content[0].text

    def generate_async(self, prompt: str, json_mode: bool = False) -> Generator[str, None, None]:
        raise NotImplementedError()
