from __future__ import annotations

import dataclasses

from bd_mcts.search_algo.multimodel_mcts import MultiModelMCTS, MultiModelMCTSConfig


@dataclasses.dataclass
class MaskStrategyMCTSConfig:
    mask_strategies: list[str]
    gen_length: int
    num_func_eval_budget: int
    full_rollout: bool = True
    enable_rollout_cache: bool = True
    demask_schedule: list[float] = dataclasses.field(
        default_factory=lambda: [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.2]
    )
    min_unmask_num: int = 3
    exploration_const: float = 1.0
    max_selection_attempts: int = 20


class MaskStrategyMCTS(MultiModelMCTS):
    """
    A MultiModelMCTS variant where the MCTS action represents a "mask strategy"
    (i.e., the demask/remask algorithm) instead of selecting among multiple models.

    The concrete strategy value is passed through `Trial.action` and is expected to be
    interpreted by the experiment runner / generator.
    """

    def __init__(self, config: MaskStrategyMCTSConfig) -> None:
        super().__init__(
            MultiModelMCTSConfig(
                models=list(config.mask_strategies),
                gen_length=int(config.gen_length),
                num_func_eval_budget=int(config.num_func_eval_budget),
                full_rollout=bool(config.full_rollout),
                enable_rollout_cache=bool(config.enable_rollout_cache),
                demask_schedule=list(config.demask_schedule),
                min_unmask_num=int(config.min_unmask_num),
                exploration_const=float(config.exploration_const),
                max_selection_attempts=int(config.max_selection_attempts),
            )
        )

