#!/usr/bin/python3
"""
Azure OpenAI large language models (LLM) as backbone optimizers.

Author(s):
    Anonymized Authors @anonymized-authors

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
import logging
import numpy as np
import os
import tiktoken
from openai import AzureOpenAI, OpenAI
from pathlib import Path
from pydantic import BaseModel, ValidationError
from tenacity import retry, stop_after_attempt, wait_random_exponential
from transformers import PreTrainedTokenizer
from typing import (
    Any, Dict, Final, List, Literal, Optional, Tuple, Type, Union
)

from .llm import BaseLLMOptimizer, load_reflection_prompt
from .state import OptimizerState
from ..envs.base import BaseTask
from ..utils import json_loads


logger = logging.getLogger(__name__)


class OpenAIOptimizer(BaseLLMOptimizer):
    max_context_length: int = NotImplemented

    def __init__(
        self,
        task: BaseTask,
        model_id: str,
        tokenizer: Union[PreTrainedTokenizer, tiktoken.Encoding],
        client: Union[OpenAI, AzureOpenAI],
        batch_size: int = 1,
        max_batch_size: int = 8,
        temperature: float = 1.0,
        max_new_tokens: int = 2048,
        seed: int = 2025,
        promptdir: Union[Path, str] = os.path.join(
            os.path.dirname(__file__), "prompts"
        ),
        system_prompt_version: str = "base",
        user_prompt_version: str = "base",
        reflection: bool = True,
        ablate_distribution_shift: bool = False,
        reasoning_effort: Optional[Literal["low", "medium", "high"]] = None,
        use_json_schema: bool = True,
        **kwargs: Dict[str, Any]
    ):
        """
        Args:
            task: the optimization task.
            model_id: OpenAI model ID to use.
            tokenizer: the model tokenizer to use for encoding text.
            client: OpenAI or AzureOpenAI client to use for API calls.
            batch_size: batch size for optimization.
            max_batch_size: the maximum batch size per OpenAI API call.
            temperature: temperature parameter for text generation.
            max_new_tokens: maximum number of new tokens to generate.
            seed: random seed. Default 2025.
            promptdir: directory containing the prompt templates.
            system_prompt_version: version of the system prompt to load.
            user_prompt_version: version of the user prompt to load.
            reflection: whether to perform reflection as in Ma YJ, et al. Proc
                ICLR (2024).
            ablate_distribution_shift: whether to ablate knowledge of the
                distribution shift.
            reasoning_effort: reasoning effort for the model.
            use_json_schema: whether to use JSON schema for response format.
        """
        self.optimizer_name: str = model_id
        self.max_batch_size: Final[int] = max_batch_size
        self.reasoning_effort: Optional[str] = reasoning_effort
        self._tokenizer: Union[PreTrainedTokenizer, tiktoken.Encoding] = (
            tokenizer
        )
        self.use_json_schema: Final[bool] = use_json_schema

        super(OpenAIOptimizer, self).__init__(
            task,
            batch_size,
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            seed=seed,
            promptdir=promptdir,
            system_prompt_version=system_prompt_version,
            user_prompt_version=user_prompt_version,
            reflection=reflection,
            ablate_distribution_shift=ablate_distribution_shift,
            **kwargs
        )
        self.design_schema: Final[Type[BaseModel]] = task.design_schema
        self._client: Final[Union[OpenAI, AzureOpenAI]] = client

    @retry(
        wait=wait_random_exponential(min=1, max=60),
        stop=stop_after_attempt(6)
    )
    def forward(
        self, state: OptimizerState, knowledge: Dict[str, str], **kwargs
    ) -> Tuple[List[BaseModel], Dict[str, int]]:
        return self._forward(state, knowledge, **kwargs)

    def _forward(
        self, state: OptimizerState, knowledge: Dict[str, str], **kwargs
    ) -> Tuple[List[BaseModel], Dict[str, int]]:
        """
        Returns a new batch of candidates to evaluate.
        Input:
            state: the current optimizer state.
            knowledge: the prior knowledge to use for the optimization.
        Returns:
            A batch of candidates to evaluate of shape B, where B is the
            sampling batch size, and a dictionary of metadata.
        """
        memory = state.memory.copy(deep=True)
        enc = self._tokenizer.encode

        num_input_tokens, num_output_tokens = 0, 0
        if self.reflection:
            reflection_prompt = load_reflection_prompt()
            # Only reflect on the most recent batch of designs.
            reflection_query = reflection_prompt.format(
                memory=state.memory.iloc[-self.batch_size:].to_markdown(),
                task=self.task.task_description(self.ablate_distribution_shift)
            )
            response, reflection_metadata = self.chat(
                messages=[
                    {"role": "system", "content": self.system_prompt},
                    {"role": "user", "content": reflection_query}
                ],
                tools=[]
            )
            num_input_tokens += max(
                getattr(reflection_metadata, "num_input_tokens", -1), 0
            )
            num_output_tokens += max(
                getattr(reflection_metadata, "num_output_tokens", -1), 0
            )
            knowledge["reflection"] = response["content"]
            logger.info(f"  Reflection: {knowledge['reflection']}")

        optim_prompt = ""
        buffer_tokens = len(enc(self.system_prompt)) + self.max_new_tokens
        if self.reflection:
            buffer_tokens += len(enc(knowledge["reflection"]))
        # Account for model overhead. See link below for more details:
        # https://github.com/openai/openai-cookbook/blob/main/examples
        buffer_tokens += (4 * 2) + 2
        while True:
            optim_prompt = self.user_prompt.format(
                memory=memory.to_markdown(), **knowledge
            )
            num_tokens = len(enc(optim_prompt)) + buffer_tokens
            if not hasattr(self, "max_context_length") or (
                num_tokens <= self.max_context_length
            ):
                break
            if len(memory) >= self.batch_size:
                memory = memory.iloc[self.batch_size:]

        kwargs = {}
        if self.reasoning_effort is not None:
            kwargs["reasoning_effort"] = self.reasoning_effort
        response_format: Dict[str, Any] = {"type": "json_object"}
        if self.use_json_schema:
            response_format = {
                "type": "json_schema",
                "json_schema": {
                    "name": "design",
                    "schema": self.design_schema.model_json_schema()
                }
            }

        out: List[BaseModel] = []
        qbatch = self.batch_size // self.max_batch_size
        rbatch = self.batch_size % self.max_batch_size
        fallbacks = self.design_schema.model_fields
        for batch_idx, candidate_count in enumerate(
            ([self.max_batch_size] * qbatch) + [rbatch]
        ):
            if candidate_count == 0:
                continue
            response = self._client.chat.completions.create(
                messages=[
                    {"role": "system", "content": self.system_prompt},
                    {"role": "user", "content": optim_prompt}
                ],
                model=self.optimizer_name,
                max_completion_tokens=self.max_new_tokens,
                n=candidate_count,
                response_format=response_format,  # type: ignore
                seed=(self.seed + (batch_idx * candidate_count)),
                temperature=self.temperature,
                **kwargs
            )
            for choice in response.choices:
                if choice.message.content is None:
                    continue
                content = json_loads(choice.message.content)
                assert isinstance(content, dict)
                while True:
                    try:
                        out.append(self.design_schema.model_validate(content))
                    except ValidationError as e:
                        for err in e.errors():
                            key = str(err["loc"][0])
                            if not hasattr(fallbacks[key], "annotation"):
                                raise e
                            annot = getattr(fallbacks[key], "annotation", None)
                            if annot == bool:
                                content[key] = self._rng.choice([True, False])
                            elif annot == float:
                                content[key] = self._rng.normal(
                                    loc=self.task.mu[key],
                                    scale=self.task.std[key]
                                )
                                if self.task.task_name == "IWPCWarfarin-v0":
                                    content[key] = max(content[key], 0.0)
                            else:
                                content[key] = self._rng.choice([
                                    x.value for x in getattr(
                                        fallbacks[key], "annotation", []
                                    )
                                ])
                        continue
                    break

            num_input_tokens += max(
                getattr(response.usage, "prompt_tokens", -1), 0
            )
            num_output_tokens += max(
                getattr(response.usage, "completion_tokens", -1), 0
            )

        self._x, self._y = [], np.array([])
        return out, {
            "num_input_tokens": num_input_tokens,
            "num_output_tokens": num_output_tokens
        }

    @retry(
        wait=wait_random_exponential(min=1, max=60),
        stop=stop_after_attempt(1)
    )
    def chat(
        self, messages: List[Dict[str, str]], tools: List[Dict[str, Any]]
    ) -> Tuple[Dict[str, Any], Dict[str, int]]:
        """
        Chat with the parent LLM.
        Input:
            messages: the input messages to the LLM.
            tools: the tools available to the LLM.
        Returns:
            The response from the LLM and associated metadata.
        """
        kwargs: Dict[str, Any] = {}
        if self.reasoning_effort is None:
            kwargs["parallel_tool_calls"] = False
            kwargs["temperature"] = 0
        response = self._client.chat.completions.create(  # type: ignore
            messages=messages,
            model=self.optimizer_name,
            max_completion_tokens=self.max_new_tokens,
            tools=tools,
            tool_choice="auto",
            seed=self.seed,
            **kwargs
        )
        return response.choices[0].message.to_dict(), {
            "num_input_tokens": getattr(response.usage, "prompt_tokens", -1),
            "num_output_tokens": getattr(
                response.usage, "completion_tokens", -1
            )
        }
