import json
import os

from PIL import Image
from ale_bench.utils import pil_to_base64jpeg
from dotenv import load_dotenv
from google import genai
from google.genai import types as google_genai_types

from common_resource import (
    EXP_ROOT_DIR, SYSTEM_PROMPT,
    BaseLLM, BaseNamespace,
    CustomJSONDecoder, CustomJSONEncoder,
    get_common_argument_parser, main_loop,
)


class GoogleGenAINamespace(BaseNamespace):
    model: str
    thinking_budget: int | None


class GoogleGenAIModel(BaseLLM):
    def __init__(
        self,
        client: genai.Client,
        model: str,
        thinking_budget: int | None,
        system_prompt: str | None,
        log_dir: str | os.PathLike
    ) -> None:
        super().__init__(system_prompt, log_dir)
        self.client = client
        self.model = model
        if model.startswith("gemini-1.5-") or model.startswith("gemini-2.0-"):
            self.generate_config = google_genai_types.GenerateContentConfig(
                system_instruction=system_prompt,
            )
        else:  # Thinking models
            self.generate_config = google_genai_types.GenerateContentConfig(
                system_instruction=system_prompt,
                thinking_config=google_genai_types.ThinkingConfig(
                    thinking_budget=thinking_budget,
                ) if thinking_budget is not None else None,
            )
        self.num_total_messages = len(self.messages)
        self._last_response_text = None

    def send_user_message(self, contents: list[str | Image.Image]):
        # Send a user message
        chat = self.client.chats.create(
            model=self.model, config=self.generate_config, history=self.messages,
        )  # NOTE: if `history` is None, set to empty list by default
        # Wait for the assistant's response
        response = None
        for i in range(1, self.num_retry + 1):
            try:
                response = chat.send_message(message=contents)
            except Exception as e:
                print(f"Trial {i} failed: {e}")
                continue
            if response is None or response.text is None or response.text == "":
                print(f"Trial {i} failed: no response from Google GenAI.")
                continue
            self._last_response_text = response.text
            self.messages = chat.get_history(curated=True)
            self.num_total_messages += 2
            json.dump(
                response.model_dump(),
                open(self.log_dir / f"response_turn{self.num_total_messages:08d}.json", "w"),
                ensure_ascii=False,
                cls=CustomJSONEncoder,
            )
            break
        if response is None:
            raise ValueError("Failed to get a valid response from Google GenAI.")

    def send_user_message_new_thread(self, contents: list[str | Image.Image]) -> None:
        self.messages = []
        self.send_user_message(contents)

    def get_last_response(self) -> str:
        if self._last_response_text is None:
            raise ValueError("No messages to get response from.")
        return self._last_response_text

    def load_history(self, file_path: str | os.PathLike) -> None:
        data = json.load(open(file_path, "r"), cls=CustomJSONDecoder)
        self.messages = [google_genai_types.Content.model_validate(c) for c in data["messages"]]
        self.generate_config = google_genai_types.GenerateContentConfig.model_validate(data["generate_config"])

    def save_history(self, file_path: str | os.PathLike) -> None:
        messages = [c.model_dump() for c in self.messages]
        with open(file_path, "w") as f:
            json.dump({
                "messages": messages,
                "generate_config": self.generate_config.model_dump(),
            }, f, ensure_ascii=False, cls=CustomJSONEncoder)


def parse_args() -> GoogleGenAINamespace:
    parser = get_common_argument_parser(description="Run Google GenAI model on ALE-Bench.")
    parser.add_argument("--model", type=str, required=True, help="Model name to run.")
    parser.add_argument("--thinking_budget", type=int, default=None, help="Thinking budget for the thinking models.")
    return parser.parse_args(namespace=GoogleGenAINamespace)


def get_client() -> genai.Client:
    return genai.Client(api_key=os.getenv("GEMINI_API_KEY"))


def main() -> None:
    # Load environment variables from .env file
    load_dotenv(EXP_ROOT_DIR / ".env")

    # Parse command line arguments
    args = parse_args()

    # Get the client
    llm = GoogleGenAIModel(
        client=get_client(),
        model=args.model,
        thinking_budget=args.thinking_budget,
        system_prompt=SYSTEM_PROMPT[args.prompt_language],
        log_dir=args.exp_dir / f"llm_log_{args.problem_id}"
    )

    # Start the ALE-Bench main loop
    main_loop(args, llm)


if __name__ == "__main__":
    main()
