#!/usr/bin/python3
"""
Implements the baseline CMA-ES generative optimization policy.

Author(s):
    Anonymized Authors @anonymized-authors

Citation(s):
    [1] Hansen N. The CMA evolution strategy: A tutorial. arXiv Preprint.
        (2016). doi: 10.48550/arXiv.1604.00772
    [2] Hansen N, Ostermeier A. Adapting arbitrary normal mutation
        distributions in evolution strategies: The covariance matrix
        adaptation. Proc IEEE Intern Conf Evolutionary Computation: 312-7.
        (1996). doi: 10.1109/ICEC.1996.542381

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
import asyncio
import cma  # type: ignore[import-untyped]
import numpy as np
import threading
from concurrent.futures import ThreadPoolExecutor
from pydantic import BaseModel
from typing import Any, Dict, List, Optional, Tuple

from .base import BaseOptimizer
from .state import OptimizerState
from ..envs.base import BaseTask
from ..utils import Pydantic2Rd, Rd2Pydantic


class CMAESOptimizer(BaseOptimizer):
    optimizer_name: str = "CMAES"

    def __init__(
        self,
        task: BaseTask,
        batch_size: int,
        seed: int = 2025,
        max_workers: Optional[int] = None,
        **kwargs: Dict[str, Any]
    ):
        """
        Args:
            task: the optimization task.
            batch_size: the number of new designs to return per sampling step.
            seed: random seed. Default 2025.
            max_workers: the maximum number of threads to use.
        """
        super(CMAESOptimizer, self).__init__(
            task=task, batch_size=batch_size, seed=seed, **kwargs
        )
        self._task_lock = threading.Lock()
        self._executor = ThreadPoolExecutor(
            max_workers=(max_workers or self.batch_size)
        )

    def forward(
        self,
        state: OptimizerState,
        knowledge: Dict[str, str],
        **kwargs: Dict[str, Any]
    ) -> Tuple[List[BaseModel], Dict[str, int]]:
        """
        Returns a new batch of candidates to evaluate.
        Input:
            state: the current optimizer state.
            knowledge: the prior knowledge to use for the optimization.
        Returns:
            A batch of candidates to evaluate and a dictionary of metadata.
        """
        del knowledge
        candidates = asyncio.run(
            self._forward(
                state,
                verbose=int(kwargs.get("verbose", -9)),  # type: ignore
                maxfevals=int(kwargs.get("maxfevals", 32))  # type: ignore
            )
        )
        return candidates, {}

    async def _forward(
        self, state: OptimizerState, verbose: int = -9, maxfevals: int = 32
    ) -> List[BaseModel]:
        """
        Returns a new batch of candidates to evaluate.
        Input:
            state: the current optimizer state.
            verbose: verbosity level for the CMA-ES optimizer.
            maxfevals: maximum number of function evaluations.
        Returns:
            A batch of candidates to evaluate.
        """
        def wrapper(x: np.ndarray) -> np.ndarray:
            x = x[np.newaxis, ...] if x.ndim == 1 else x
            with self._task_lock:
                designs = self.task.extend(
                    Rd2Pydantic(self.task, x, continuous=True),
                    state.individual
                )
                y = np.array(self.task(designs))
            return -1.0 * y

        loop = asyncio.get_running_loop()
        options = {
            "verbose": verbose,
            "maxfevals": maxfevals,
            "seed": self.seed,
        }
        futures = [
            loop.run_in_executor(
                self._executor, cma.fmin2, wrapper, x0, 1.0, options
            )
            for x0 in self.X
        ]
        results = await asyncio.gather(*futures)
        candidates = np.vstack([sol for sol, _ in results])
        return Rd2Pydantic(self.task, candidates, continuous=True)

    def fit(self, X: List[BaseModel], y: np.ndarray) -> None:
        """
        Fits the generative policy and performs any pre-acquisition steps.
        Input:
            X: a list of BaseModel's of all prior evaluated designs.
            y: an array of shape N of all objective evaluations, where N is the
                number of designs.
        Returns:
            None.
        """
        Xval = Pydantic2Rd(self.task, X, continuous=True)
        self.X = Xval[np.argsort(y.squeeze())[-self.batch_size:]]
        if self.X.ndim == 1:
            self.X = self.X.reshape(-1, 1)
