#!/usr/bin/python3
"""
Simple baseline optimization policy that always returns the majority class
of the training data as the next candidate designs (or the mean of the
training data if optimizing over a continuous design space). Should perform
quite poorly and be thought of as a negative control.

Author(s):
    Anonymized Authors @anonymized-authors

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
import numpy as np
import torch
from pydantic import BaseModel
from typing import Any, Dict, List, Tuple, Union

from .base import BaseOptimizer
from .state import OptimizerState
from ..envs.base import BaseTask


class MajorityBaselineOptimizer(BaseOptimizer):
    optimizer_name: str = "majority-baseline"

    def __init__(
        self,
        task: BaseTask,
        batch_size: int,
        seed: int = 2025,
        **kwargs: Dict[str, Any]
    ):
        """
        Args:
            task: the optimization task.
            batch_size: batch size to use for sampling per iteration.
            seed: random seed. Default 2025.
        """
        del kwargs
        super(MajorityBaselineOptimizer, self).__init__(task, batch_size, seed)
        train_designs = torch.vstack([
            task.train[i].as_tensor() for i in range(len(task.train))
        ])
        majority: List[Union[float, int, bool, str]] = []
        for i in range(train_designs.size(dim=-1)):
            if torch.is_floating_point(train_designs[..., i]):
                majority.append(torch.mean(train_designs[..., i]).item())
            else:
                mode_values, _ = torch.mode(train_designs[..., i])
                if mode_values.ndim == 0:
                    mode_values = mode_values.unsqueeze(dim=0)
                majority.append(mode_values[0].item())
        assert len(majority) == len(self.task.design_schema.model_fields)
        self.majority = self.task.design_schema(
            **{
                k: v for k, v in zip(
                    self.task.design_schema.model_fields.keys(), majority
                )
            }
        )

    def forward(
        self, state: OptimizerState, knowledge: Dict[str, str], **kwargs
    ) -> Tuple[List[BaseModel], Dict[str, int]]:
        """
        Returns a new batch of candidates to evaluate.
        Input:
            state: the current optimizer state.
            knowledge: the prior knowledge to use for the optimization.
        Returns:
            A batch of candidates to evaluate of shape BD, where B is the batch
            size and D is the number of design dimensions, and a dictionary of
            metadata.
        """
        del state, knowledge, kwargs
        return [self.majority] * self.batch_size, {}

    def fit(self, X: List[BaseModel], y: np.ndarray) -> None:
        """
        Fits the generative policy and performs any pre-acquisition steps.
        Input:
            X: a list of BaseModel's of all prior evaluated designs.
            y: an array of shape N of all objective evaluations, where N is the
                number of designs.
        Returns:
            None.
        """
        del X, y
        return
