from typing import Dict, Literal, Optional, Tuple

from llm_mcts.mcts_algo.algo_base import MCTSAlgo
from llm_mcts.mcts_algo.hierarchical_thompson.thompson_state import (
    HierarchicalThompsonState,
    PriorConfig,
)
from llm_mcts.mcts_algo.mcts_config import MCTSConfig
from llm_mcts.mcts_algo.node import Node
from llm_mcts.mcts_algo.solver.base import AggregatedSolver, Solver
from llm_mcts.mcts_scorer.base import MCTSScorer


class HierarchicalThompsonAlgo(MCTSAlgo):
    """
    Node expansion leveraging Thompson Sampling algorithm.
    We will retain GEN node for generating answers from LLM.
    """

    def __init__(self, config: MCTSConfig, prior_config: Optional[PriorConfig] = None):
        if prior_config is None:
            prior_config = PriorConfig()

        self.prior_config = prior_config
        self.config = config
        assert (
            len(self.config.actions) == 1
        ), "Only a single action is currently supported by Thompson Sampling algorithm"

        self.thompson_states: Dict[Node, HierarchicalThompsonState] = dict()

        self.next_serial_number = 1  # NOTE: 0 is reserved for the root node

    def run_mcts_step(self, root: Node, solver: Solver, scorer: MCTSScorer) -> None:
        if not isinstance(solver, AggregatedSolver):
            raise KeyError("Only AggregatedSolver is supported")

        if root.is_root() and not root.is_expanded():
            self.expand(
                root,
                solver,
                scorer=scorer,
            )
        # When num_simulations is 0, only the root node is expanded
        if self.config.num_simulations == 0:
            return

        node = root

        while node.is_expanded():
            node, node_model_name = self.select_or_generate_child(
                node, solver=solver, scorer=scorer
            )
            if node_model_name is not None:
                self.backpropagate(node, scorer, node_model_name)
                return

        # Expand the last step node
        node, node_model_name = self.expand(
            node,
            solver,
            scorer=scorer,
        )

        self.backpropagate(node, scorer, node_model_name)

    def expand(
        self,
        node: Node,
        solver: Solver,
        scorer: MCTSScorer,
    ) -> Tuple[Node, str]:
        assert node not in self.thompson_states
        if not isinstance(solver, AggregatedSolver):
            raise KeyError("Only AggregatedSolver is supported")

        self.thompson_states[node] = HierarchicalThompsonState(
            model_names=solver.get_solvers(), prior_config=self.prior_config
        )
        thompson_state, node_identifier = self.thompson_states[node].ask_next_idx()
        assert isinstance(
            node_identifier, str
        ), "Internal Error: Something went wrong with Thompson Sampling algorithm: The method ask_next_idx for newly created HierarchicalThompsonState should return model_name rather than model index!"

        child = self.generate_new_child(
            node, solver=solver, scorer=scorer, model_name=node_identifier
        )
        thompson_state.add_new_node(model_name=node_identifier, node=child)

        return child, node_identifier

    def select_or_generate_child(
        self,
        node: Node,
        solver: Solver,
        scorer: MCTSScorer,
    ) -> Tuple[Node, Optional[str]]:
        """
        Expansion is performed according to Thompson Sampling algotithm.
        In case the new node is generated, model_name is returned for the second argument.
        """
        thompson_state = self.thompson_states[node]

        thompson_state, child_identifier = thompson_state.ask_next_idx()
        if isinstance(child_identifier, int):
            return node.children[child_identifier], None
        else:
            child = self.generate_new_child(
                node, solver=solver, scorer=scorer, model_name=child_identifier
            )
            thompson_state.add_new_node(model_name=child_identifier, node=child)
            return child, child_identifier

    def backpropagate(self, node: Node, scorer: MCTSScorer, model_name: str) -> None:
        score = scorer.get_score(node=node)

        # NOTE: For the newly created node, we always update the score for GEN node
        assert node.parent is not None
        self.thompson_states[node.parent].tell_reward(
            reward=score, node_identifier=model_name
        )
        now = node.parent

        # Update scores for all the ancestors, this time for the generated node, not GEN node.
        while not now.is_root():
            assert now.parent is not None

            now_idx = now.parent.children.index(now)
            self.thompson_states[now.parent].tell_reward(
                reward=score, node_identifier=now_idx
            )

            now = now.parent

    def generate_new_child(
        self, node: Node, solver: Solver, scorer: MCTSScorer, model_name: str
    ) -> Node:
        if not isinstance(solver, AggregatedSolver):
            raise KeyError("Only AggregatedSolver is supported")

        solver.set_solver(solver_name=model_name)

        nodes, self.next_serial_number = solver.generate_child_nodes(
            node,
            self.config.actions[0],
            num_samples=1,
            scorer=scorer,
            next_serial_number=self.next_serial_number,
        )
        assert len(nodes) == 1

        return nodes[0]
