import asyncio
import functools
import os
import re
import time
from concurrent.futures import TimeoutError as FutureTimeoutError
from threading import Lock
from typing import Any, Callable, List

import aiolimiter
from google import genai
from google.genai import types as genai_types
from google.genai.errors import APIError
from google.genai.types import SafetySetting
from pydantic import BaseModel

from llms.generation_config import GenerationConfig
from llms.providers.google.constants import (
    MAX_GENERATION_PER_BATCH,
    MAX_REQUESTS_PER_MINUTE,
    MAX_THINKING_BUDGETS,
    THINKING_MODELS_STRS,
)
from llms.providers.google.dummy_api_error import TestAPIError
from llms.providers.google.google_client_manager import get_client_manager
from llms.providers.google.prompter import GooglePrompter
from llms.retry_utils import retry_with_exponential_backoff
from llms.types import Cache, ContentItem, Message
from utils.image_utils import any_to_pil
from utils.logger_utils import logger
from utils.timing_utils import timeit

# ===============================================================================
# Globals
# ===============================================================================
# NOTE: It has been more convenient to have a file of functions and globals than
# an object-oriented solution due to multiple differences between providers that
# makes little of the code re-usable between them.
# The general structure is similar among providers though:
# 1) Convert prompt messages and generation config from uniform to provider-specific format
# 2) Call the API 'num_generations' times
# 3) Convert the API response back to list of messages in uniform format
# +: Provider-specific logic for error handling and retry with exponential backoff


# --- State control flow ---
RESET_PROMPT = False  # Controls whether to reset the prompt messages. Important if uploading files.
PAYLOAD_TOO_LARGE = False  # Controls if should upload parts of the prompt to the cloud.

# Global cache storing the provider-specific prompt messages, gen configs, api responses.
# This reduces overhead of prompt conversions and also helps control flow during multiple generations.
cache = Cache()

# --- Handling retries with exponential backoff ---
MAX_DELAY = 60 * 1.5  # Maximum delay between retries
MAX_RETRIES = 1  # Max retries before switching to a new API key
BASE_DELAY = 60 // 2  # Initial delay in the exponential backoffs
BASE_DELAY_THROTTLED = 30  # Delay in the exponential backoffs for throttled executions

MAX_WAIT_PER_GEN = 3 * 60  # Maximum wait time for each generation
MAX_API_WAIT_TIME = 10 * 60  # Max wait time for overall API call before flagging as failed. Usually less than this.
# timeout is given by `min(MAX_WAIT_PER_GEN * num_generations, MAX_API_WAIT_TIME)`

# --- Batch generation / throttled execution configs ---
# Config for THROTTLED_EXECUTION for concurrency synchronization
MIN_DELAY_RESET = MAX_WAIT_PER_GEN
LOCKS_PER_PROCESS = {}
THROTTLED_EXECUTION = False
API_KEY_ERROR_COUNT = {}

# Safety settings
safety_settings = [
    SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="BLOCK_ONLY_HIGH"),  # type: ignore
    SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_ONLY_HIGH"),  # type: ignore
    SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="BLOCK_ONLY_HIGH"),  # type: ignore
    SafetySetting(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="BLOCK_ONLY_HIGH"),  # type: ignore
]

# ==============================================================================
# LINK: Provider-specific Error handling and retry logic
# ==============================================================================
# Google API error documentation: https://github.com/googleapis/python-genai/blob/main/google/genai/errors.py

# By default, always retry, apply exponential backoff, increment retries
# The handlers below can override this behavior depending on error


def handle_custom_errors(e: Exception, *args: Any, **kwargs: Any) -> tuple[Exception, bool, bool, bool]:
    """Handle errors that are not due to the API call.

    Args:
        e (Exception): Error to handle

    Returns:
        tuple[Exception, bool, bool, bool]:
        `e`: Error to raise in case of no retry
        `should_retry`: Whether to retry the API call
        `apply_delay`: Whether to apply exp backoff delay before retrying
        `increment_retries`: Whether to increment the number of retries
    """

    if isinstance(e, TimeoutError):
        # If API just took too long to respond, and num_retries < max_retries, retry without delay
        logger.info(f"Google API didn't respond after {MAX_API_WAIT_TIME} seconds. Retrying...")
        should_retry, apply_delay, increment_retries = True, False, True
    else:
        # All other errors: retry with exponential backoff and increment the number of retries
        should_retry, apply_delay, increment_retries = True, True, True

    return e, should_retry, apply_delay, increment_retries


def parse_quota_error(e: APIError) -> str:
    if not hasattr(e, "details"):
        return ""
    details = e.details.get("error", {}).get("details", [])
    for detail in details:
        detail_type = detail.get("@type", "")
        if detail_type != "type.googleapis.com/google.rpc.QuotaFailure":
            continue
        for violation in detail.get("violations", []):
            quota_id = violation.get("quotaId", "")
            logger.info(f"Quota violation: details: {details}")
            logger.info(f"Quota violation: violation: {violation}")
            if re.match(".*PerDay*.", quota_id, re.IGNORECASE):
                return "day"
            elif re.match(".*PerMinute*.", quota_id, re.IGNORECASE):
                return "minute"
    return ""


def _test_api_key(p_id: int = 0, model: str = ""):
    if not model:
        model = "gemini-2.0-flash-001"
    valid_key = False
    try:
        client = get_client_manager(p_id).get_client()

        prompt = "What's the capital of France?"

        _ = client.models.generate_content(model=model, contents=[prompt])
        valid_key = True
    except APIError as e:
        valid_key = False

    except Exception as e:
        raise e
    return valid_key


def reset_client(
    p_id: int = 0,
    increment_api_key_retry_count: int = 1,
    meta_data: dict[str, Any] = {},
    model: str = "",
):
    global THROTTLED_EXECUTION, RESET_PROMPT, API_KEY_ERROR_COUNT
    call_id = meta_data.get("call_id", 0)

    if not THROTTLED_EXECUTION:
        client_manager = get_client_manager(p_id)
        client_manager.api_keys_retry_count[hash(client_manager.api_key)] += increment_api_key_retry_count
        client_manager.reset_api_key()
        RESET_PROMPT = True

    else:
        global LOCKS_PER_PROCESS
        if p_id not in LOCKS_PER_PROCESS:
            LOCKS_PER_PROCESS[p_id] = Lock()

        sleep = False
        with LOCKS_PER_PROCESS[p_id]:
            client_manager = get_client_manager(p_id)
            api_key = client_manager.api_key
            start_time = client_manager.get_start_time(api_key)
            elapsed = time.time() - start_time

            # If other async call reset the client, just wait for it to finish
            if api_key != meta_data.get("api_key_used_in_call", ""):
                sleep = True
                logger.info(f"process {p_id}, call_id {call_id}: API key changed. Resetting prompt.")
                _ = GooglePrompter.reset_prompt(meta_data["provider_msgs"], p_id=p_id, all_images=False)

            # Elif, reset the key if sufficient time has passed
            elif elapsed > MIN_DELAY_RESET or API_KEY_ERROR_COUNT.get(api_key, 0) == 0:
                API_KEY_ERROR_COUNT[api_key] = 1

                valid_key = False
                num_tries = 1
                client_manager.api_keys_retry_count[hash(client_manager.api_key)] += increment_api_key_retry_count
                while not valid_key:
                    client_manager.reset_api_key()
                    add_text = "" if num_tries == 1 else "Previous api_key not valid. "
                    logger.info(f"Proccess {p_id}, call_id {call_id}: {add_text}Resetting client. Attempt {num_tries}.")
                    num_tries += 1
                    valid_key = _test_api_key(p_id, model=model)
                    new_key = client_manager.api_key
                    if not valid_key:
                        client_manager.api_keys_retry_count[hash(new_key)] += 1
                    else:
                        API_KEY_ERROR_COUNT[new_key] = API_KEY_ERROR_COUNT.get(new_key, 0)
                        logger.info(f"Proccess {p_id}, call_id {call_id}: API reset successful.")
                        logger.info(f"Proccess {p_id}, call_id {call_id}: Resetting prompt.")
                        _ = GooglePrompter.reset_prompt(meta_data["provider_msgs"], p_id=p_id, all_images=False)
            else:
                sleep = True
                API_KEY_ERROR_COUNT[api_key] = API_KEY_ERROR_COUNT.get(api_key, 0) + 1

        if sleep:
            if API_KEY_ERROR_COUNT[api_key] > 1:
                # Safeguard; shouldn't happen if code is correct
                logger.warning(f"ATTENTION: retried API key that is not working more than 1 times without reset.")
                time.sleep(MIN_DELAY_RESET - elapsed)
            else:
                time.sleep(10)


def handle_api_errors(e: APIError | TestAPIError, *args: Any, **kwargs: Any) -> tuple[Exception, bool, bool, bool]:
    """Handle errors from the provider API.

    Args:
        e (APIError): Error due to the API call

    Returns:
        tuple[Exception, bool, bool, bool]:
        `e`: Error to raise in case of no retry
        `should_retry`: Whether to retry the API call
        `apply_delay`: Whether to apply exp backoff delay before retrying
        `increment_retries`: Whether to increment the number of retries
    """

    global PAYLOAD_TOO_LARGE, API_KEY_ERROR_COUNT
    try:
        p_id = kwargs.get("process_id", 0)
        meta_data = kwargs.get("meta_data", {})
        model = kwargs.get("model", "")
        call_id = meta_data.get("call_id", 0)
        api_key = meta_data.get("api_key_used_in_call", "")
        # If error due to payload too large, retry without exponential backoff.
        if e.message and re.search("payload", e.message, re.IGNORECASE):
            logger.error(f"Google API error: call_id {call_id}, proccess {p_id}: {e.message}. Payload too large.")
            PAYLOAD_TOO_LARGE = True
            should_retry, apply_delay, increment_retries = True, False, False

        elif quota_limit := parse_quota_error(e) or (e.message and "quota" in e.message.lower()):
            logger.error(
                f"Google API error: call_id {call_id}, proccess {p_id}: {e}.\nQuota limit: {quota_limit}.\nAPI key: {api_key}",
                exc_info=False,
            )
            if quota_limit == "day":
                reset_client(p_id, increment_api_key_retry_count=float("inf"), meta_data=meta_data, model=model)
                should_retry, apply_delay, increment_retries = True, False, False
            else:
                should_retry, apply_delay, increment_retries = True, True, True

        elif e.message and re.search("key expired", e.message, re.IGNORECASE):
            logger.error(f"Google API error: call_id {call_id}, proccess {p_id}: {e.message}. Key expired.")
            reset_client(p_id, increment_api_key_retry_count=float("inf"), meta_data=meta_data, model=model)
            should_retry, apply_delay, increment_retries = True, False, False

        elif e.message and re.search("Deadline", e.message, re.IGNORECASE):
            logger.error(f"Google API error: call_id {call_id}, proccess {p_id}: {e}.")
            PAYLOAD_TOO_LARGE = True
            should_retry, apply_delay, increment_retries = True, False, True
            # reset_client(p_id)
            # NOTE: this error seems prompt dependent and never recovers
            # should_retry, apply_delay, increment_retries = False, False, False

        # If other invalid argument error, do not retry.
        elif e.status and re.search("invalid", e.status, re.IGNORECASE):
            logger.error(f"Google API error: call_id {call_id}, proccess {p_id}: {e}. Stopping generation.")
            should_retry, apply_delay, increment_retries = False, False, True
        else:
            # All other errors: retry with exponential backoff and increment the number of retries
            logger.error(
                f"Google API error: call_id {call_id}, proccess {p_id}: {e}. Retrying with exponential backoff..."
            )
            should_retry, apply_delay, increment_retries = True, True, True

        return e, should_retry, apply_delay, increment_retries
    except Exception as e:
        logger.error(f"{e}")
        return e, False, False, True


def handle_max_retries(e: Exception, *args: Any, **kwargs: Any) -> tuple[Exception, bool, bool, bool]:
    """Specific logic in case number of exp backoff retries is hit.

    Args:
        e (Exception): Error to handle

    Returns:
        tuple[Exception, bool, bool]: (`e`, `should_retry`, `apply_delay`)
        `e`: Error to raise in case of no retry
        `should_retry`: Whether to retry the API call
        `apply_delay`: Whether to apply exp backoff delay before retrying
    """

    global RESET_PROMPT
    try:
        logger.info(f"Max retries reached for API key. Last error: {e}.")

        # Update retry count for the current API key
        p_id = kwargs.get("process_id", 0)
        model = kwargs.get("model", "")
        meta_data = kwargs.get("meta_data", {})
        reset_client(p_id, model=model, meta_data=meta_data)
        should_retry, apply_delay, increment_retries = True, False, True

        return e, should_retry, apply_delay, increment_retries
    # If no API keys left or other errors, do not retry
    except Exception as e:
        logger.error(f"{e}")
        return e, False, False, True


retry_exp_backoff = functools.partial(
    retry_with_exponential_backoff,
    base_delay=BASE_DELAY,
    max_delay=MAX_DELAY,
    exp_base=1.5,
    jitter=True,
    max_retries=MAX_RETRIES,
    api_errors=(APIError, TestAPIError),
    custom_errors=(FutureTimeoutError,),
    handle_custom_errors=handle_custom_errors,
    handle_api_errors=handle_api_errors,
    handle_max_retries=handle_max_retries,
    max_workers=MAX_GENERATION_PER_BATCH,
    logger=logger,
)


# If API call doesnt return in min(MAX_WAIT_PER_GEN * num_generations, MAX_API_WAIT_TIME) seconds, retry
# This should be passed to the `retry_exp_backoff` decorator (see `sync_api_call`)
def timeout_getter(args: Any, kwargs: Any, key: str = "provider_gen_config") -> float:
    provider_gen_config: genai_types.GenerateContentConfig = kwargs.get(key)
    n = provider_gen_config.candidate_count if provider_gen_config.candidate_count else MAX_API_WAIT_TIME
    return min(MAX_WAIT_PER_GEN * n, MAX_API_WAIT_TIME)


def id_getter(args: Any, kwargs: Any, key: str = "process_id") -> str:
    process_id: int = kwargs.get(key, "")
    return str(process_id)


# ==============================================================================
# LINK: Output conversion: provider-specific -> uniform format
# ==============================================================================
def convert_single_part(part: genai_types.Part) -> ContentItem | None:
    """Convert a single part to a list of content items."""
    if part.text is not None:
        return ContentItem(type="text", data=part.text, raw_model_output=part)

    elif part.inline_data is not None and part.inline_data.data is not None:
        try:
            img = any_to_pil(part.inline_data.data)
            return ContentItem(type="image", data=img, raw_model_output=part)
        except Exception as e:
            logger.warning(f"Error converting inline_data to PIL Image: {e}")
            # raise e
            return None

    elif part.function_call is not None:
        return ContentItem(type="function_call", data=part.to_json_dict(), raw_model_output=part)

    elif part.thought is not None:
        return ContentItem(type="reasoning", data=part.to_json_dict(), raw_model_output=part)

    elif part.executable_code is not None:
        logger.warning(f"Executable code generated but not implemented yet: {part.executable_code}")
        return None

    else:
        logger.warning(f"Part type not implemented: type {type(part)}; part: {part}")
        return None


def convert_single_generation(
    candidate: genai_types.Candidate,
) -> Message | None:
    """
    Convert a single candidate to a list of content items.
    """
    if candidate.content is None:
        return None
    if candidate.content.parts is None:
        return None

    all_parsed_parts = []
    # Convert all outputs of a single generation
    for part in candidate.content.parts:
        parsed_part = convert_single_part(part)
        if parsed_part is not None:
            all_parsed_parts.append(parsed_part)
    if all_parsed_parts:
        return Message(role="assistant", name="", contents=all_parsed_parts)
    else:
        return None


def convert_generations(response: genai_types.GenerateContentResponse, response_schema: bool = False) -> list[Message]:
    all_generations = []
    if response.candidates:
        for candidate in response.candidates:
            msg = convert_single_generation(candidate)
            all_generations.append(msg) if msg else None
    return all_generations


# ==============================================================================
# LINK: Prompt messages conversion: uniform -> provider-specific format
# ==============================================================================


def get_provider_msgs(
    messages: List[Message],
    use_cache: bool = True,
    reset_prompt: bool = False,
    p_id: int = 0,
    force_upload: bool = False,
) -> List[genai_types.Content]:
    """
    Processes the input messages:
    - Converts the prompt using GooglePrompter
    - Resets the prompt if needed
    - Uploads images if payload is too large
    """
    global RESET_PROMPT, PAYLOAD_TOO_LARGE, cache, THROTTLED_EXECUTION

    if use_cache:
        # If no preprocessed messages, create them
        provider_msgs = cache.messages_to_provider
        if not provider_msgs:
            provider_msgs = GooglePrompter.convert_prompt(messages, p_id=p_id, force_upload=force_upload)
            cache.messages_to_provider = provider_msgs
    else:
        provider_msgs = GooglePrompter.convert_prompt(messages, p_id=p_id, force_upload=force_upload)

    # Re-create prompt only on specific flags
    if RESET_PROMPT or reset_prompt:
        logger.info("Resetting prompt...")
        if THROTTLED_EXECUTION:  # TODO debug and clean this away to use only reset_prompt with all_images=False
            provider_msgs = GooglePrompter.reset_prompt(provider_msgs, p_id=p_id, all_images=False)
        else:
            provider_msgs = GooglePrompter.reset_prompt(provider_msgs, p_id=p_id)
        if not reset_prompt:
            RESET_PROMPT = False
        if use_cache:
            cache.messages_to_provider = provider_msgs

    elif PAYLOAD_TOO_LARGE:
        logger.info("Payload too large. Uploading images...")
        if THROTTLED_EXECUTION:
            provider_msgs = GooglePrompter.upload_all_images(provider_msgs, p_id=p_id, force_upload=True)
        else:
            provider_msgs = GooglePrompter.upload_all_images(provider_msgs, p_id=p_id, force_upload=False)
        PAYLOAD_TOO_LARGE = False
        if use_cache:
            cache.messages_to_provider = provider_msgs

    return provider_msgs


# ===============================================================================
# LINK: Generation config conversion: uniform -> provider-specific format
# ===============================================================================


def regularize_thinking_budget(gen_config: GenerationConfig):
    if gen_config.thinking_budget is not None:
        for k, v in MAX_THINKING_BUDGETS.items():
            if k in gen_config.model:
                if gen_config.thinking_budget > v:
                    logger.warning(
                        f"Thinking budget regularized for model {gen_config.model}: {gen_config.thinking_budget} -> {v}"
                    )
                    gen_config.thinking_budget = v
                    break
        return gen_config.thinking_budget
    else:
        return None


def clean_schema(schema: dict) -> dict:
    """
    Recursively remove keys that are not allowed by the Schema model,
    such as "additionalProperties".
    """
    cleaned = {}
    for key, value in schema.items():
        if key == "additionalProperties":
            # Remove this key to prevent validation errors.
            continue
        if isinstance(value, dict):
            cleaned[key] = clean_schema(value)
        elif isinstance(value, list):
            cleaned[key] = [clean_schema(item) if isinstance(item, dict) else item for item in value]
        else:
            cleaned[key] = value
    return cleaned


def regularize_response_schema(original_schema: Any) -> dict[str, Any]:
    # Import tools for handling generic types and inspection
    import inspect
    from typing import get_args, get_origin

    # If already a valid JSON Schema dict, pass it through.
    if isinstance(original_schema, dict):
        return original_schema

    # If the schema is of the form list[SomeModel] (a generic alias)
    elif get_origin(original_schema) is list:
        # Get the inner type of the list (e.g. Translation)
        inner = get_args(original_schema)[0]
        if inspect.isclass(inner) and issubclass(inner, BaseModel):
            return {
                "type": "array",
                "items": inner.model_json_schema(),
            }
        elif isinstance(inner, dict):
            # In case someone passes list[dict] where dict is a valid JSON schema
            return {"type": "array", "items": inner}
        else:
            raise ValueError(f"Invalid list response schema type: {inner}")

    # If the schema is provided as a Pydantic model class
    elif inspect.isclass(original_schema) and issubclass(original_schema, BaseModel):
        return original_schema.model_json_schema()

    # If the schema is provided as an instance of a Pydantic model
    elif isinstance(original_schema, BaseModel):
        return type(original_schema).model_json_schema()

    else:
        raise ValueError(f"Invalid response schema: {original_schema}")


def gen_config_to_provider(gen_config: GenerationConfig) -> genai_types.GenerateContentConfig:
    """
    Convert the uniform generation configuration to Google API format.
    """
    # Generation arguments
    gen_args = {
        "candidate_count": gen_config.num_generations,
        "max_output_tokens": gen_config.max_tokens,
        "top_p": gen_config.top_p,
        "temperature": gen_config.temperature,
        "stop_sequences": gen_config.stop_sequences,
        "top_k": gen_config.top_k,
        "seed": gen_config.seed,
        "presence_penalty": gen_config.presence_penalty,
        "frequency_penalty": gen_config.frequency_penalty,
        "safety_settings": safety_settings,
        "response_modalities": gen_config.modalities,
    }
    if gen_config.response_schema:
        gen_args["response_schema"] = clean_schema(regularize_response_schema(gen_config.response_schema))
        gen_args["response_mime_type"] = "application/json"

    if gen_config.thinking_budget is not None:
        thinking_budget = regularize_thinking_budget(gen_config)
        if thinking_budget is not None:
            gen_args["thinking_config"] = genai_types.ThinkingConfig(thinking_budget=thinking_budget)

    provider_gen_config = genai_types.GenerateContentConfig(**gen_args)  # type: ignore
    return provider_gen_config


def regularize_provider_gen_config_for_model(
    model: str,
    provider_gen_config: genai_types.GenerateContentConfig,
) -> genai_types.GenerateContentConfig:
    """Regularize the provider generation configuration for model-specific settings.

    Args:
        model (str): Model name
        provider_gen_config (genai_types.GenerateContentConfig): Provider-specific generation configuration

    Returns:
        genai_types.GenerateContentConfig: Regularized generation configuration
    """

    # Model specific regularization
    if model == "gemini-2.0-flash-exp-image-generation":
        provider_gen_config.system_instruction = None
        provider_gen_config.candidate_count = 1
        # logger.warning(
        #     f"Warning: model {model} arguments regularized: {provider_gen_config.system_instruction} -> {None} and {provider_gen_config.candidate_count} -> {1}"
        # )

    else:
        provider_gen_config.response_modalities = ["Text"]

    if not any(model_str in model for model_str in THINKING_MODELS_STRS):
        provider_gen_config.thinking_config = None

    return provider_gen_config


def get_provider_gen_config(
    gen_config: GenerationConfig,
    provider_msgs: List[genai_types.Content],
    use_cache: bool = True,
) -> genai_types.GenerateContentConfig:
    """
    Constructs the generation configuration to be used in the API call.
    """
    global cache
    if not use_cache:
        return gen_config_to_provider(gen_config)

    # Else, try to get from cache
    if not cache.gen_config:
        provider_gen_config = gen_config_to_provider(gen_config)
        cache.gen_config = provider_gen_config
    else:
        provider_gen_config = cache.gen_config

    return provider_gen_config


# ==============================================================================
# Non-batch Generation
# ==============================================================================


def generate_from_google_chat_completion(
    messages: List[Message],
    gen_config: GenerationConfig,
    meta_data: dict[str, Any] = {},
) -> tuple[list[dict[str, Any]], list[Message]]:
    """Synchronous generation from Google API.

    This function:
     - Converts prompt messages and genconfig from uniform to provider-specific format.
     - Applies model-specific regularizations to genconfigs and messages.
     - Handles multiple generations of different modalities.
     - Converts model outputs back to uniform format.

    Args:
        messages (List[Message]): List of Message objects in uniform format to send to the model.
        gen_config (GenerationConfig): Generation configuration.

    Returns:
        tuple[list[Dict[str, Any]], list[Message]]: List of API responses and list of generated messages in uniform format
    """
    global MAX_GENERATION_PER_BATCH, cache
    cache.reset()

    # Number of generations remaining to be generated
    remaining_generation_count = gen_config.num_generations

    # Build provider messages and generation config on first call
    provider_gen_config = get_provider_gen_config(gen_config, get_provider_msgs(messages))
    # (obs.: needs to rebuild prompt on each retry to cover API resets; see `sync_api_call`)

    # Generate outputs
    logger.info(f"CALLING MODEL: `{gen_config.model}`: generating {gen_config.num_generations} outputs...")

    while remaining_generation_count > 0:
        provider_gen_config.candidate_count = min(MAX_GENERATION_PER_BATCH, remaining_generation_count)

        # Regularizing here handles where model supports only `num_generations=1` by calling the API `n` times
        provider_gen_config = regularize_provider_gen_config_for_model(gen_config.model, provider_gen_config)

        # Call the API
        response: genai_types.GenerateContentResponse
        response = sync_api_call(
            model=gen_config.model, messages=messages, provider_gen_config=provider_gen_config, meta_data={}
        )
        model_messages = convert_generations(response, gen_config.response_schema is not None)

        # Update cache and decrement remaining generation count
        if model_messages:
            cache.api_responses.append(response.model_dump(mode="json"))
            cache.model_messages.extend(model_messages)
            remaining_generation_count -= len(model_messages)

    return cache.api_responses, cache.model_messages


# NOTE: Keeping the final API call separate from main generation function
# for more isolated application of `retry_with_exponential_backoff`.
@timeit(custom_name=f"LLM:sync_{os.path.basename(__file__)}_api_call")
@retry_exp_backoff(timeout_getter=timeout_getter, id_getter=id_getter)
def sync_api_call(
    model: str,
    messages: List[Message],
    provider_gen_config: genai_types.GenerateContentConfig,
    process_id: int = 0,
    use_cache: bool = True,
    meta_data: dict[str, Any] = {},
) -> genai_types.GenerateContentResponse:
    """Synchronous API call to Google API."""

    if "api_key" in meta_data:
        process_id = meta_data["api_key"]
        use_cache = False

    # Get global client
    client_manager = get_client_manager(process_id, api_key=meta_data.get("api_key"))
    client = client_manager.get_client()

    # Get provider messages. Obs.: This caches, re-upload, reset prompts if needed.
    provider_msgs = get_provider_msgs(messages, use_cache=use_cache, p_id=process_id)

    # Obs.: this redundancy in system_instruction and regularization handle cases
    # Where API is redefined which must reset system_prompt message
    # (not currently the case, but possible if contain files or context caching)

    # For Google, system prompt goes into the generation config
    if provider_msgs and provider_msgs[0].parts:
        provider_gen_config.system_instruction = provider_msgs[0].parts[0]

    # Regularize provider generation config for model-specific settings.
    # (some models don't support a system_instruction)
    provider_gen_config = regularize_provider_gen_config_for_model(model, provider_gen_config)

    # @debug
    # # raise FutureTimeoutError("test")
    # raise TestAPIError("test")
    # Call the API
    response = client.models.generate_content(
        model=model,
        contents=provider_msgs[1:],  # type: ignore # all msgs except sys_prompt
        config=provider_gen_config,
    )
    return response


# ==============================================================================
# Batch generation
# ==============================================================================


async def async_api_call(
    model: str,
    messages: List[Message],
    provider_gen_config: GenerationConfig,
    process_id: int = 0,
    use_cache: bool = False,
    meta_data: dict[str, Any] = {},
) -> genai_types.GenerateContentResponse:
    """Asynchronous API call to Google API."""

    # Get provider messages. Obs.: This caches, re-upload, reset prompts if needed.
    provider_msgs = get_provider_msgs(messages, use_cache=use_cache, p_id=process_id)

    # Obs.: this redundancy in system_instruction and regularization handle cases
    # Where API is redefined which must reset system_prompt message
    # (not currently the case, but possible if contain files or context caching)

    # For Google, system prompt goes into the generation config
    if provider_msgs and provider_msgs[0].parts:
        provider_gen_config.system_instruction = provider_msgs[0].parts[0]

    # Regularize provider generation config for model-specific settings.
    # (some models don't support a system_instruction)
    provider_gen_config = regularize_provider_gen_config_for_model(model, provider_gen_config)

    # @debug
    # # raise FutureTimeoutError("test")
    # raise TestAPIError("test")
    # Get global client
    api_key = None
    start_time = time.time()
    while not api_key:
        client_manager = get_client_manager(process_id)
        client = client_manager.get_client()
        api_key = client_manager.api_key
        if time.time() - start_time > 30:
            raise Exception("Failed to get API key")
    meta_data["api_key_used_in_call"] = api_key
    meta_data["provider_msgs"] = provider_msgs

    # Call the API
    response = client.models.generate_content(
        model=model,
        contents=provider_msgs[1:],  # type: ignore # all msgs except sys_prompt
        config=provider_gen_config,
    )
    return response


async def _throttled_google_agenerate(
    limiter: aiolimiter.AsyncLimiter,
    messages: List[Message],
    gen_config: GenerationConfig,
    dump_conversation_fun=None,
    dump_usage_fun=None,
    process_id: int = 0,
    call_id: int = 0,
) -> genai_types.GenerateContentResponse | dict[str, Any]:
    async with limiter:
        provider_msgs = get_provider_msgs(messages, use_cache=False, p_id=process_id)
        provider_gen_config = get_provider_gen_config(gen_config, provider_msgs, use_cache=False)
        num_retries = 0
        try:
            while num_retries < MAX_RETRIES:
                num_retries += 1
                try:
                    # logger.info(f"[{__file__}] Calling model {gen_config.model} with call_id {call_id}")
                    meta_data = {"call_id": call_id}
                    resp = await asyncio.wait_for(
                        async_api_call(
                            model=gen_config.model,
                            messages=messages,
                            provider_gen_config=provider_gen_config,
                            process_id=process_id,
                            use_cache=False,
                            meta_data=meta_data,
                        ),
                        timeout=min(MAX_WAIT_PER_GEN * gen_config.num_generations, MAX_API_WAIT_TIME),
                    )
                    logger.info(f"[{__file__}] 1 generation successful for model {gen_config.model}")
                    if dump_conversation_fun or dump_usage_fun:
                        gen_config_dict = gen_config.to_dict()
                        if dump_conversation_fun:
                            model_messages = convert_generations(resp)
                            if model_messages:
                                dump_conversation_fun(messages, model_messages, gen_config_dict)
                        if dump_usage_fun:
                            dump_usage_fun([resp.model_dump(mode="json")], gen_config_dict)
                    return resp

                except APIError as e:
                    e, should_retry, apply_delay, increment_retries = handle_api_errors(
                        e, **{"model": gen_config.model, "meta_data": meta_data, "process_id": process_id}
                    )
                    if not should_retry:
                        logger.error(f"Error in async generation: {e}")
                        return {"failed": True, "request_kwargs": provider_gen_config}

                    if hasattr(e, "status") and re.search("permission", e.status, re.IGNORECASE):
                        logger.info(
                            f"Proccess {process_id}, call_id {call_id}: Permission denied during throttled generation."
                        )
                        try:
                            logger.info(f"Proccess {process_id}, call_id {call_id}: Resetting prompt.")
                            _ = GooglePrompter.reset_prompt(
                                meta_data["provider_msgs"], p_id=process_id, all_images=False
                            )
                            apply_delay, increment_retries = False, True
                        except Exception as e:
                            logger.error(f"Proccess {process_id}, call_id {call_id}: Error in async generation: {e}")
                            continue

                    if apply_delay:
                        logger.info(
                            f"Proccess {process_id}, call_id {call_id}: Sleeping for {BASE_DELAY_THROTTLED} seconds"
                        )
                        await asyncio.sleep(BASE_DELAY_THROTTLED)

                except TimeoutError as e:
                    e, should_retry, apply_delay, increment_retries = handle_custom_errors(e, process_id)
                    if not should_retry:
                        logger.error(f"Error in async generation: {e}")
                        return {"failed": True, "request_kwargs": provider_gen_config}

                except Exception as e:
                    logger.error(f"Proccess {process_id}, call_id {call_id}: Error in async generation: {e}")
                    return_dict = {"failed": True, "request_kwargs": provider_gen_config}
                    return return_dict
            return {"failed": True, "request_kwargs": provider_gen_config}
        except Exception as e:
            logger.error(f"Error in async generation: {e}")
            return {"failed": True, "request_kwargs": provider_gen_config}


def batch_generate_from_google(
    messages_list: list[list[Message]],
    gen_config: GenerationConfig,
    requests_per_minute: int = MAX_REQUESTS_PER_MINUTE,
    dump_conversation_funs=None,
    dump_usage_funs=None,
    process_id: int = 0,
) -> tuple[List[dict[str, Any]], List[List[Message]]]:
    """
    Args:
        prompt_batches: A list where each element is a list of messages to be sent to the API.
        gen_config: Generation configuration.
        requests_per_minute: Rate-limit for async requests.

    Returns:
        A tuple of:
          - List of raw JSON response objects.
          - List of generated text contents.
    """
    global cache, THROTTLED_EXECUTION
    cache.reset()
    THROTTLED_EXECUTION = True

    try:
        if gen_config.num_generations > 1:
            gen_config.num_generations = 1
            logger.warning("Setting num_generations to 1 for batch generation.")

        # Convert prompts without caching.
        # provider_msgs_list = [get_provider_msgs(messages, gen_config, use_cache=False) for messages in messages_list]
        # Create a rate-limiter.
        limiter = aiolimiter.AsyncLimiter(requests_per_minute)

        async def _async_generate() -> tuple[List[dict[str, Any]], List[List[Message]]]:
            # Use async context manager to ensure the async client is cleaned up on exit.
            tasks = [
                _throttled_google_agenerate(
                    limiter=limiter,
                    messages=messages,
                    gen_config=gen_config,
                    dump_conversation_fun=None if dump_conversation_funs is None else dump_conversation_funs[i],
                    dump_usage_fun=None if dump_usage_funs is None else dump_usage_funs[i],
                    process_id=process_id,
                    call_id=i,
                )
                for i, messages in enumerate(messages_list)
            ]
            logger.info(
                f"[{__file__}] Generating {len(messages_list)} calls in batch mode for model {gen_config.model}"
            )
            results: list[genai_types.GenerateContentResponse | dict[str, Any]] = await asyncio.gather(*tasks)
            all_api_responses = []
            all_model_messages = []
            for response in results:
                if isinstance(response, dict) and response.get("failed", False):
                    logger.warning(f"No generations returned for {response['request_kwargs']}")
                    continue

                model_messages = convert_generations(response)
                if model_messages:
                    all_api_responses.append(response.model_dump(mode="json"))
                    all_model_messages.extend(model_messages)
                else:
                    logger.warning(f"No generations returned for {response['request_kwargs']}")
                    continue
            return all_api_responses, all_model_messages

        # Run the asyncio task synchronously.
        return asyncio.run(_async_generate())
    except Exception as e:
        THROTTLED_EXECUTION = False
        raise e


# ==============================================================================
# Token counting
# ==============================================================================


def google_count_tokens(model: str, lm_input: str) -> int:
    client = get_client_manager().get_client()
    token_count = client.models.count_tokens(model=model, contents=lm_input)
    return token_count.total_tokens
