from abc import ABC, abstractmethod
from typing import List, Tuple

from llm_mcts.data_types import Action
from llm_mcts.mcts_algo.node import Node
from llm_mcts.mcts_scorer.base import MCTSScorer


class Solver(ABC):
    @abstractmethod
    def generate_child_nodes(
        self,
        node: Node,
        kind: Action,
        num_samples: int,
        scorer: MCTSScorer,
        next_serial_number: int,
    ) -> Tuple[List[Node], int]:
        raise NotImplementedError()


class AggregatedSolver(Solver):
    @abstractmethod
    def get_solvers(self) -> List[str]:
        raise NotImplementedError()

    @abstractmethod
    def set_solver(self, solver_name: str) -> None:
        raise NotImplementedError()
