#!/usr/bin/python3
"""
Optimization policy that always returns the human-generated designs from the
test dataset as the next candidate designs. Should be (near-) optimal and
thought of as a positive control.

Author(s):
    Anonymized Authors @anonymized-authors

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
import numpy as np
from importlib import import_module
from gymnasium.envs.registration import registry
from pydantic import BaseModel
from typing import Any, Dict, Final, List, NamedTuple, Optional, Tuple

from .base import BaseOptimizer
from ..envs.base import BaseTask
from .state import OptimizerState


class HumanBaselineOptimizer(BaseOptimizer):
    optimizer_name: str = "human-baseline"

    def __init__(
        self,
        task: BaseTask,
        ref: NamedTuple,
        batch_size: int,
        seed: int = 2025,
        **kwargs: Dict[str, Any]
    ):
        """
        Args:
            task: the optimization task.
            ref: the reference individual that we are optimizing for.
            batch_size: batch size to use for sampling per iteration.
            seed: random seed. Default 2025.
        """
        del kwargs
        task_spec = registry[task.task_name].kwargs
        dataset_cls: Optional[str] = task_spec.get("dataset", None)
        assert dataset_cls is not None
        module, attr = dataset_cls.split(":", 1)
        dataset = getattr(import_module(module), attr)(task_spec["test_split"])
        ref_design = [
            dataset[i] for i in range(len(dataset))
            if getattr(dataset[i], "id_", None) == getattr(ref, "id_", -1)
        ]
        assert len(ref_design) == 1
        self._ref: Final[NamedTuple] = ref_design[0]
        super(HumanBaselineOptimizer, self).__init__(task, batch_size, seed)

    def forward(
        self, state: OptimizerState, knowledge: Dict[str, str], **kwargs
    ) -> Tuple[List[BaseModel], Dict[str, int]]:
        """
        Returns a new batch of candidates to evaluate.
        Input:
            state: the current optimizer state.
            knowledge: the prior knowledge to use for the optimization.
        Returns:
            A batch of candidates to evaluate of shape BD, where B is the batch
            size and D is the number of design dimensions, and a dictionary of
            metadata.
        """
        assert getattr(state.individual, "id_", None) == (
            getattr(self._ref, "id_", -1)
        )
        del knowledge, kwargs
        return self.task.reduce([self._ref]) * self.batch_size, {}

    def fit(self, X: List[BaseModel], y: np.ndarray) -> None:
        """
        Fits the generative policy and performs any pre-acquisition steps.
        Input:
            X: a list of BaseModel's of all prior evaluated designs.
            y: an array of shape N of all objective evaluations, where N is the
                number of designs.
        Returns:
            None.
        """
        del X, y
        return
