from dataclasses import dataclass
from typing import Dict, Literal, Optional, Tuple

from llm_mcts.mcts_algo.algo_base import MCTSAlgo
from llm_mcts.mcts_algo.mcts_config import MCTSConfig
from llm_mcts.mcts_algo.node import Node
from llm_mcts.mcts_algo.pymc_mixed.pymc_interface import Observation, PyMCInterface
from llm_mcts.mcts_algo.solver.base import AggregatedSolver, Solver
from llm_mcts.mcts_scorer.base import MCTSScorer


@dataclass
class PyMCMixedAlgoConfig:
    enable_pruning: bool = True
    algo: Literal["thompson"] = "thompson"

    reward_average_priors: Optional[float | Dict[str, float]] = None
    model_selection_strategy: Literal[
        "stack", "multiarm_bandit_thompson", "multiarm_bandit_ucb"
    ] = "stack"


class PyMCMixedAlgo(MCTSAlgo):
    """
    Sampling Method based on Mixed Models using PyMC.
    """

    def __init__(
        self, config: MCTSConfig, pymc_config: Optional[PyMCMixedAlgoConfig] = None
    ):
        self.config = config
        assert (
            len(self.config.actions) == 1
        ), "Only a single action is currently supported by PyMCThompsonAlgo"

        if pymc_config is None:
            pymc_config = PyMCMixedAlgoConfig()

        self.all_observations: Dict[Node, Observation] = dict()

        self.pymc_interface = PyMCInterface(
            algo=pymc_config.algo,
            enable_pruning=pymc_config.enable_pruning,
            reward_average_priors=pymc_config.reward_average_priors,
            model_selection_strategy=pymc_config.model_selection_strategy,
        )
        self.next_serial_number = 1  # NOTE: 0 is reserved for the root node

    def run_mcts_step(self, root: Node, solver: Solver, scorer: MCTSScorer) -> None:
        if not isinstance(solver, AggregatedSolver):
            raise KeyError("Only AggregatedSolver is supported")

        if root.is_root() and not root.is_expanded():
            self.expand(
                root,
                solver,
                scorer=scorer,
            )
        # When num_simulations is 0, only the root node is expanded
        if self.config.num_simulations == 0:
            return

        node = root

        while node.is_expanded():
            node, node_model_name = self.select_or_generate_child(
                node, solver=solver, scorer=scorer
            )
            # New node is generated by selecting GEN node, so the step is finished
            if node_model_name is not None:
                return

        # Expand the last step node
        node, node_model_name = self.expand(
            node,
            solver,
            scorer=scorer,
        )

    def expand(
        self,
        node: Node,
        solver: Solver,
        scorer: MCTSScorer,
    ) -> Tuple[Node, str]:
        if not isinstance(solver, AggregatedSolver):
            raise KeyError("Only AggregatedSolver is supported")

        node_identifier = self.pymc_interface.run(
            Observation.collect_all_observations_of_descendant(
                node, self.all_observations
            ),
            model_names=solver.get_solvers(),
            node=node,
            all_observations=list(self.all_observations.values()),
        )

        assert isinstance(
            node_identifier, str
        ), f"Internal Error: Something went wrong with Sampling: The method ask_next_idx for newly created HierarchicalThompsonState should return model_name rather than model index {node_identifier}!"

        child = self.generate_new_child(
            node, solver=solver, scorer=scorer, model_name=node_identifier
        )

        return child, node_identifier

    def select_or_generate_child(
        self,
        node: Node,
        solver: Solver,
        scorer: MCTSScorer,
    ) -> Tuple[Node, Optional[str]]:
        """
        In case a new node is generated, model_name is returned for the second argument.
        """
        if not isinstance(solver, AggregatedSolver):
            raise KeyError("Only AggregatedSolver is supported")

        child_identifier = self.pymc_interface.run(
            Observation.collect_all_observations_of_descendant(
                node, self.all_observations
            ),
            model_names=solver.get_solvers(),
            node=node,
            all_observations=list(self.all_observations.values()),
        )
        if isinstance(child_identifier, int):
            return node.children[child_identifier], None
        else:
            child = self.generate_new_child(
                node, solver=solver, scorer=scorer, model_name=child_identifier
            )
            return child, child_identifier

    def generate_new_child(
        self, node: Node, solver: Solver, scorer: MCTSScorer, model_name: str
    ) -> Node:
        if not isinstance(solver, AggregatedSolver):
            raise KeyError("Only AggregatedSolver is supported")

        solver.set_solver(solver_name=model_name)

        nodes, self.next_serial_number = solver.generate_child_nodes(
            node,
            self.config.actions[0],
            num_samples=1,
            scorer=scorer,
            next_serial_number=self.next_serial_number,
        )
        assert len(nodes) == 1

        new_node = nodes[0]
        self.all_observations[new_node] = Observation(
            reward=scorer.get_score(new_node), model_name=model_name, node=new_node
        )
        return new_node
