from __future__ import annotations

import dataclasses

from bd_mcts.search_algo.multimodel_mcts import MultiModelMCTS, MultiModelMCTSConfig, MCTSNode


def _format_temperature_action(value: float) -> str:
    return str(float(value))


@dataclasses.dataclass
class TemperatureMCTSConfig:
    temperatures: list[float]
    gen_length: int
    num_func_eval_budget: int
    full_rollout: bool = True
    enable_rollout_cache: bool = True
    demask_schedule: list[float] = dataclasses.field(
        default_factory=lambda: [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.2]
    )
    min_unmask_num: int = 3
    exploration_const: float = 1.0
    max_selection_attempts: int = 20


class TemperatureMCTS(MultiModelMCTS):
    """
    A MultiModelMCTS variant where the MCTS action represents a demask sampling
    temperature (while keeping the remask algorithm fixed in the demask config).

    The selected temperature is passed through `Trial.action` as a string and is
    expected to be interpreted by the experiment runner / generator.
    """

    def __init__(self, config: TemperatureMCTSConfig) -> None:
        if not config.temperatures:
            raise ValueError("TemperatureMCTS requires at least one temperature")

        temperatures: list[float] = []
        for idx, value in enumerate(config.temperatures):
            try:
                temp = float(value)
            except (TypeError, ValueError) as exc:
                raise ValueError(
                    f"temperatures entries must be numbers, got {value!r} at index {idx}"
                ) from exc
            if temp < 0.0:
                raise ValueError(
                    f"temperatures must be >= 0, got {temp} at index {idx}"
                )
            temperatures.append(temp)

        action_by_temp = {_format_temperature_action(t): t for t in temperatures}
        if len(action_by_temp) != len(temperatures):
            raise ValueError(
                "temperatures contain duplicates after normalization; "
                "use distinct values (e.g., 0.1 and 0.10 are considered the same)"
            )
        self._temp_by_action = dict(action_by_temp)

        super().__init__(
            MultiModelMCTSConfig(
                models=list(self._temp_by_action.keys()),
                gen_length=int(config.gen_length),
                num_func_eval_budget=int(config.num_func_eval_budget),
                full_rollout=bool(config.full_rollout),
                enable_rollout_cache=bool(config.enable_rollout_cache),
                demask_schedule=list(config.demask_schedule),
                min_unmask_num=int(config.min_unmask_num),
                exploration_const=float(config.exploration_const),
                max_selection_attempts=int(config.max_selection_attempts),
            )
        )

    def _select_action(self, node: MCTSNode, actions: list[str]) -> str:
        if self.enable_rollout_cache:
            cached = [
                action for action in actions if self._get_cached_rollout(node, action)
            ]
            candidates = cached if cached else actions
        else:
            candidates = actions
        return min(candidates, key=lambda action: self._temp_by_action.get(action, 0.0))

