import copy
import dataclasses
from collections import defaultdict
from typing import Dict, List, Literal, Optional, Tuple, TypeVar

import numpy as np
from scipy.stats import beta, invgamma

from llm_mcts.mcts_algo.hierarchical_thompson.data_types import NodeIdentifier
from llm_mcts.mcts_algo.hierarchical_thompson.node_indices import ThompsonNodeIndices
from llm_mcts.mcts_algo.node import Node

T = TypeVar("T")


@dataclasses.dataclass
class BetaPrior:
    """
    The default is Jeffrey's prior
    """

    a: float = 0.5
    b: float = 0.5


@dataclasses.dataclass
class GaussianPrior:
    m: float = 0
    kappa: float = 1
    nu: float = 1
    tau_square: float = 1


@dataclasses.dataclass
class PriorConfig:
    dist_type: Literal["beta", "gaussian"] = "gaussian"
    prior: BetaPrior | GaussianPrior | Dict[str, float] | None = None

    def get_params(self) -> Dict[str, float]:
        # If no prior is provided, the default will be created
        if self.prior is None:
            if self.dist_type == "gaussian":
                default_prior = GaussianPrior()
            elif self.dist_type == "beta":
                default_prior = BetaPrior()
            else:
                raise NotImplementedError(f"dist_type {self.dist_type} not supported.")

            return dataclasses.asdict(default_prior)

        if self.dist_type == "gaussian":
            if isinstance(self.prior, dict):
                return dataclasses.asdict(GaussianPrior(**self.prior))
            elif isinstance(self.prior, GaussianPrior):
                return dataclasses.asdict(self.prior)
            else:
                raise ValueError(
                    f"Invalid prior {self.prior} for Gaussian Distribution."
                )
        elif self.dist_type == "beta":
            if isinstance(self.prior, dict):
                return dataclasses.asdict(BetaPrior(**self.prior))
            elif isinstance(self.prior, BetaPrior):
                return dataclasses.asdict(self.prior)
            else:
                raise ValueError(f"Invalid prior {self.prior} for Beta Distribution.")
        else:
            raise NotImplementedError(f"Invalid dist_type {self.dist_type}")


class ProbabilisticDist:
    def __init__(self, prior_config: Optional[PriorConfig] = None):
        if prior_config is None:
            prior_config = PriorConfig()
        self.dist_type = prior_config.dist_type

        self.all_obs: List[float] = []
        self.params = prior_config.get_params()
        self.prior_params = copy.deepcopy(self.params)

    def tell_observation(self, obs: float) -> None:
        if self.dist_type == "beta":
            assert obs >= 0 and obs <= 1
            self.params["a"] += obs
            self.params["b"] += 1 - obs
        elif self.dist_type == "gaussian":
            # See Section 3.4.3.3 of Murphy's textbook "Probabilistic Machine Learning: Advanced Topics": https://probml.github.io/pml-book/book2.html
            self.all_obs.append(obs)

            n = len(self.all_obs)
            ave = float(np.mean(self.all_obs))
            var = float(np.var(self.all_obs, ddof=0)) * n

            m = self.prior_params["m"]
            kappa = self.prior_params["kappa"]
            nu = self.prior_params["nu"]
            tau_square = self.prior_params["tau_square"]

            new_kappa = kappa + n
            new_nu = nu + n
            new_m = (kappa * m + n * ave) / new_kappa
            new_tau_square = (
                nu * tau_square + var + n * kappa * (m - ave) * (m - ave) / (kappa + n)
            ) / new_nu

            self.params = {
                "m": new_m,
                "kappa": new_kappa,
                "nu": new_nu,
                "tau_square": new_tau_square,
            }
        else:
            raise NotImplementedError()

    def draw_sample(self) -> float:
        if self.dist_type == "beta":
            return beta.rvs(self.params["a"], self.params["b"])
        elif self.dist_type == "gaussian":
            sigma_square = invgamma.rvs(
                self.params["nu"] / 2,
                scale=(self.params["nu"] * self.params["tau_square"]) / 2,
            )
            mu = np.random.normal(
                self.params["m"], np.sqrt(sigma_square / self.params["kappa"])
            )
            return mu
        else:
            raise NotImplementedError()


class HierarchicalThompsonState:
    def __init__(
        self,
        model_names: List[str],
        prior_config: Optional[PriorConfig] = None,
    ):
        if prior_config is None:
            prior_config = PriorConfig()
        # Matrices and vectors initialization for LinUCB
        self.model_probas = {
            model_name: ProbabilisticDist(prior_config) for model_name in model_names
        }

        self.gen_and_cont_probas = {
            model_name: {
                "GEN": ProbabilisticDist(prior_config),
                "CONT": ProbabilisticDist(prior_config),
            }
            for model_name in model_names
        }

        self.node_probas: Dict[str, Dict[int, ProbabilisticDist]] = {
            model_name: defaultdict(lambda: ProbabilisticDist(prior_config))
            for model_name in model_names
        }

        self.node_indices = ThompsonNodeIndices(model_names)

    def tell_reward(
        self,
        reward: float,
        node_identifier: NodeIdentifier,
        node_indices: Optional[ThompsonNodeIndices] = None,
    ) -> None:
        """We reflect the reward information to thompson state, and update the node indices"""
        if node_indices is not None:
            self.node_indices = node_indices

        if isinstance(node_identifier, str):
            model_name = node_identifier
            self.gen_and_cont_probas[model_name]["GEN"].tell_observation(reward)
        elif isinstance(node_identifier, int):
            model_name = self.node_indices.get_model_name(node_identifier)

            self.node_probas[model_name][node_identifier].tell_observation(reward)
            self.gen_and_cont_probas[model_name]["CONT"].tell_observation(reward)
        else:
            raise RuntimeError(f"Invalid node_identifier {node_identifier}")

        self.model_probas[model_name].tell_observation(reward)

    def ask_next_idx(self) -> Tuple["HierarchicalThompsonState", NodeIdentifier]:
        """
        Main part of the algorithm; calculate the probability and retrieve the next idx.
        The creation of the new node and adding that to node indices is the responsibility of the caller of this function.
        """
        model_name = self.thompson_sampling(self.model_probas)
        gen_or_cont = self.thompson_sampling(self.gen_and_cont_probas[model_name])
        if gen_or_cont == "GEN" or len(self.node_probas[model_name]) == 0:
            return self, model_name
        elif gen_or_cont == "CONT":
            node_idx = self.thompson_sampling(self.node_probas[model_name])
            return self, node_idx
        else:
            raise RuntimeError(f"Internal Error! Invalid gen_or_cont {gen_or_cont}")

    def thompson_sampling(self, probas: Dict[T, ProbabilisticDist]) -> T:
        max_name = None
        max_val = None
        for name in probas:
            val = probas[name].draw_sample()
            if max_val is None or val > max_val:
                max_name = name
                max_val = val
        assert max_name is not None

        return max_name

    def add_new_node(self, model_name: str, node: Node) -> None:
        new_idx = self.node_indices.add_new_node(model_name)

        if node.eval_results is None:
            new_node_score = 0.0
        else:
            new_node_score = float(
                np.mean([eval_result.get_score() for eval_result in node.eval_results])
            )
        self.node_probas[model_name][new_idx].tell_observation(new_node_score)
