#!/usr/bin/python3
"""
Implements dual annealing for optimization.

Author(s):
    Anonymized Authors @anonymized-authors

Citation(s):
    [1] Xiang Y, Gubian S, Suomela B, Hoeng J. Global optimization:
        The GenSA package. The R Journal 5/1. (2013). URL:
        https://journal.r-project.org/archive/2013/RJ-2013-002/RJ-2013-002.pdf

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
import numpy as np
from pydantic import BaseModel
from scipy.optimize import dual_annealing
from tqdm import tqdm
from typing import Any, Dict, Final, List, NamedTuple, Optional, Tuple, Union

from .base import BaseOptimizer
from .state import OptimizerState
from ..envs.base import BaseTask
from ..utils import Pydantic2Rd, Rd2Pydantic


class DualAnnealingOptimizer(BaseOptimizer):
    optimizer_name: str = "DualAnnealing"

    def __init__(
        self,
        task: BaseTask,
        ref: NamedTuple,
        batch_size: int,
        seed: int = 2025,
        **kwargs: Dict[str, Any]
    ):
        """
        Args:
            task: the optimization task.
            ref: the reference individual that we are optimizing for.
            batch_size: the number of new designs to return per sampling step.
            seed: random seed. Default 2025.
            categorical_bounds: the bounds for each categorical dimension.
        """
        del kwargs
        super(DualAnnealingOptimizer, self).__init__(
            task, batch_size, seed=seed
        )
        self.ref: Final[NamedTuple] = ref
        self.X: Optional[np.ndarray] = None

    def forward(
        self,
        state: OptimizerState,
        knowledge: Dict[str, str],
        **kwargs: Dict[str, Any]
    ) -> Tuple[List[BaseModel], Dict[str, int]]:
        """
        Returns a new batch of candidates to evaluate.
        Input:
            state: the current optimizer state.
            knowledge: the prior knowledge to use for the optimization.
        Returns:
            A batch of candidates to evaluate of shape B, where B is the
            sampling batch size, and a dictionary of metadata.
        """
        del state, knowledge, kwargs
        if self.X is not None:
            bounds = [(-3.0, 3.0)] * self.X.shape[1]
        else:
            raise RuntimeError

        def _wrapped_eval(
            x: Union[
                np.typing.NDArray[np.float64], np.typing.NDArray[np.int64]
            ]
        ) -> float:
            x = x.reshape(1, -1)
            designs = self.task.extend(
                Rd2Pydantic(self.task, x, continuous=True), self.ref
            )
            return -1.0 * np.array(self.task(designs)).item()

        if not hasattr(self, "candidates"):
            candidates = np.vstack([
                getattr(
                    dual_annealing(
                        _wrapped_eval,
                        bounds,
                        rng=self.seed,
                        x0=x0,
                        maxfun=128
                    ),
                    "x"
                )
                for x0 in tqdm(self.X)
            ])
            setattr(
                self,
                "candidates",
                Rd2Pydantic(self.task, candidates, continuous=True)
            )
        self.X = None
        return getattr(self, "candidates", []), {}

    def fit(self, X: List[BaseModel], y: np.ndarray) -> None:
        """
        Performs any pre-acquisition steps.
        Input:
            X: a list of all prior evaluated designs.
            y: an array of shape N of all objective evaluations, where N is the
                number of designs.
        Returns:
            None.
        """
        Xval = Pydantic2Rd(self.task, X, continuous=True)
        self.X = Xval[np.argsort(y.squeeze())[-self.batch_size:]]
        if self.X.ndim == 1:
            self.X = self.X.reshape(-1, 1)
