from __future__ import annotations

import abc
from dataclasses import dataclass

from bd_mcts.token_sequence import TokenSequence


@dataclass
class Trial:
    trial_id: str
    parent_token_seq: TokenSequence | None  # None stands for a root parent
    num_tokens_to_demask: int
    remaining_func_evals: int  # Remaining NFE budget, to be used as an information by diffusion model (e.g. early stopping). If all the budget is consumed, return -1.
    full_rollout: bool = True
    action: str | None = None
    rollout_mask_targets: list[int] | None = None
    cached_result: Result | None = None


@dataclass
class Result:
    trial_id: str
    new_token_seq: TokenSequence
    clean_pred: str
    num_func_evals: int  # the amount of NFE budget consumed by MLD models
    reward: float
    # Optional rollout path for DTS-style algorithms (first element should be new_token_seq).
    rollout_token_seqs: list[TokenSequence] | None = None
    rollout_clean_preds: list[str | None] | None = None
    final_token_len: int | None = None
    token_len_history: list[int] | None = None
    token_len_with_mask_history: list[int] | None = None
    # Optional detailed demasking trace (per diffusion step) for analysis/visualization.
    demask_step_indices: list[int] | None = None
    demask_mask_counts: list[int] | None = None


class SearchAlgo(abc.ABC):
    @abc.abstractmethod
    def ask(self) -> Trial:
        raise NotImplementedError()

    @abc.abstractmethod
    def tell(self, result: Result) -> None:
        raise NotImplementedError()

    @abc.abstractmethod
    def top_k(self, k: int) -> list[tuple[str, float]]:
        raise NotImplementedError()
