from typing import Dict, Optional, Tuple

from llm_mcts.mcts_algo.algo_base import MCTSAlgo
from llm_mcts.mcts_algo.linucb.linucb_state import LinUCBState
from llm_mcts.mcts_algo.mcts_config import MCTSConfig
from llm_mcts.mcts_algo.node import Node
from llm_mcts.mcts_algo.solver.base import AggregatedSolver, Solver
from llm_mcts.mcts_scorer.base import MCTSScorer


class LinUCBWithGenNodeAlgo(MCTSAlgo):
    """
    Node expansion leveraging LinUCB algorithm.
    We will retain GEN node for generating answers from LLM.
    """

    def __init__(self, config: MCTSConfig):
        self.config = config
        assert (
            len(self.config.actions) == 1
        ), "Only a single action is currently supported by LinUCB algorithm"

        self.linucb_states: Dict[Node, LinUCBState] = dict()

        self.next_serial_number = 1  # NOTE: 0 is reserved for the root node

    def run_mcts_step(self, root: Node, solver: Solver, scorer: MCTSScorer) -> None:
        if not isinstance(solver, AggregatedSolver):
            raise KeyError("Only AggregatedSolver is supported")

        if root.is_root() and not root.is_expanded():
            self.expand(
                root,
                solver,
                scorer=scorer,
            )
        # When num_simulations is 0, only the root node is expanded
        if self.config.num_simulations == 0:
            return

        node = root

        while node.is_expanded():
            node, node_model_name = self.select_or_generate_child(
                node, solver=solver, scorer=scorer
            )
            if node_model_name is not None:
                self.backpropagate(node, scorer, node_model_name)
                return

        # Expand the last step node
        node, node_model_name = self.expand(
            node,
            solver,
            scorer=scorer,
        )

        self.backpropagate(node, scorer, node_model_name)

    def expand(
        self,
        node: Node,
        solver: Solver,
        scorer: MCTSScorer,
    ) -> Tuple[Node, str]:
        """
        Expand the children of the unexpanded node.
        For the unexpanded node we don't have any children yet,
        so we generate a new child by choosing one of GEN node with ask_next_idx method.
        """
        assert node not in self.linucb_states
        if not isinstance(solver, AggregatedSolver):
            raise KeyError("Only AggregatedSolver is supported")

        self.linucb_states[node] = LinUCBState(
            model_names=solver.get_solvers(),
            alpha=self.config.alpha,
        )
        node_indices, node_identifier = self.linucb_states[node].ask_next_idx()
        assert isinstance(
            node_identifier, str
        ), "Internal Error: Something went wrong with LinUCB algorithm: The method ask_next_idx for newly created LinUCBState should return model_name rather than model index!"

        child = self.generate_new_child(
            node, solver=solver, scorer=scorer, model_name=node_identifier
        )
        node_indices.add_new_node(model_name=node_identifier)

        return child, node_identifier

    def select_or_generate_child(
        self,
        node: Node,
        solver: Solver,
        scorer: MCTSScorer,
    ) -> Tuple[Node, Optional[str]]:
        """
        Expansion is performed according to LinUCB algotithm.
        In case the new node is generated, model_name is returned for the second argument.
        """
        linucb_state = self.linucb_states[node]

        node_indices, child_identifier = linucb_state.ask_next_idx()
        if isinstance(child_identifier, int):
            return node.children[child_identifier], None
        else:
            child = self.generate_new_child(
                node, solver=solver, scorer=scorer, model_name=child_identifier
            )
            node_indices.add_new_node(model_name=child_identifier)
            return child, child_identifier

    def backpropagate(self, node: Node, scorer: MCTSScorer, model_name: str) -> None:
        # Update score values of each node
        score = scorer.get_score(node=node)

        # TODO: Investigate whether this implementaion is effective!
        # NOTE: For the newly created node, we always update the score for GEN node, rather than newly created node.
        assert node.parent is not None
        self.linucb_states[node.parent].tell_reward(
            reward=score, node_identifier=model_name
        )
        now = node.parent

        # Update LinUCB scores for all the ancestors, this time for the generated node, not GEN node.
        while not now.is_root():
            assert now.parent is not None

            now_idx = now.parent.children.index(now)
            self.linucb_states[now.parent].tell_reward(
                reward=score, node_identifier=now_idx
            )

            now = now.parent

    def generate_new_child(
        self, node: Node, solver: Solver, scorer: MCTSScorer, model_name: str
    ) -> Node:
        if not isinstance(solver, AggregatedSolver):
            raise KeyError("Only AggregatedSolver is supported")

        solver.set_solver(solver_name=model_name)

        nodes, self.next_serial_number = solver.generate_child_nodes(
            node,
            self.config.actions[0],
            num_samples=1,
            scorer=scorer,
            next_serial_number=self.next_serial_number,
        )
        assert len(nodes) == 1

        return nodes[0]
