#!/usr/bin/python3
"""
Azure OpenAI large language models (LLM) as backbone optimizers.

Author(s):
    Anonymized Authors @anonymized-authors

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
import os
import tiktoken
from openai import AzureOpenAI
from pathlib import Path
from pydantic import BaseModel
from tenacity import retry, stop_after_attempt, wait_random
from transformers import PreTrainedTokenizer
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

from .openai import OpenAIOptimizer
from .state import OptimizerState
from ..envs.base import BaseTask


class AzureOpenAIOptimizer(OpenAIOptimizer):
    max_context_length: int = NotImplemented

    def __init__(
        self,
        task: BaseTask,
        model_id: str,
        tokenizer: Union[PreTrainedTokenizer, tiktoken.Encoding],
        batch_size: int = 1,
        max_batch_size: int = 8,
        temperature: float = 1.0,
        max_new_tokens: int = 2048,
        seed: int = 2025,
        promptdir: Union[Path, str] = os.path.join(
            os.path.dirname(__file__), "prompts"
        ),
        system_prompt_version: str = "base",
        user_prompt_version: str = "base",
        reflection: bool = True,
        ablate_distribution_shift: bool = False,
        reasoning_effort: Optional[Literal["low", "medium", "high"]] = None,
        **kwargs: Dict[str, Any]
    ):
        """
        Args:
            task: the optimization task.
            model_id: OpenAI model ID to use.
            tokenizer: the model tokenizer to use for encoding text.
            batch_size: batch size for optimization.
            max_batch_size: the maximum batch size per OpenAI API call.
            temperature: temperature parameter for text generation.
            max_new_tokens: maximum number of new tokens to generate.
            seed: random seed. Default 2025.
            promptdir: directory containing the prompt templates.
            system_prompt_version: version of the system prompt to load.
            user_prompt_version: version of the user prompt to load.
            reflection: whether to perform reflection as in Ma YJ, et al. Proc
                ICLR (2024).
            ablate_distribution_shift: whether to ablate knowledge of the
                distribution shift.
            reasoning_effort: reasoning effort for the model.
        """
        super(AzureOpenAIOptimizer, self).__init__(
            task,
            model_id=model_id,
            tokenizer=tokenizer,
            client=AzureOpenAI(
                api_version=os.getenv("AZURE_API_VERSION", ""),
                azure_endpoint=os.getenv("API_ENDPOINT_CHAT", ""),
                api_key=os.getenv("API_KEY_CHAT", "")
            ),
            batch_size=batch_size,
            max_batch_size=max_batch_size,
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            seed=seed,
            promptdir=promptdir,
            system_prompt_version=system_prompt_version,
            user_prompt_version=user_prompt_version,
            reflection=reflection,
            ablate_distribution_shift=ablate_distribution_shift,
            reasoning_effort=reasoning_effort,
            use_json_schema=True,
            **kwargs
        )


class GPT4oMiniOptimizer(AzureOpenAIOptimizer):
    optimizer_name: str = "gpt-4o-mini"

    max_context_length: int = 64000

    def __init__(
        self,
        task: BaseTask,
        batch_size: int = 1,
        temperature: float = 1.0,
        max_new_tokens: int = 2048,
        seed: int = 2025,
        promptdir: Union[Path, str] = os.path.join(
            os.path.dirname(__file__), "prompts"
        ),
        system_prompt_version: str = "base",
        user_prompt_version: str = "base",
        reflection: bool = True,
        ablate_distribution_shift: bool = False,
        **kwargs: Dict[str, Any]
    ):
        """
        Args:
            task: the optimization task.
            batch_size: batch size for optimization.
            temperature: temperature parameter for text generation.
            max_new_tokens: maximum number of new tokens to generate.
            seed: random seed. Default 2025.
            promptdir: directory containing the prompt templates.
            system_prompt_version: version of the system prompt to load.
            user_prompt_version: version of the user prompt to load.
            reflection: whether to perform reflection as in Ma YJ, et al. Proc
                ICLR (2024).
            ablate_distribution_shift: whether to ablate knowledge of the
                distribution shift.
        """
        super(GPT4oMiniOptimizer, self).__init__(
            task=task,
            model_id=self.optimizer_name,
            tokenizer=tiktoken.encoding_for_model(self.optimizer_name),
            batch_size=batch_size,
            max_batch_size=32,
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            seed=seed,
            promptdir=promptdir,
            system_prompt_version=system_prompt_version,
            user_prompt_version=user_prompt_version,
            reflection=reflection,
            ablate_distribution_shift=ablate_distribution_shift,
            reasoning_effort=None,
            **kwargs
        )


class o4MiniOptimizer(AzureOpenAIOptimizer):
    optimizer_name: str = "o4-mini"

    max_context_length: int = 100000

    def __init__(
        self,
        task: BaseTask,
        batch_size: int = 1,
        temperature: float = 1.0,
        max_new_tokens: int = 8192,
        seed: int = 2025,
        promptdir: Union[Path, str] = os.path.join(
            os.path.dirname(__file__), "prompts"
        ),
        system_prompt_version: str = "base",
        user_prompt_version: str = "base",
        reflection: bool = True,
        ablate_distribution_shift: bool = False,
        reasoning_effort: Literal["low", "medium", "high"] = "high",
        **kwargs: Dict[str, Any]
    ):
        """
        Args:
            task: the optimization task.
            batch_size: batch size for optimization.
            temperature: temperature parameter for text generation.
            max_new_tokens: maximum number of new tokens to generate.
            seed: random seed. Default 2025.
            promptdir: directory containing the prompt templates.
            system_prompt_version: version of the system prompt to load.
            user_prompt_version: version of the user prompt to load.
            reflection: whether to perform reflection as in Ma YJ, et al. Proc
                ICLR (2024).
            ablate_distribution_shift: whether to ablate knowledge of the
                distribution shift.
            reasoning_level: reasoning level for the model.
        """
        super(o4MiniOptimizer, self).__init__(
            task=task,
            model_id=self.optimizer_name,
            tokenizer=tiktoken.get_encoding("cl100k_base"),
            batch_size=batch_size,
            max_batch_size=8,
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            seed=seed,
            promptdir=promptdir,
            system_prompt_version=system_prompt_version,
            user_prompt_version=user_prompt_version,
            reflection=reflection,
            ablate_distribution_shift=ablate_distribution_shift,
            reasoning_effort=reasoning_effort,
            **kwargs
        )

    @retry(wait=wait_random(min=45, max=180), stop=stop_after_attempt(6))
    def forward(
        self, state: OptimizerState, knowledge: Dict[str, str], **kwargs
    ) -> Tuple[List[BaseModel], Dict[str, int]]:
        return self._forward(state, knowledge, **kwargs)
