#!/usr/bin/python3
"""
Implements the baseline genetic algorithm optimization method.

Author(s):
    Anonymized Authors @anonymized-authors

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
import numpy as np
import pygad  # type: ignore[import-untyped]
from pydantic import BaseModel
from typing import Any, Dict, Final, List, Tuple, Union

from .base import BaseOptimizer
from .state import OptimizerState
from ..envs.base import BaseTask
from ..utils import Pydantic2Rd, Rd2Pydantic


class GAOptimizer(BaseOptimizer):
    optimizer_name: str = "Genetic-Algorithm"

    def __init__(
        self,
        task: BaseTask,
        batch_size: int,
        seed: int = 2025,
        **kwargs: Dict[str, Any]
    ):
        """
        Args:
            task: the optimization task.
            batch_size: the number of new designs to return per sampling step.
            seed: random seed. Default 2025.
        """
        self.num_generations: Final[int] = int(
            str(kwargs.get("num_generations", 50))
        )
        self.num_parents_mating: Final[int] = int(
            str(kwargs.get("num_parents_mating", 4))
        )
        super(GAOptimizer, self).__init__(
            task=task, batch_size=batch_size, seed=seed, **kwargs
        )

    def forward(
        self,
        state: OptimizerState,
        knowledge: Dict[str, str],
        **kwargs: Dict[str, Any]
    ) -> Tuple[List[BaseModel], Dict[str, int]]:
        """
        Returns a new batch of candidates to evaluate.
        Input:
            state: the current optimizer state.
            knowledge: the prior knowledge to use for the optimization.
        Returns:
            A batch of candidates to evaluate and a dictionary of metadata.
        """
        del knowledge, kwargs

        def wrapper(
            ga_instance: pygad.GA, x: np.ndarray, sol_idx: int
        ) -> Union[float, np.ndarray]:
            del ga_instance, sol_idx
            designs = self.task.extend(
                Rd2Pydantic(self.task, np.atleast_2d(x), continuous=True),
                state.individual
            )
            y = np.array(self.task(designs))
            if y.size == 1:
                return y.item()
            return y

        ga = pygad.GA(
            num_generations=1,
            num_parents_mating=self.num_parents_mating,
            fitness_func=wrapper,
            initial_population=getattr(self, "X", None),
            save_solutions=True,
            random_seed=self.seed
        )
        ga.run()
        candidates = np.array(ga.solutions)
        if candidates.shape[0] > self.batch_size:
            ypred = wrapper(ga, candidates, -1)
            idxs = np.argsort(ypred)[self.batch_size:]
            idxs = self._rng.choice(idxs, self.batch_size, replace=False)
            candidates = candidates[idxs]

        return Rd2Pydantic(self.task, candidates, continuous=True), {}

    def fit(self, X: List[BaseModel], y: np.ndarray) -> None:
        """
        Fits the generative policy and performs any pre-acquisition steps.
        Input:
            X: a list of BaseModel's of all prior evaluated designs.
            y: an array of shape N of all objective evaluations, where N is the
                number of designs.
        Returns:
            None.
        """
        Xval = Pydantic2Rd(self.task, X, continuous=True)
        self.X = Xval[np.argsort(y.squeeze())[-self.batch_size:]]
        if self.X.ndim == 1:
            self.X = self.X.reshape(-1, 1)
