from llm_mcts.mcts_algo.algo_base import MCTSAlgo
from llm_mcts.mcts_algo.mcts_config import MCTSConfig
from llm_mcts.mcts_algo.node import Node
from llm_mcts.mcts_algo.score_funcs import ScoreFunc
from llm_mcts.mcts_algo.solver.base import Solver
from llm_mcts.mcts_scorer.base import MCTSScorer


class StandardAlgo(MCTSAlgo):
    def __init__(self, config: MCTSConfig, score_func: ScoreFunc):
        assert config.num_expand_samples is not None, f"num_expand_samples must be set"
        self.config = config
        self.score_func = score_func
        self.next_serial_number = 1  # NOTE: 0 is reserved for the root node

    def run_mcts_step(self, root: Node, solver: Solver, scorer: MCTSScorer) -> None:
        if root.is_root() and not root.is_expanded():
            self.expand(
                root,
                solver,
                scorer=scorer,
            )
        # When num_simulations is 0, only the root node is expanded
        if self.config.num_simulations == 0:
            return

        node = root

        while node.is_expanded():
            node = self.select_child(node)

        # Expand the last step node
        self.expand(
            node,
            solver,
            scorer=scorer,
        )

        node.backpropagate(scorer.get_score(node=node))

    def expand(
        self,
        node: Node,
        solver: Solver,
        scorer: MCTSScorer,
    ) -> None:
        for action in self.config.actions:
            num_samples = (
                self.config.initial_expand_samples
                if node.is_root() and self.config.initial_expand_samples is not None
                else self.config.num_expand_samples
            )
            _, self.next_serial_number = solver.generate_child_nodes(
                node,
                action,
                num_samples=num_samples,
                scorer=scorer,
                next_serial_number=self.next_serial_number,
            )

    def select_child(self, node: Node) -> Node:
        """
        We will expand by sampling fixed number of completions
        """
        max_score = -1e6
        max_node: Node = node.children[0]
        for child in node.children:
            score = self.score_func.score(node, child)
            if max_score < score:
                max_score = score
                max_node = child

        return max_node
