#!/usr/bin/python3
"""
Base optimizer class.

Author(s):
    Anonymized Authors @anonymized-authors

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
import abc
import numpy as np
import torch.nn as nn
from pydantic import BaseModel
from typing import Any, Dict, Final, List, Tuple

from ..envs.base import BaseTask
from .state import OptimizerState


class BaseOptimizer(abc.ABC, nn.Module):
    optimizer_name: str = NotImplemented

    def __init__(
        self,
        task: BaseTask,
        batch_size: int,
        seed: int = 2025,
        **kwargs: Dict[str, Any]
    ):
        """
        Args:
            task: the optimization task.
            batch_size: batch size to use for sampling per iteration.
            seed: random seed. Default 2025.
        """
        del kwargs
        super(BaseOptimizer, self).__init__()
        self.task: Final[BaseTask] = task
        self.batch_size: Final[int] = batch_size
        self.seed: Final[int] = seed
        self._rng = np.random.RandomState(seed=self.seed)

    @abc.abstractmethod
    def forward(
        self, state: OptimizerState, knowledge: Dict[str, str], **kwargs
    ) -> Tuple[List[BaseModel], Dict[str, int]]:
        """
        Returns a new batch of candidates to evaluate.
        Input:
            state: the current optimizer state.
            knowledge: the prior knowledge to use for the optimization.
        Returns:
            A batch of candidates to evaluate of shape BD, where B is the batch
            size and D is the number of design dimensions, and a dictionary of
            metadata.
        """
        raise NotImplementedError

    @abc.abstractmethod
    def fit(self, X: List[BaseModel], y: np.ndarray) -> None:
        """
        Fits the generative policy and performs any pre-acquisition steps.
        Input:
            X: a list of BaseModel's of all prior evaluated designs.
            y: an array of shape N of all objective evaluations, where N is the
                number of designs.
        Returns:
            None.
        """
        raise NotImplementedError
