import dataclasses
import json
import logging
import os
from pathlib import Path
from typing import Iterable, Optional, Sequence

import google.generativeai as genai
from google.generativeai.types import GenerateContentResponse
from tenacity import TryAgain, before_sleep_log, retry, wait_exponential

from llm_mcts.file_logging import get_logging_dir
from llm_mcts.llm_generation_interface import GenerationRequest, GenerationResult, Model
from llm_mcts.models.custom_json_encoder import CustomJSONEncoder

# Reference: https://ai.google.dev/pricing#1_5pro
PRICING = {
    "gemini-1.5-pro-002": {
        "prompt_token_count": 1.25 / 1e6,
        "candidates_token_count": 5.0 / 1e6,
    },
    # 無料期間中
    "gemini-2.0-flash-thinking-exp-01-21": {
        "prompt_token_count": 0.0 / 1e6,
        "candidates_token_count": 0.0 / 1e6,
    },
}

logging.basicConfig(level=logging.INFO)

logger = logging.getLogger(__name__)


@retry(
    wait=wait_exponential(multiplier=1, min=1, max=10),
    before_sleep=before_sleep_log(logger, logging.INFO),
)
def try_generate(self: "GeminiAPIModel", contents, generation_config):
    chat_completion = self.model.generate_content(
        contents=contents,
        generation_config=genai.types.GenerationConfig(**generation_config),
    )

    for candidate in chat_completion.candidates:
        # For some reason the content parts is not returned in some occasion, so we retry in that case
        if len(candidate.content.parts) == 0:
            raise TryAgain()

    return chat_completion


class GeminiAPIModel(Model):
    def __init__(
        self,
        model: str = "gemini-1.5-pro-002",
        temperature: float = 0.7,
        logging_dir: Optional[Path] = None,
        num_trial: int = 5,
    ) -> None:
        genai.configure(api_key=os.environ["GEMINI_API_KEY"])
        self.generation_config = dict(temperature=temperature)
        self.model_name = model
        self.model = genai.GenerativeModel(model_name=model)

        self.logging_dir = (
            get_logging_dir(os.getpid()) if logging_dir is None else logging_dir
        )
        if not self.logging_dir.exists():
            self.logging_dir.mkdir()

        self.call_count = 0
        self.num_trial = num_trial

    def call_api(
        self, request: GenerationRequest, num_samples: int
    ) -> GenerateContentResponse:
        contents = []
        for message in request.messages:
            content = dataclasses.asdict(message)
            if (
                "assistant" in content
            ):  # Gemini API uses "model" instead of "assistant" as a role
                content["model"] = content.pop("assistant")

            if "content" in content:
                content["parts"] = content.pop("content")
            contents.append(content)

        # num_samples is called "candidate_count" in Gemini API
        generation_config = self.generation_config | {"candidate_count": num_samples}

        return try_generate(self, contents, generation_config)

    def generate(
        self, requests: Sequence[GenerationRequest], num_samples: int = 1
    ) -> Iterable[GenerationResult]:
        results = []
        for request in requests:
            chat_completion = self.call_api(request, num_samples)
            if not os.getenv("DISABLE_LOG_USAGE"):
                assert chat_completion.usage_metadata

                data = {
                    "prompt_token_count": chat_completion.usage_metadata.prompt_token_count,
                    "candidates_token_count": chat_completion.usage_metadata.candidates_token_count,
                    "total_token_count": chat_completion.usage_metadata.total_token_count,
                }
                data["price"] = {
                    key: int(data[key]) * PRICING[self.model_name][key]
                    for key in PRICING[self.model_name]
                }
                data["price"]["total"] = sum(data["price"].values())

                (
                    self.logging_dir / f"{self.model_name}_{self.call_count}.txt"
                ).write_text(json.dumps(data, indent=4, cls=CustomJSONEncoder))

            self.call_count += 1
            for candidate in chat_completion.candidates:
                result = GenerationResult(
                    request=request, generation=candidate.content.parts[0].text
                )
                results.append(result)
                (
                    self.logging_dir
                    / f"response_{self.model_name}_{self.call_count}.log"
                ).write_text(
                    json.dumps(
                        dataclasses.asdict(result), indent=4, cls=CustomJSONEncoder
                    )
                )
        return results
