from math import log, sqrt
from typing import Dict, List

from llm_mcts.mcts_algo.algo_base import MCTSAlgo
from llm_mcts.mcts_algo.mcts_config import MCTSConfig
from llm_mcts.mcts_algo.node import Node
from llm_mcts.mcts_algo.solver.base import AggregatedSolver, Solver
from llm_mcts.mcts_scorer.base import MCTSScorer


class MultiArmedBanditUCBAlgo(MCTSAlgo):
    """
    Node expansion leveraging MultiArmedBanditUCB algorithm.
    We will retain GEN node for generating answers from LLM.
    """

    def __init__(self, config: MCTSConfig) -> None:
        self.config = config
        assert (
            len(self.config.actions) == 1
        ), "Only a single action is currently supported by MultiArmedBanditUCB algorithm"
        self.score_by_solvers: Dict[str, List[float]] = dict()
        self.next_serial_number = 1  # NOTE: 0 is reserved for the root node

    def run_mcts_step(self, root: Node, solver: Solver, scorer: MCTSScorer) -> None:
        if not isinstance(solver, AggregatedSolver):
            raise KeyError("Only AggregatedSolver is supported")
        assert (
            root.is_root()
        ), "Only root node is supported for MultiArmedBanditUCB algorithm"

        if not root.is_expanded():
            for solver_name in solver.get_solvers():
                self.score_by_solvers[solver_name] = []

        next_solver = self.select_solver(node=root)
        assert (
            next_solver in solver.get_solvers()
        ), "Internal Error: Something went wrong with MultiArmedBanditUCB algorithm: selected solver is not in the list of available solvers"

        new_node = self.generate_new_child(
            root, solver=solver, scorer=scorer, model_name=next_solver
        )
        self.score_by_solvers[next_solver].append(scorer.get_score(node=new_node))

    def generate_new_child(
        self, node: Node, solver: Solver, scorer: MCTSScorer, model_name: str
    ) -> Node:
        if not isinstance(solver, AggregatedSolver):
            raise KeyError("Only AggregatedSolver is supported")
        solver.set_solver(solver_name=model_name)
        nodes, self.next_serial_number = solver.generate_child_nodes(
            node,
            self.config.actions[0],
            num_samples=1,
            scorer=scorer,
            next_serial_number=self.next_serial_number,
        )
        assert len(nodes) == 1, "Only a single node is expected to be generated"
        return nodes[0]

    def select_solver(self, node: Node) -> str:
        """
        We will expand by sampling fixed number of completions
        """
        max_score = -float("inf")
        max_solver = None
        assert sum(len(scores) for scores in self.score_by_solvers.values()) == len(
            node.children
        ), "The number of completions does not match the number of scores"
        for solver_name, scores in self.score_by_solvers.items():
            if (
                len(scores) == 0
            ):  # NOTE: We prefer to select the solver that has not been sampled yet
                return solver_name
            # NOTE: We don't strict the exploitation term to be in [0, 1] range (e.g., CodeContests task)
            # NOTE: We fixed the multiplier constant for exploration term to sqrt(2)
            ucb_score = sum(scores) / len(scores) + sqrt(2) * sqrt(
                log(len(node.children)) / len(scores)
            )
            if max_score < ucb_score:
                max_score = ucb_score
                max_solver = solver_name
        return max_solver
