from abc import ABC, abstractmethod
from typing import List

from llm_mcts.mcts_algo.node import Node
from llm_mcts.mcts_algo.solver.base import AggregatedSolver, Solver
from llm_mcts.mcts_scorer.base import MCTSScorer


class MCTSAlgo(ABC):
    @abstractmethod
    def run_mcts_step(self, root: Node, solver: Solver, scorer: MCTSScorer) -> None:
        raise NotImplementedError()

    def next_node_to_generate_child(self, root: Node) -> Node:
        """
        Convinience method to be used for retrieving the next node to generate children.

        Implementation Summary:
        We leverage MockSolver which raises MockSolverException when generate_child_nodes of MockSolver is called by MCTSAlgo, where the next node
        to expand will be raised.

        NOTE:
        This method is experimental one, so it may not work for all the MCTSAlgo concrete classes. So far, HierarchicalThompsonAlgo was checked to work fine.
        """
        try:
            self.run_mcts_step(root, MockSolver(), MockScorer())
        except MockSolverException as e:
            return e.node

        raise RuntimeError(
            f"generate_child_nodes was not called while running run_mcts_step: Something went wrong with MCTSAlgo {self}"
        )


class MockSolverException(Exception):
    """
    Convinience Exception class to be used only for throwing Node object.
    This Exception class is only used by MockSolver.
    """

    def __init__(self, node: Node):
        self.node = node
        super().__init__()


class MockSolver(AggregatedSolver):
    """
    A mock AggregatedSolver class, which raises MockSolverException with node object when generate_child_nodes is called by MCTSAlgo concrete classes.

    This Solver class should be only used by next_node_to_generate_child.
    """

    def generate_child_nodes(self, node, kind, num_samples, scorer, next_serial_number):
        raise MockSolverException(node)

    def get_solvers(self) -> List[str]:
        return ["mock_agent"]

    def set_solver(self, solver_name):
        pass


class MockScorer(MCTSScorer):
    """
    Mock solver class to avoid an error by MCTSAlgo. The returned scores should not be used.
    """

    def get_score(self, node):
        return 0.0
