import copy
import dataclasses
from collections import defaultdict
from functools import partial
from math import log, sqrt
from typing import Dict, List, Literal, Optional, Tuple, TypeVar

import numpy as np
from scipy.stats import beta, invgamma

from llm_mcts.mcts_algo.hierarchical_thompson.data_types import NodeIdentifier
from llm_mcts.mcts_algo.hierarchical_thompson.node_indices import ThompsonNodeIndices
from llm_mcts.mcts_algo.node import Node

T = TypeVar("T")


@dataclasses.dataclass
class BetaPrior:
    """
    The default is Jeffrey's prior
    """

    a: float = 0.5
    b: float = 0.5


@dataclasses.dataclass
class GaussianPrior:
    m: float = 0
    kappa: float = 1
    nu: float = 1
    tau_square: float = 0.1


@dataclasses.dataclass
class PriorConfig:
    dist_type: Literal["beta", "gaussian"] = "gaussian"
    prior: BetaPrior | GaussianPrior | Dict[str, float] | None = None

    def get_params(self) -> Dict[str, float]:
        # If no prior is provided, the default will be created
        if self.prior is None:
            if self.dist_type == "gaussian":
                default_prior = GaussianPrior()
            elif self.dist_type == "beta":
                default_prior = BetaPrior()
            else:
                raise NotImplementedError(f"dist_type {self.dist_type} not supported.")

            return dataclasses.asdict(default_prior)

        if self.dist_type == "gaussian":
            if isinstance(self.prior, dict):
                return dataclasses.asdict(GaussianPrior(**self.prior))
            elif isinstance(self.prior, GaussianPrior):
                return dataclasses.asdict(self.prior)
            else:
                raise ValueError(
                    f"Invalid prior {self.prior} for Gaussian Distribution."
                )
        elif self.dist_type == "beta":
            if isinstance(self.prior, dict):
                return dataclasses.asdict(BetaPrior(**self.prior))
            elif isinstance(self.prior, BetaPrior):
                return dataclasses.asdict(self.prior)
            else:
                raise ValueError(f"Invalid prior {self.prior} for Beta Distribution.")
        else:
            raise NotImplementedError(f"Invalid dist_type {self.dist_type}")

    def set_reward_average_prior(self, reward_average_prior: float) -> None:
        prior = self.prior
        if isinstance(prior, dict):
            if self.dist_type == "gaussian":
                prior = dataclasses.asdict(GaussianPrior(**prior))
            else:
                prior = dataclasses.asdict(BetaPrior(**prior))
        elif prior is None:
            if self.dist_type == "gaussian":
                prior = GaussianPrior()
            else:
                prior = BetaPrior()

        if isinstance(prior, GaussianPrior):
            prior.m = reward_average_prior
        elif isinstance(prior, BetaPrior):
            prior.a = reward_average_prior
            prior.b = 1 - reward_average_prior
        else:
            raise ValueError(f"Invalid prior {prior}.")

        self.prior = prior


class ProbabilisticDist:
    def __init__(self, prior_config: Optional[PriorConfig] = None):
        if prior_config is None:
            prior_config = PriorConfig()
        self.dist_type = prior_config.dist_type

        self.all_obs: List[float] = []
        self.params = prior_config.get_params()
        self.prior_params = copy.deepcopy(self.params)

    def tell_observation(self, obs: float) -> None:
        if self.dist_type == "beta":
            assert obs >= 0 and obs <= 1
            self.params["a"] += obs
            self.params["b"] += 1 - obs
        elif self.dist_type == "gaussian":
            # See Section 3.4.3.3 of Murphy's textbook "Probabilistic Machine Learning: Advanced Topics": https://probml.github.io/pml-book/book2.html
            self.all_obs.append(obs)

            n = len(self.all_obs)
            ave = float(np.mean(self.all_obs))
            var = float(np.var(self.all_obs, ddof=0)) * n

            m = self.prior_params["m"]
            kappa = self.prior_params["kappa"]
            nu = self.prior_params["nu"]
            tau_square = self.prior_params["tau_square"]

            new_kappa = kappa + n
            new_nu = nu + n
            new_m = (kappa * m + n * ave) / new_kappa
            new_tau_square = (
                nu * tau_square + var + n * kappa * (m - ave) * (m - ave) / (kappa + n)
            ) / new_nu

            self.params = {
                "m": new_m,
                "kappa": new_kappa,
                "nu": new_nu,
                "tau_square": new_tau_square,
            }
        else:
            raise NotImplementedError()

    def draw_sample(self) -> float:
        if self.dist_type == "beta":
            return beta.rvs(self.params["a"], self.params["b"])
        elif self.dist_type == "gaussian":
            sigma_square = invgamma.rvs(
                self.params["nu"] / 2,
                scale=(self.params["nu"] * self.params["tau_square"]) / 2,
            )
            mu = np.random.normal(
                self.params["m"], np.sqrt(sigma_square / self.params["kappa"])
            )
            return mu
        else:
            raise NotImplementedError()


def build_default_prob_dist(prior_config):
    """
    A trick to dump pickle for NodeProbState.

    Module level function to avoid pickle dump error
    """
    return ProbabilisticDist(prior_config)


class HierarchicalThompsonState:
    def __init__(
        self,
        model_names: List[str],
        prior_config: Optional[PriorConfig] = None,
        reward_average_priors: Optional[float | Dict[str, float]] = None,
        model_selection_strategy: Literal[
            "stack", "multiarm_bandit_thompson", "multiarm_bandit_ucb"
        ] = "stack",
    ):
        if prior_config is None:
            prior_config = PriorConfig()

        self.model_names = model_names

        # Strategy for model selection:
        # "stack": Perform separate fits for each model (traditional approach)
        # "multiarm_bandit_thompson": Use Thompson Sampling for joint selection
        # "multiarm_bandit_ucb": Use UCB for joint selection
        if model_selection_strategy not in [
            "stack",
            "multiarm_bandit_thompson",
            "multiarm_bandit_ucb",
        ]:
            raise ValueError(
                f"Invalid model_selection_strategy: {model_selection_strategy}. "
                f"Must be one of: 'stack', 'multiarm_bandit_thompson', 'multiarm_bandit_ucb'"
            )
        self.model_selection_strategy = model_selection_strategy

        self.prior_for_models = dict()
        self.model_probas = dict()
        for model_name in model_names:
            prior_for_model = copy.deepcopy(prior_config)
            # Overrides prior reward average values
            if isinstance(reward_average_priors, float):
                reward_average_prior = reward_average_priors
            elif isinstance(reward_average_priors, dict):
                reward_average_prior = reward_average_priors[model_name]
            # In case override value is not specified, use the default prior
            else:
                self.model_probas[model_name] = ProbabilisticDist(prior_for_model)
                continue

            assert (
                reward_average_prior > 0 and reward_average_prior < 1
            ), f"reward_average_prior {reward_average_prior} is not in the range 0 < reward_average_prior < 1"

            prior_for_model.set_reward_average_prior(reward_average_prior)

            self.prior_for_models[model_name] = prior_for_model
            self.model_probas[model_name] = ProbabilisticDist(
                self.prior_for_models[model_name]
            )

        if model_selection_strategy.startswith("multiarm_bandit_"):
            # For multiarm bandit strategies, use shared GEN/CONT distributions across all models
            self.gen_and_cont_probas = {
                "shared": {
                    "GEN": ProbabilisticDist(prior_config),
                    "CONT": ProbabilisticDist(prior_config),
                }
            }
            # Store node probabilistic distributions for each model
            self.node_probas: Dict[str, Dict[int, ProbabilisticDist]] = {
                "shared": defaultdict(partial(build_default_prob_dist, prior_config))
            }
        else:
            # For stack strategy, use separate GEN/CONT distributions for each model
            self.gen_and_cont_probas = {
                model_name: {
                    "GEN": ProbabilisticDist(prior_config),
                    "CONT": ProbabilisticDist(prior_config),
                }
                for model_name in model_names
            }

            self.node_probas = {
                model_name: defaultdict(partial(build_default_prob_dist, prior_config))
                for model_name in model_names
            }

        self.node_indices = ThompsonNodeIndices(model_names)

    def tell_reward(
        self,
        reward: float,
        node_identifier: NodeIdentifier,
        node_indices: Optional[ThompsonNodeIndices] = None,
    ) -> None:
        """We reflect the reward information to thompson state, and update the node indices"""
        if node_indices is not None:
            self.node_indices = node_indices

        if isinstance(node_identifier, str):
            model_name = node_identifier
            if self.model_selection_strategy.startswith("multiarm_bandit_"):
                # For multiarm bandit strategies, update shared GEN distribution
                self.gen_and_cont_probas["shared"]["GEN"].tell_observation(reward)
            else:
                # For stack strategy, update model-specific GEN distribution
                self.gen_and_cont_probas[model_name]["GEN"].tell_observation(reward)

        elif isinstance(node_identifier, int):
            model_name = self.node_indices.get_model_name(node_identifier)

            if self.model_selection_strategy.startswith("multiarm_bandit_"):
                # Update the node's probability
                self.node_probas["shared"][node_identifier].tell_observation(reward)
                # For multiarm bandit strategies, update shared CONT distribution
                self.gen_and_cont_probas["shared"]["CONT"].tell_observation(reward)
            else:
                # Update the node's probability
                self.node_probas[model_name][node_identifier].tell_observation(reward)
                # For stack strategy, update model-specific CONT distribution
                self.gen_and_cont_probas[model_name]["CONT"].tell_observation(reward)

        else:
            raise RuntimeError(f"Invalid node_identifier {node_identifier}")

        self.model_probas[model_name].tell_observation(reward)

    def ask_next_idx(
        self, all_rewards_store: Dict[str, list[float]]
    ) -> Tuple["HierarchicalThompsonState", NodeIdentifier]:
        """
        Main part of the algorithm; calculate the probability and retrieve the next idx.
        The creation of the new node and adding that to node indices is the responsibility of the caller of this function.
        """
        if self.model_selection_strategy == "stack":
            return self._ask_next_idx_stack()
        elif self.model_selection_strategy == "multiarm_bandit_thompson":
            return self._ask_next_ids_multiarm_bandit(
                strategy="thompson", all_rewards_store=all_rewards_store
            )
        elif self.model_selection_strategy == "multiarm_bandit_ucb":
            return self._ask_next_ids_multiarm_bandit(
                strategy="ucb", all_rewards_store=all_rewards_store
            )
        else:
            raise ValueError(
                f"Unknown model_selection_strategy: {self.model_selection_strategy}"
            )

    def _ask_next_idx_stack(self) -> Tuple["HierarchicalThompsonState", NodeIdentifier]:
        model_name = self.thompson_sampling(self.model_probas)
        gen_or_cont = self.thompson_sampling(self.gen_and_cont_probas[model_name])
        if gen_or_cont == "GEN" or len(self.node_probas[model_name]) == 0:
            return self, model_name
        elif gen_or_cont == "CONT":
            node_idx = self.thompson_sampling(self.node_probas[model_name])
            return self, node_idx
        else:
            raise RuntimeError(f"Internal Error! Invalid gen_or_cont {gen_or_cont}")

    def _select_best_model(
        self, strategy: str, all_rewards_store: Dict[str, List[float]]
    ) -> str:

        # Handle the case where reward for some models is empty
        for model_name in self.model_names:
            if model_name not in all_rewards_store:
                all_rewards_store[model_name] = []

        # For single model case, we just return that model.
        if len(all_rewards_store) == 1:
            return next(iter(all_rewards_store))

        all_len = sum([len(v) for k, v in all_rewards_store.items()])

        model_scores = dict()
        if strategy == "thompson":
            for model_name, scores in all_rewards_store.items():
                pd = ProbabilisticDist(self.prior_for_models[model_name])
                for reward in all_rewards_store[model_name]:
                    pd.tell_observation(reward)
                model_scores[model_name] = pd.draw_sample()
        elif strategy == "ucb":
            for model_name, scores in all_rewards_store.items():
                ucb_score = sum(scores) / len(scores) + sqrt(2) * sqrt(
                    log(all_len) / len(scores)
                )
                model_scores[model_name] = ucb_score
            pass
        else:
            raise ValueError(
                f"Invalid strategy {strategy}, it should be either ucb or thompson"
            )
        return max(model_scores, key=model_scores.get)

    def _ask_next_ids_multiarm_bandit(
        self, strategy: str, all_rewards_store: Dict[str, List[float]]
    ) -> Tuple["HierarchicalThompsonState", NodeIdentifier]:
        """
        Multi-armed bandit approach: two-step decision process.

        First, decide between GEN or CONT using Thompson sampling.
        If CONT is chosen, return the node. If GEN is chosen, use multiarm bandit to pick the best model.

        Args:
            strategy: Either "thompson" for Thompson sampling or "ucb" for Upper Confidence Bound

        Returns:
            Either an model name (str) for generating a new node,
            or a node index (int) for continuing with an existing node
        """
        # Step 1: Decide between GEN or CONT using shared distributions based on Thompson Sampling
        gen_cont_options = self.gen_and_cont_probas["shared"]
        choice = self.thompson_sampling(gen_cont_options)

        if choice == "GEN" or len(self.node_probas["shared"]) == 0:
            # Step 2a: GEN was chosen, now pick the best model using multiarm bandit
            return self, self._select_best_model(
                strategy=strategy, all_rewards_store=all_rewards_store
            )
        else:  # CONT
            return self, self.thompson_sampling(self.node_probas["shared"])

    def thompson_sampling(self, probas: Dict[T, ProbabilisticDist]) -> T:
        max_name = None
        max_val = None
        for name in probas:
            val = probas[name].draw_sample()
            if max_val is None or val > max_val:
                max_name = name
                max_val = val
        assert max_name is not None

        return max_name

    def add_new_node(self, model_name: str, node: Node) -> None:
        new_idx = self.node_indices.add_new_node(model_name)

        if node.eval_results is None:
            new_node_score = 0.0
        else:
            new_node_score = float(
                np.mean([eval_result.get_score() for eval_result in node.eval_results])
            )

        if self.model_selection_strategy.startswith("multiarm_bandit_"):
            self.node_probas["shared"][new_idx].tell_observation(new_node_score)
        else:
            self.node_probas[model_name][new_idx].tell_observation(new_node_score)
