from collections import defaultdict
from typing import Dict, Literal, Optional, Tuple

from llm_mcts.mcts_algo.algo_base import MCTSAlgo
from llm_mcts.mcts_algo.hierarchical_thompson.thompson_state import (
    HierarchicalThompsonState,
    PriorConfig,
)
from llm_mcts.mcts_algo.mcts_config import MCTSConfig
from llm_mcts.mcts_algo.node import Node
from llm_mcts.mcts_algo.solver.base import AggregatedSolver, Solver
from llm_mcts.mcts_scorer.base import MCTSScorer


class HierarchicalThompsonAlgo(MCTSAlgo):
    """
    Node expansion leveraging Thompson Sampling algorithm.
    We will retain GEN node for generating answers from LLM.
    """

    def __init__(
        self,
        config: MCTSConfig,
        prior_config: Optional[PriorConfig] = None,
        reward_average_priors: Optional[float | Dict[str, float]] = None,
        model_selection_strategy: Literal[
            "stack", "multiarm_bandit_thompson", "multiarm_bandit_ucb"
        ] = "stack",
    ):
        if prior_config is None:
            prior_config = PriorConfig()

        self.prior_config = prior_config
        self.reward_average_priors = reward_average_priors

        # Strategy for model selection:
        # "stack": Perform separate fits for each model (traditional approach)
        # "multiarm_bandit_thompson": Use Thompson Sampling for joint selection
        # "multiarm_bandit_ucb": Use UCB for joint selection
        if model_selection_strategy not in [
            "stack",
            "multiarm_bandit_thompson",
            "multiarm_bandit_ucb",
        ]:
            raise ValueError(
                f"Invalid model_selection_strategy: {model_selection_strategy}. "
                f"Must be one of: 'stack', 'multiarm_bandit_thompson', 'multiarm_bandit_ucb'"
            )
        self.model_selection_strategy = model_selection_strategy

        self.config = config
        assert (
            len(self.config.actions) == 1
        ), "Only a single action is currently supported by Thompson Sampling algorithm"

        self.thompson_states: Dict[Node, HierarchicalThompsonState] = dict()
        self.all_rewards_store: Dict[str, list[float]] = defaultdict(list)

        self.next_serial_number = 1  # NOTE: 0 is reserved for the root node

    def run_mcts_step(self, root: Node, solver: Solver, scorer: MCTSScorer) -> None:
        if not isinstance(solver, AggregatedSolver):
            raise KeyError("Only AggregatedSolver is supported")

        if root.is_root() and not root.is_expanded():
            self.expand(
                root,
                solver,
                scorer=scorer,
            )
        # When num_simulations is 0, only the root node is expanded
        if self.config.num_simulations == 0:
            return

        node = root

        while node.is_expanded():
            node, node_model_name = self.select_or_generate_child(
                node, solver=solver, scorer=scorer
            )
            if node_model_name is not None:
                self.backpropagate(node, scorer, node_model_name)
                return

        # Expand the last step node
        node, node_model_name = self.expand(
            node,
            solver,
            scorer=scorer,
        )

        self.backpropagate(node, scorer, node_model_name)

    def expand(
        self,
        node: Node,
        solver: Solver,
        scorer: MCTSScorer,
    ) -> Tuple[Node, str]:
        assert node not in self.thompson_states
        if not isinstance(solver, AggregatedSolver):
            raise KeyError("Only AggregatedSolver is supported")

        self.thompson_states[node] = HierarchicalThompsonState(
            model_names=solver.get_solvers(),
            prior_config=self.prior_config,
            reward_average_priors=self.reward_average_priors,
            model_selection_strategy=self.model_selection_strategy,
        )
        thompson_state, node_identifier = self.thompson_states[node].ask_next_idx(
            self.all_rewards_store
        )
        assert isinstance(
            node_identifier, str
        ), "Internal Error: Something went wrong with Thompson Sampling algorithm: The method ask_next_idx for newly created HierarchicalThompsonState should return model_name rather than model index!"

        child = self.generate_new_child(
            node, solver=solver, scorer=scorer, model_name=node_identifier
        )
        thompson_state.add_new_node(model_name=node_identifier, node=child)

        return child, node_identifier

    def select_or_generate_child(
        self,
        node: Node,
        solver: Solver,
        scorer: MCTSScorer,
    ) -> Tuple[Node, Optional[str]]:
        """
        Expansion is performed according to Thompson Sampling algotithm.
        In case the new node is generated, model_name is returned for the second argument.
        """
        thompson_state = self.thompson_states[node]

        thompson_state, child_identifier = thompson_state.ask_next_idx(
            self.all_rewards_store
        )
        if isinstance(child_identifier, int):
            return node.children[child_identifier], None
        else:
            child = self.generate_new_child(
                node, solver=solver, scorer=scorer, model_name=child_identifier
            )
            thompson_state.add_new_node(model_name=child_identifier, node=child)
            return child, child_identifier

    def backpropagate(self, node: Node, scorer: MCTSScorer, model_name: str) -> None:
        score = scorer.get_score(node=node)

        # Store reward to all_rewards_store for multiarm bandit case
        self.all_rewards_store[model_name].append(score)

        # NOTE: For the newly created node, we always update the score for GEN node
        assert node.parent is not None
        self.thompson_states[node.parent].tell_reward(
            reward=score, node_identifier=model_name
        )
        now = node.parent

        # Update scores for all the ancestors, this time for the generated node, not GEN node.
        while not now.is_root():
            assert now.parent is not None

            now_idx = now.parent.children.index(now)
            self.thompson_states[now.parent].tell_reward(
                reward=score, node_identifier=now_idx
            )

            now = now.parent

    def generate_new_child(
        self, node: Node, solver: Solver, scorer: MCTSScorer, model_name: str
    ) -> Node:
        if not isinstance(solver, AggregatedSolver):
            raise KeyError("Only AggregatedSolver is supported")

        solver.set_solver(solver_name=model_name)

        nodes, self.next_serial_number = solver.generate_child_nodes(
            node,
            self.config.actions[0],
            num_samples=1,
            scorer=scorer,
            next_serial_number=self.next_serial_number,
        )
        assert len(nodes) == 1

        return nodes[0]
