import base64
import copy
import dataclasses
import io
import json
import os
import random
import time
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence, Union

from anthropic import AnthropicBedrock, APIStatusError, BadRequestError, RateLimitError
from anthropic.types import Message, Usage
from PIL import Image

from llm_mcts.file_logging import get_logging_dir
from llm_mcts.llm_generation_interface import GenerationRequest, GenerationResult, Model
from llm_mcts.models.custom_json_encoder import CustomJSONEncoder


# Reference: https://aws.amazon.com/bedrock/pricing/
PRICING = {
    "us.anthropic.claude-3-5-sonnet-20240620-v1:0": {
        "input_tokens": 3.0 / 1e6,
        "output_tokens": 15.0 / 1e6,
    },
    "anthropic.claude-3-5-sonnet-20240620-v1:0": {
        "input_tokens": 3.0 / 1e6,
        "output_tokens": 15.0 / 1e6,
    },
    "us.anthropic.claude-3-5-sonnet-20241022-v2:0": {
        "input_tokens": 3.0 / 1e6,
        "output_tokens": 15.0 / 1e6,
    },  # NOTE: US region only (at 2024-12-21)
    "us.anthropic.claude-3-5-haiku-20241022-v1:0": {
        "input_tokens": 1.0 / 1e6,
        "output_tokens": 5.0 / 1e6,
    },  # NOTE: US region only (at 2024-12-21)
}


class ClaudeBedrockAPIModel(Model):
    def __init__(
        self,
        model: str = "us.anthropic.claude-3-5-sonnet-20240620-v1:0",
        temperature: float = 0.3,
        logging_dir: Optional[Path] = None,
        num_trial: int = 14,  # We will wait for up to about four and a half hours.
    ) -> None:

        self.model = self.model_name = model
        self.client = AnthropicBedrock(aws_region="us-east-1")
        self.temperature = temperature
        self.num_trial = num_trial

        self.logging_dir = (
            get_logging_dir(os.getpid()) if logging_dir is None else logging_dir
        )
        if not self.logging_dir.exists():
            self.logging_dir.mkdir()

        self.call_count = 0

    def _encode_image(self, image: Image.Image) -> str:
        buffered = io.BytesIO()
        image.save(buffered, format="JPEG")
        img_str = base64.b64encode(buffered.getvalue()).decode()
        return img_str

    def _process_message(
        self, request: GenerationRequest
    ) -> List[Dict[str, Union[str, Image.Image]]]:
        request_copy = copy.deepcopy(request)
        messages = []
        for message in request_copy.messages:
            # if the message contains image tokens, the content is a list
            # image should be encoded as base64
            if isinstance(message.content, list):
                full_content = []
                for content in message.content:
                    if isinstance(content, str):
                        full_content.append(
                            {
                                "type": "text",
                                "text": content,
                            }
                        )
                    elif isinstance(content, Image.Image):
                        full_content.append(
                            {
                                "type": "image",
                                "source": {
                                    "type": "base64",
                                    "media_type": "image/jpeg",
                                    "data": self._encode_image(content.convert("RGB")),
                                },
                            }
                        )
                message.content = full_content
            messages.append(dataclasses.asdict(message))
        return messages

    def call_api(self, messages: List[Dict[str, Union[str, Image.Image]]]) -> Message:
        base_delay = 10  # Always wait at least 10 seconds
        for i in range(self.num_trial):
            try:
                return self.client.messages.create(
                    max_tokens=2048,
                    messages=messages,
                    model=self.model,
                    temperature=self.temperature,
                    # n=num_samples, # n parameter not supported for Claude
                )
            except (APIStatusError, RateLimitError, BadRequestError) as e:
                if e.status_code == 424 or e.status_code == 429 or e.status_code == 400:
                    print(
                        f"Hit an API error with status code {e.status_code}, retrying..."
                    )

                    # Exponential Backoff and Jitter
                    # See: https://aws.amazon.com/jp/blogs/architecture/exponential-backoff-and-jitter/
                    #
                    # Exponential Backoff:
                    #   2 ** i
                    #
                    # Jitter:
                    #   Multiply by a random factor between 0.5 and 1.5
                    exp_factor = 2**i
                    jitter_factor = random.uniform(0.5, 1.5)

                    # Add base_delay
                    current_backoff = exp_factor * jitter_factor
                    final_delay = base_delay + current_backoff

                    print(
                        f"Sleeping {final_delay:.2f} seconds before retry (attempt {i+1}/{self.num_trial})..."
                    )
                    time.sleep(final_delay)
                    continue
                else:
                    raise

        raise RuntimeError(
            f"Maximum number of trial {self.num_trial} reached in calling Claude API"
        )

    def calculate_price(self, usage: Usage) -> dict[Any]:
        data = usage.model_dump()
        price_data = {
            key: int(data[key]) * PRICING[self.model_name][key]
            for key in PRICING[self.model_name]
        }
        price_data["total"] = sum(price_data.values())

        return price_data

    def generate(
        self, requests: Sequence[GenerationRequest]
    ) -> Iterable[GenerationResult]:
        results = []
        for request in requests:
            messages = self._process_message(request)
            chat_completion = self.call_api(messages)
            if not os.getenv("DISABLE_LOG_USAGE"):
                assert chat_completion.usage

                data = chat_completion.usage.model_dump()
                data["price"] = self.calculate_price(chat_completion.usage)

                (
                    self.logging_dir / f"{self.model_name}_{self.call_count}.txt"
                ).write_text(json.dumps(data, indent=4, cls=CustomJSONEncoder))

            self.call_count += 1
            result = GenerationResult(
                request=request, generation=chat_completion.content[0].text
            )
            results.append(result)

            (
                self.logging_dir / f"response_{self.model_name}_{self.call_count}.log"
            ).write_text(
                json.dumps(dataclasses.asdict(result), indent=4, cls=CustomJSONEncoder)
            )
        return results
