#!/usr/bin/python3
"""
Large language models (LLM) as backbone optimizers.

Author(s):
    Anonymized Authors @anonymized-authors

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
import abc
import os
import numpy as np
from pathlib import Path
from pydantic import BaseModel
from tenacity import retry, stop_after_attempt, wait_random_exponential
from typing import Any, Dict, Final, List, Tuple, Union

from .base import BaseOptimizer
from .state import OptimizerState
from ..envs.base import BaseTask


class BaseLLMOptimizer(BaseOptimizer, abc.ABC):
    optimizer_name: str = NotImplemented

    def __init__(
        self,
        task: BaseTask,
        batch_size: int,
        temperature: float,
        max_new_tokens: int,
        seed: int = 2025,
        promptdir: Union[Path, str] = os.path.join(
            os.path.dirname(__file__), "prompts"
        ),
        system_prompt_version: str = "base",
        user_prompt_version: str = "base",
        reflection: bool = True,
        ablate_distribution_shift: bool = False,
        **kwargs: Dict[str, Any]
    ):
        """
        Args:
            task: the optimization task.
            batch_size: batch size for optimization.
            temperature: temperature parameter for text generation.
            max_new_tokens: maximum number of new tokens to generate.
            seed: random seed. Default 2025.
            promptdir: directory containing the prompt templates.
            system_prompt_version: version of the system prompt to load.
            user_prompt_version: version of the user prompt to load.
            reflection: whether to perform reflection as in Ma YJ, et al. Proc
                ICLR (2024).
            ablate_distribution_shift: whether to ablate knowledge of the
                distribution shift.
        """
        del kwargs
        super(BaseLLMOptimizer, self).__init__(task, batch_size, seed=seed)

        self.temperature: Final[float] = temperature
        self.max_new_tokens: Final[int] = max_new_tokens
        self.promptdir: Final[Union[Path, str]] = promptdir
        self.system_prompt_version: Final[str] = system_prompt_version
        self.user_prompt_version: Final[str] = user_prompt_version
        self.reflection: Final[bool] = reflection
        self.ablate_distribution_shift: Final[bool] = ablate_distribution_shift

        self.system_prompt: Final[str] = self._load_system_prompt(
            system_prompt_version=self.system_prompt_version,
            task_name=self.task.task_name
        )
        self.user_prompt: Final[str] = self._load_user_prompt(
            user_prompt_version=self.user_prompt_version
        )
        self._x: List[BaseModel] = []
        self._y: np.ndarray = np.array([])

    @retry(
        wait=wait_random_exponential(min=1, max=60),
        stop=stop_after_attempt(6)
    )
    @abc.abstractmethod
    def forward(
        self, state: OptimizerState, knowledge: Dict[str, str], **kwargs
    ) -> Tuple[List[BaseModel], Dict[str, int]]:
        """
        Returns a new batch of candidates to evaluate.
        Input:
            state: the current optimizer state.
            knowledge: the prior knowledge to use for the optimization.
        Returns:
            A batch of candidates to evaluate of shape B, where B is the
            sampling batch size, and a dictionary of metadata.
        """
        raise NotImplementedError

    def fit(self, x: List[BaseModel], y: np.ndarray) -> None:
        """
        Fits the generative policy and performs any pre-acquisition steps.
        Input:
            x: a list of BaseModel's of all prior evaluated designs.
            y: an array of shape N of all objective evaluations, where N is the
                number of designs.
        Returns:
            None.
        """
        assert len(self._x) == 0 and len(self._y) == 0
        self._x, self._y = x, y

    def _load_system_prompt(
        self, system_prompt_version: str, task_name: str
    ) -> str:
        """
        Load the system prompt from the prompt directory.
        Input:
            system_prompt_version: version of the system prompt to load.
            task_name: name of the task.
        Returns:
            The system prompt.
        """
        system_fn = os.path.join(
            self.promptdir, "system", system_prompt_version, f"{task_name}.txt"
        )
        assert os.path.isfile(system_fn)
        with open(system_fn) as f:
            return f.read()

    def _load_user_prompt(self, user_prompt_version: str) -> str:
        """
        Load the user prompt from the prompt directory.
        Input:
            user_prompt_version: version of the user prompt to load.
        Returns:
            The user prompt.
        """
        user_fn = os.path.join(
            self.promptdir, "user", f"{user_prompt_version}.txt"
        )
        if self.reflection:
            user_fn = user_fn.replace(".txt", "_reflection.txt", 1)
        assert os.path.isfile(user_fn)
        with open(user_fn) as f:
            return f.read()

    @retry(
        wait=wait_random_exponential(min=1, max=60),
        stop=stop_after_attempt(6)
    )
    def chat(
        self, messages: List[Dict[str, str]], tools: List[Dict[str, Any]]
    ) -> Tuple[Dict[str, Any], Dict[str, int]]:
        """
        Chat with the parent LLM.
        Input:
            messages: the input messages to the LLM.
            tools: the tools available to the LLM.
        Returns:
            The response from the LLM and associated metadata.
        """
        del messages, tools
        raise NotImplementedError


def load_reflection_prompt() -> str:
    """
    Loads the reflection prompt as from the EUREKA optimizer.
    Input:
        None.
    Returns:
        The reflection prompt.
    Citation(s):
        [1] Ma YJ, Liang W, Wang G, et al. Eureka: Human-level reward design
            via coding large language models. Proc ICLR. (2024). URL:
            https://openreview.net/forum?id=IEduRUO55F
    """
    with open(
        os.path.join(
            os.path.dirname(__file__),
            "prompts",
            "user",
            "reflection.txt"
        )
    ) as f:
        return f.read()
