import dataclasses

from bd_mcts.search_algo.base import Result, SearchAlgo, Trial


@dataclasses.dataclass
class BestOfNConfig:
    n: int
    gen_length: int
    num_func_eval_budget: int


class BestOfN(SearchAlgo):
    def __init__(self, config: BestOfNConfig) -> None:
        self.n = config.n
        self.gen_length = config.gen_length
        self.num_func_eval_budget = config.num_func_eval_budget

        self.results: list[tuple[str, float]] = []

        self.trial_id = -1

    def next_trial_id(self) -> int:
        self.trial_id += 1
        return self.trial_id

    def ask(self) -> Trial:
        return Trial(
            trial_id=str(self.next_trial_id()),
            parent_token_seq=None,
            num_tokens_to_demask=self.gen_length,
            remaining_func_evals=self.num_func_eval_budget,
            full_rollout=True,
        )

    def tell(self, result: Result) -> None:
        self.num_func_eval_budget -= result.num_func_evals
        self.results.append((result.clean_pred, result.reward))

    def top_k(self, k: int) -> list[tuple[str, float]]:
        return sorted(self.results, key=lambda x: x[1], reverse=True)[:k]
