from functools import total_ordering
from heapq import heappush, heappop
from typing import List

from llm_mcts.mcts_algo.algo_base import MCTSAlgo
from llm_mcts.mcts_algo.mcts_config import MCTSConfig
from llm_mcts.mcts_algo.node import Node
from llm_mcts.mcts_algo.solver.base import Solver
from llm_mcts.mcts_scorer.base import MCTSScorer


@total_ordering
class BestFirstSearchHeapItem:
    def __init__(self, node: Node, score: float):
        self.node = node
        self.score = score
        depth = 0
        current_node = node
        while not current_node.is_root():
            current_node = current_node.parent
            depth += 1
        self.node_depth = depth

    def __eq__(self, other: "BestFirstSearchHeapItem") -> bool:
        return self.node_depth == other.node_depth and self.score == other.score

    def __lt__(self, other: "BestFirstSearchHeapItem") -> bool:  # heapq is a min heap
        if self.score == other.score:
            return (
                self.node_depth < other.node_depth
            )  # NOTE: We define that shallower nodes are better
        return self.score > other.score  # Higher score is better


class BestFirstSearchAlgo(MCTSAlgo):
    def __init__(self, config: MCTSConfig):
        self.config = config
        self.leaves: List[BestFirstSearchHeapItem] = []
        self.next_serial_number = 1  # NOTE: 0 is reserved for the root node

    def run_mcts_step(self, root: Node, solver: Solver, scorer: MCTSScorer) -> None:
        if root.is_root() and not root.is_expanded():
            self.expand(root, solver, scorer=scorer)
        # When num_simulations is 0, only the root node is expanded
        if self.config.num_simulations == 0:
            return

        # Select the best leaf node
        # NOTE: If we don"t want to restrict expanding non-leaf nodes, we can remove the above assertion and change the following line to:
        # node = self.leaves[0].node
        node = heappop(self.leaves).node
        assert (
            not node.is_expanded()
        ), "Best First Search algorithm went wrong: Leaf node should not be expanded."

        # Expand the last step node
        self.expand(node, solver, scorer=scorer)

    def expand(
        self,
        node: Node,
        solver: Solver,
        scorer: MCTSScorer,
    ) -> None:
        for action in self.config.actions:
            num_samples = (
                self.config.initial_expand_samples
                if node.is_root() and self.config.initial_expand_samples is not None
                else self.config.num_expand_samples
            )
            generated_nodes, self.next_serial_number = solver.generate_child_nodes(
                node,
                action,
                num_samples=num_samples,
                scorer=scorer,
                next_serial_number=self.next_serial_number,
            )
            for child in generated_nodes:
                score = scorer.get_score(child)
                heappush(self.leaves, BestFirstSearchHeapItem(child, score))
