import dataclasses
import json
import os
import time
from pathlib import Path
from typing import Iterable, Optional, Sequence

import google.generativeai as genai
from google.generativeai.types import GenerateContentResponse

from llm_mcts.file_logging import get_logging_dir
from llm_mcts.llm_generation_interface import GenerationRequest, GenerationResult, Model
from llm_mcts.models.custom_json_encoder import CustomJSONEncoder


# Reference: https://ai.google.dev/pricing#1_5pro
PRICING = {
    "gemini-1.5-pro-002": {
        "prompt_token_count": 1.25 / 1e6,
        "candidates_token_count": 5.0 / 1e6,
    },
}


class GeminiAPIModel(Model):
    def __init__(
        self,
        model: str = "gemini-1.5-pro-002",
        temperature: float = 0.7,
        logging_dir: Optional[Path] = None,
        num_trial: int = 5,
    ) -> None:
        genai.configure(api_key=os.environ["GEMINI_API_KEY"])
        self.generation_config = dict(temperature=temperature)
        self.model_name = model
        self.model = genai.GenerativeModel(model_name=model)

        self.logging_dir = (
            get_logging_dir(os.getpid()) if logging_dir is None else logging_dir
        )
        if not self.logging_dir.exists():
            self.logging_dir.mkdir()

        self.call_count = 0
        self.num_trial = num_trial

    def call_api(
        self, request: GenerationRequest, num_samples: int
    ) -> GenerateContentResponse:
        contents = []
        for message in request.messages:
            content = dataclasses.asdict(message)
            if (
                "assistant" in content
            ):  # Gemini API uses "model" instead of "assistant" as a role
                content["model"] = content.pop("assistant")

            if "content" in content:
                content["parts"] = content.pop("content")
            contents.append(content)

        # num_samples is called "candidate_count" in Gemini API
        generation_config = self.generation_config | {"candidate_count": num_samples}

        for i in range(self.num_trial):
            chat_completion = self.model.generate_content(
                contents=contents,
                generation_config=genai.types.GenerationConfig(**generation_config),
            )

            request_failed = False
            for candidate in chat_completion.candidates:
                # For some reason the content parts is not returned in some occasion, so we retry in that case
                if len(candidate.content.parts) == 0:
                    request_failed = True
                    break

            if not request_failed:
                return chat_completion

            time.sleep(10 * (i + 1))

        raise RuntimeError(f"Gemini request failed after {self.num_trial} attempts")

    def generate(
        self, requests: Sequence[GenerationRequest], num_samples: int = 1
    ) -> Iterable[GenerationResult]:
        results = []
        for request in requests:
            chat_completion = self.call_api(request, num_samples)
            if not os.getenv("DISABLE_LOG_USAGE"):
                assert chat_completion.usage_metadata

                data = {
                    "prompt_token_count": chat_completion.usage_metadata.prompt_token_count,
                    "candidates_token_count": chat_completion.usage_metadata.candidates_token_count,
                    "total_token_count": chat_completion.usage_metadata.total_token_count,
                }
                data["price"] = {
                    key: int(data[key]) * PRICING[self.model_name][key]
                    for key in PRICING[self.model_name]
                }
                data["price"]["total"] = sum(data["price"].values())

                (
                    self.logging_dir / f"{self.model_name}_{self.call_count}.txt"
                ).write_text(json.dumps(data, indent=4, cls=CustomJSONEncoder))

            self.call_count += 1
            for candidate in chat_completion.candidates:
                result = GenerationResult(
                    request=request, generation=candidate.content.parts[0].text
                )
                results.append(result)
                (
                    self.logging_dir
                    / f"response_{self.model_name}_{self.call_count}.log"
                ).write_text(
                    json.dumps(
                        dataclasses.asdict(result), indent=4, cls=CustomJSONEncoder
                    )
                )
        return results
