from collections import deque
from functools import total_ordering
from heapq import heappush, heappop
from typing import List

from llm_mcts.mcts_algo.algo_base import MCTSAlgo
from llm_mcts.mcts_algo.mcts_config import MCTSConfig
from llm_mcts.mcts_algo.node import Node
from llm_mcts.mcts_algo.solver.base import Solver
from llm_mcts.mcts_scorer.base import MCTSScorer


@total_ordering
class TreeOfThoughtsBFSHeapItem:
    def __init__(self, node: Node, score: float):
        self.node = node
        self.score = score
        depth = 0
        current_node = node
        while not current_node.is_root():
            current_node = current_node.parent
            depth += 1
        self.node_depth = depth

    def __eq__(self, other: "TreeOfThoughtsBFSHeapItem") -> bool:
        return self.node_depth == other.node_depth and self.score == other.score

    def __lt__(self, other: "TreeOfThoughtsBFSHeapItem") -> bool:  # heapq is a min heap
        if self.node_depth != other.node_depth:
            return (
                self.node_depth > other.node_depth
            )  # Deeper nodes are better because we need to get S_{t-1}
        return self.score > other.score  # Higher score is better


# Tree of Thoughts Breadth-First Search (ToT-BFS) algorithm
# Original paper: https://proceedings.neurips.cc/paper_files/paper/2023/hash/271db9922b8d1f4dd7aaef84ed5ac703-Abstract-Conference.html
# Original implementation: https://github.com/princeton-nlp/tree-of-thought-llm/blob/master/src/tot/methods/bfs.py
class TreeOfThoughtsBFSAlgo(MCTSAlgo):
    def __init__(self, config: MCTSConfig):
        assert (
            config.num_simulations > 0
        ), "num_simulations (step limit T) should be greater than 0"
        assert (
            config.initial_expand_samples is not None
            and config.initial_expand_samples > 0
        ), "initial_expand_samples (size limit k) should be greater than 0"
        assert (
            config.num_expand_samples is not None and config.num_expand_samples > 0
        ), "num_expand_samples (breadth limit b) should be greater than 0"
        assert (
            config.initial_expand_samples >= config.num_expand_samples
        ), "initial_expand_samples (size limit k) should be greater than or equal to num_expand_samples (breadth limit b)"
        assert (
            len(config.actions) == 1
        ), "Only a single action is currently supported by TreeOfThoughtsBFSAlgo"
        self.step_limit = config.num_simulations  # T
        self.size_limit = config.initial_expand_samples  # k
        self.breadth_limit = config.num_expand_samples  # b
        self.action = config.actions[0]
        self.next_serial_number = 1  # NOTE: 0 is reserved for the root node

    def run_mcts_step(self, root: Node, solver: Solver, scorer: MCTSScorer) -> None:
        """
        Corresponds to a single time step in the ToT-BFS algorithm
        """
        if root.is_root() and not root.is_expanded():
            self.expand(root, solver, scorer=scorer)
            return  # The first step is to expand the root node

        # NOTE: We swap the order of the processes in the for loop, but the logic is the same
        # Original: (Generate x `k` -> Select x `b`) -> (Generate x `k` -> ... -> Select x `b`) -> Final Answer
        # Ours: Generate x `k` -> (Select x `b` -> Generate x `k`) -> ... -> Generate x `k`) -> Final Answer
        selected_nodes = self.select_nodes(root, scorer)
        for node in selected_nodes:
            self.expand(node, solver, scorer=scorer)

    def expand(self, node: Node, solver: Solver, scorer: MCTSScorer) -> None:
        """
        Generate `k` children nodes for the given node
        """
        _, self.next_serial_number = solver.generate_child_nodes(
            node,
            self.action,
            num_samples=self.size_limit,
            scorer=scorer,
            next_serial_number=self.next_serial_number,
        )

    def select_nodes(self, root_node: Node, scorer: MCTSScorer) -> List[Node]:
        """
        Extract S_{t-1} from the tree (in total, `b` nodes)
        """
        priority_queue: List[TreeOfThoughtsBFSHeapItem] = []
        all_nodes: deque[Node] = deque([])
        for child in root_node.children:
            all_nodes.append(child)
        while all_nodes:
            current_node = all_nodes.popleft()
            if current_node.is_expanded():
                for child in current_node.children:
                    all_nodes.append(child)
            else:
                eval_score = scorer.get_score(node=current_node)
                heappush(
                    priority_queue, TreeOfThoughtsBFSHeapItem(current_node, eval_score)
                )

        selected_nodes = []
        selected_nodes_depth = set()
        for _ in range(self.breadth_limit):
            if (
                not priority_queue
            ):  # NOTE: We should set `k` to be the same or larger than `b`
                raise RuntimeError(
                    "ToT-BFS algorithm went wrong: Priority queue is empty"
                )
            selected = heappop(
                priority_queue
            )  # NOTE: If the scores are the same, two nodes with the same scores are returned in the order they were added, see https://docs.python.org/3/library/heapq.html#priority-queue-implementation-notes
            selected_nodes.append(selected.node)
            selected_nodes_depth.add(selected.node_depth)
        if len(selected_nodes_depth) != 1:
            raise RuntimeError(
                "ToT-BFS algorithm went wrong: All selected nodes should have the same depth"
            )
        if selected_nodes_depth.pop() >= self.step_limit:
            raise RuntimeError(
                "ToT-BFS algorithm went wrong: The selected nodes should be within the step limit"
            )
        return selected_nodes
