#!/usr/bin/python3
"""
Google Gemini large language models (LLM) as backbone optimizers.

Author(s):
    Anonymized Authors @anonymized-authors

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
import json
import logging
import numpy as np
import os
from google import genai
from google.genai import types
from hashlib import sha256
from pathlib import Path
from pydantic import BaseModel
from tenacity import retry, stop_after_attempt, wait_random
from typing import Any, Dict, Final, List, Optional, Tuple, Type, Union

from .llm import BaseLLMOptimizer, load_reflection_prompt
from .state import OptimizerState
from ..envs.base import BaseTask


logger = logging.getLogger(__name__)


if os.getenv("GEMINI_LOGGING", "False").title() != "True":
    import logging
    logger = logging.getLogger("google_genai.types")
    logger.setLevel(level=logging.ERROR)


class GeminiOptimizer(BaseLLMOptimizer):
    optimizer_name: str = "gemini-2.5-flash-preview-05-20"

    def __init__(
        self,
        task: BaseTask,
        model_id: Optional[str] = None,
        batch_size: int = 1,
        temperature: float = 1.0,
        max_new_tokens: int = 4096,
        seed: int = 2025,
        promptdir: Union[Path, str] = os.path.join(
            os.path.dirname(__file__), "prompts"
        ),
        system_prompt_version: str = "base",
        user_prompt_version: str = "base",
        reflection: bool = True,
        ablate_distribution_shift: bool = False,
        thinking_budget: int = -1,
        **kwargs: Dict[str, Any]
    ):
        """
        Args:
            task: the optimization task.
            model_id: Google Gemini model ID to use. Default
                `gemini-2.5-flash-preview-05-20`.
            batch_size: batch size for optimization.
            temperature: temperature parameter for text generation.
            max_new_tokens: maximum number of new tokens to generate.
            seed: random seed. Default 2025.
            promptdir: directory containing the prompt templates.
            system_prompt_version: version of the system prompt to load.
            user_prompt_version: version of the user prompt to load.
            reflection: whether to perform reflection as in Ma YJ, et al. Proc
                ICLR (2024).
            ablate_distribution_shift: whether to ablate knowledge of the
                distribution shift.
            thinking_budget: number of allowed internal thinking tokens.
                Default dynamic.
        """
        if model_id is not None:
            self.optimizer_name = model_id

        super(GeminiOptimizer, self).__init__(
            task,
            batch_size,
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            seed=seed,
            promptdir=promptdir,
            system_prompt_version=system_prompt_version,
            user_prompt_version=user_prompt_version,
            reflection=reflection,
            ablate_distribution_shift=ablate_distribution_shift,
            **kwargs
        )
        self.design_schema: Final[Type[BaseModel]] = task.design_schema

        self._client = genai.Client(
            api_key=os.getenv("GEMINI_API_KEY", "")
        )

        self.safety_config = [
            types.SafetySetting(
                category=category,
                threshold=types.HarmBlockThreshold.BLOCK_NONE
            )
            for category in [
                types.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
                types.HarmCategory.HARM_CATEGORY_HARASSMENT,
                types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
                types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
                types.HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY
            ]
        ]

        self.max_batch_size: Final[int] = 8  # Set by Gemini API.
        self.thinking_budget: Final[int] = thinking_budget

    @retry(wait=wait_random(min=45, max=90), stop=stop_after_attempt(6))
    def forward(
        self,
        state: OptimizerState,
        knowledge: Dict[str, str],
        **kwargs: Dict[str, Any]
    ) -> Tuple[List[BaseModel], Dict[str, int]]:
        """
        Returns a new batch of candidates to evaluate.
        Input:
            state: the current optimizer state.
            knowledge: the prior knowledge to use for the optimization.
        Returns:
            A batch of candidates to evaluate of shape B, where B is the
            sampling batch size, and a dictionary of metadata.
        """
        del kwargs
        num_input_tokens, num_output_tokens = 0, 0
        if self.reflection:
            reflection_prompt = load_reflection_prompt()
            # Only reflect on the most recent batch of designs.
            reflection_query = reflection_prompt.format(
                memory=state.memory.iloc[-self.batch_size:].to_markdown(),
                task=self.task.task_description(self.ablate_distribution_shift)
            )
            response, reflection_metadata = self.chat(
                messages=[
                    {"role": "system", "content": self.system_prompt},
                    {"role": "user", "content": reflection_query}
                ],
                tools=[]
            )
            num_input_tokens += max(
                getattr(reflection_metadata, "num_input_tokens", -1), 0
            )
            num_output_tokens += max(
                getattr(reflection_metadata, "num_output_tokens", -1), 0
            )
            knowledge["reflection"] = response["content"]
            logger.info(f"  Reflection: {knowledge['reflection']}")

        optim_prompt = self.user_prompt.format(
            memory=state.memory.to_markdown(), **knowledge
        )

        out = []
        qbatch = self.batch_size // self.max_batch_size
        rbatch = self.batch_size % self.max_batch_size
        for batch_idx, candidate_count in enumerate(
            ([self.max_batch_size] * qbatch) + [rbatch]
        ):
            if candidate_count == 0:
                continue
            resp = self._client.models.generate_content(
                model=self.optimizer_name,
                config=types.GenerateContentConfig(
                    system_instruction=self.system_prompt,
                    max_output_tokens=self.max_new_tokens,
                    temperature=self.temperature,
                    candidate_count=candidate_count,
                    seed=(self.seed + (batch_idx * candidate_count)),
                    response_mime_type="application/json",
                    response_schema=self.design_schema,
                    safety_settings=self.safety_config,
                    thinking_config=types.ThinkingConfig(
                        thinking_budget=self.thinking_budget
                    )
                ),
                contents=optim_prompt
            )
            out.extend([
                self.design_schema.model_validate_json(
                    resp.candidates[i].content.parts[0].text  # type: ignore
                )
                for i in range(len(resp.candidates))  # type: ignore
                if resp.candidates[i].finish_reason == (  # type: ignore
                    types.FinishReason.STOP
                )
            ])
            prompt_token_count = getattr(
                resp.usage_metadata, "prompt_token_count", None
            )
            if prompt_token_count is not None:
                num_input_tokens += prompt_token_count
            candidates_token_count = getattr(
                resp.usage_metadata, "candidates_token_count", None
            )
            if candidates_token_count is not None:
                num_output_tokens += candidates_token_count

        self._x, self._y = [], np.array([])
        return out, {
            "num_input_tokens": num_input_tokens,
            "num_output_tokens": num_output_tokens
        }

    @retry(wait=wait_random(min=45, max=90), stop=stop_after_attempt(6))
    def chat(
        self,
        messages: List[Dict[str, str]],
        tools: List[Dict[str, Any]],
        max_retries: int = 4
    ) -> Tuple[Dict[str, Any], Dict[str, int]]:
        """
        Chat with the parent LLM.
        Input:
            messages: the input messages to the LLM.
            tools: the tools available to the LLM.
            max_retries: maximum number of retries for the LLM call.
        Returns:
            The response from the LLM and associated metadata.
        """
        gtools: Optional[types.Tool] = None
        if tools:
            funcs = [
                tool["function"] for tool in tools
                if tool["type"] == "function"
            ]
            _ = [tool.pop("strict", None) for tool in funcs]
            _ = [
                tool.get("parameters", {}).pop("additionalProperties", None)
                for tool in funcs
            ]
            gtools = types.Tool(function_declarations=funcs)
        tool_config = types.ToolConfig(
            function_calling_config=types.FunctionCallingConfig(
                mode="AUTO",  # type: ignore
            )
        )
        config = types.GenerateContentConfig(
            max_output_tokens=self.max_new_tokens,
            temperature=0,
            seed=self.seed,
            tools=([gtools] if gtools else None),
            tool_config=(tool_config if gtools else None),
            safety_settings=self.safety_config
        )

        gmsgs: List[Union[types.Content, str]] = []
        for i, msg in enumerate(messages):
            if msg["role"] == "system":
                config.system_instruction = msg["content"]
            elif msg["role"] == "user":
                gmsgs.append(
                    types.Content(
                        role="user", parts=[types.Part(text=msg["content"])]
                    )
                )
            elif msg["role"] == "assistant":
                gmsgs.append(msg["content"])
            elif msg["role"] == "tool":
                prev_tool_call = messages[i - 1]["tool_calls"][0]
                func_name = prev_tool_call["function"]["name"]  # type: ignore
                gmsgs.append(
                    types.Content(
                        role="user",
                        parts=[
                            types.Part.from_function_response(
                                name=func_name,
                                response={"result": msg["content"]}
                            )
                        ]
                    )
                )

        parts: Optional[List[types.Part]] = None
        num_retries = 0
        while not parts:
            if num_retries > max_retries:
                break
            resp = self._client.models.generate_content(
                model=self.optimizer_name,
                config=config,
                contents=gmsgs  # type: ignore
            )
            try:
                parts = resp.candidates[0].content.parts  # type: ignore
                break
            except TypeError:
                num_retries += 1
                continue

        if parts is None or len(parts) == 0:
            fail = {"role": "assistant", "content": "", "tool_calls": None}
            return fail, {"num_input_tokens": -1, "num_output_tokens": -1}

        out: Dict[str, Any] = {"content": resp.text}
        func_call = parts[0].function_call
        if func_call is None:
            out = {
                "role": "assistant",
                "content": parts[0].text,
                "tool_calls": None
            }
        else:
            id_ = func_call.id
            if id_ is None:
                id_ = str(
                    sha256(getattr(func_call, "name", "").encode()).hexdigest()
                )
            out = {
                "role": "assistant",
                "content": getattr(resp, "candidates", [])[0].content,
                "tool_calls": [{
                    "id": id_,
                    "function": {
                        "name": func_call.name,
                        "arguments": json.dumps(func_call.args)
                    }
                }]
            }
        return out, {
            "num_input_tokens": getattr(
                resp.usage_metadata, "prompt_token_count", -1
            ),
            "num_output_tokens": getattr(
                resp.usage_metadata, "candidates_token_count", -1
            )
        }
