import os
import random
from math import log, sqrt

# Avoid file lock error
os.environ["PYTENSOR_FLAGS"] = (
    f"compiledir_format=compiledir_{os.getpid()},base_compiledir={os.path.expanduser('~')}/.pytensor/compiledir_llm-mcts"
)


from collections import Counter
from dataclasses import dataclass, replace
from typing import Dict, List, Literal, Optional, Tuple

import jax
import numpy as np
import numpyro
import pandas as pd
import pymc as pm
from pymc.sampling.jax import sample_numpyro_nuts
from xarray import DataArray

from llm_mcts.mcts_algo.node import Node

numpyro.set_host_device_count(1)  # Use only 1 core for sample_numpyro_nuts
numpyro.set_platform("cpu")  # Use CPU rather than GPU for sample_numpyro_nuts


@dataclass(frozen=True)
class Observation:
    reward: float
    model_name: str
    node: Node

    # Linear model group index to be used by PyMC; It is not used or set from MCTS side
    child_idx: int = -1

    @classmethod
    def to_pandas(
        cls, observations: List["Observation"], model_name: Optional[str]
    ) -> pd.DataFrame:
        observations_with_id = []

        # We iterate over observation to avoid time-consuming dataclasses.asdict operation
        for idx, observation in enumerate(observations):
            if model_name is None or observation.model_name == model_name:
                observations_with_id.append(
                    {
                        "obs_id": idx,
                        "reward": observation.reward,
                        "child_idx": observation.child_idx,
                        "model_name": observation.model_name,
                    }
                )
        return pd.DataFrame(observations_with_id)

    @classmethod
    def collect_all_observations_of_descendant(
        cls,
        parent: Node,
        all_observations: Dict[Node, "Observation"],
    ) -> List["Observation"]:
        """
        A helper method to collect all the decsendant observations.
        """
        observations: List["Observation"] = []
        for child_idx, child in enumerate(parent.children):
            cls.collect_all_observations(
                child,
                all_observations,
                child_idx_override=child_idx,
                observations=observations,
            )

        return observations

    @classmethod
    def collect_all_observations(
        cls,
        parent: Node,
        all_observations: Dict[Node, "Observation"],
        child_idx_override: int,
        observations: List["Observation"],
    ) -> None:
        """
        A helper method to collect all the observations (includeing the one of the current node) to list.
        """
        observation = replace(all_observations[parent], child_idx=child_idx_override)
        observations.append(observation)
        for child_idx, child in enumerate(parent.children):
            cls.collect_all_observations(
                child,
                all_observations,
                child_idx_override=child_idx_override,
                observations=observations,
            )


@dataclass
class PruningConfig:
    # min subtree size where pruning is enabled. If it is 4, the subtree with size >=4 is amenable to pruning
    min_subtree_size_for_pruning: int = 4

    # subtree is pruned if (max number of nodes with the same score inside subtree) / (# nodes of subtree) >= same_score_proportion_threshold
    same_score_proportion_threshold: float = 0.75


def is_prunable(
    node: Node, observations: List["Observation"], pruning_config: PruningConfig
) -> bool:
    """
    Check if the subtree where the node is parent is prunable from next search.
    """
    scores = []
    for obs in observations:
        obs_node = obs.node
        while obs_node.parent is not None:
            if obs_node == node:
                scores.append(int(round(obs.reward * 100)))
                break
            obs_node = obs_node.parent

    _max_element, max_count = Counter(scores).most_common(1)[0]
    if (
        max_count / len(scores) >= pruning_config.same_score_proportion_threshold
        and len(scores) >= pruning_config.min_subtree_size_for_pruning
    ):
        return True
    else:
        return False


class PyMCInterface:
    """
    We leverage PyMC to perform (1) parameter fitting and (2) prediction.

    We use the fact that we can use different statistical models for (1) and (2) to make predictions for unobserved group (i.e. GEN node).

    See https://www.pymc-labs.com/blog-posts/out-of-model-predictions-with-pymc/ for pedagogical introduction for using different models for (1) and (2),
    and see https://www.pymc.io/projects/examples/en/latest/generalized_linear_models/multilevel_modeling.html for hierarchical modeling in PyMC.
    """

    called_number: int = 0

    def __init__(
        self,
        algo: Literal["thompson"] = "thompson",
        enable_pruning: bool = True,
        pruning_config: Optional[PruningConfig] = None,
        reward_average_priors: Optional[float | Dict[str, float]] = None,
        model_selection_strategy: str = "stack",
    ):
        self.algo = algo
        self.enable_pruning = enable_pruning
        self.pruning_config = (
            pruning_config if pruning_config is not None else PruningConfig()
        )
        self.reward_average_priors = (
            reward_average_priors if reward_average_priors is not None else dict()
        )

        # Strategy for model selection:
        # "stack": Perform separate fits for each model (traditional approach)
        # "multiarm_bandit_thompson": Use Thompson Sampling for joint selection
        # "multiarm_bandit_ucb": Use UCB for joint selection (not implemented yet)
        if model_selection_strategy not in [
            "stack",
            "multiarm_bandit_thompson",
            "multiarm_bandit_ucb",
        ]:
            raise ValueError(
                f"Invalid model_selection_strategy: {model_selection_strategy}. "
                f"Must be one of: 'stack', 'multiarm_bandit_thompson', 'multiarm_bandit_ucb'"
            )
        self.model_selection_strategy = model_selection_strategy

    def run(
        self,
        observations: List[Observation],
        model_names: List[str],
        node: Node,
        all_observations: List[Observation],
    ) -> str | int:
        """
        Main entry point of ABMCTS-M.
        Returns the model_name in case GEN Node is chosen; otherwise return child_idx

        Three strategies are supported:
        1. "stack": Perform separate fits for each model and select the highest scoring option
           (either a child node or new node generation).
        2. "multiarm_bandit_thompson": First decide whether to select a child or generate a new node
           using Thompson Sampling. If generating a new node, use Thompson Sampling across
           model scores using all observations from the tree.
        3. "multiarm_bandit_ucb": First decide whether to select a child or generate a new node
           using Upper Confidence Bound (not implemented yet).
        """
        if self.model_selection_strategy == "stack":
            return self._run_stacked_strategy(observations, model_names, node)
        elif self.model_selection_strategy == "multiarm_bandit_thompson":
            return self._run_multiarm_bandit_strategy(
                observations, model_names, node, all_observations, strategy="thompson"
            )
        elif self.model_selection_strategy == "multiarm_bandit_ucb":
            return self._run_multiarm_bandit_strategy(
                observations, model_names, node, all_observations, strategy="ucb"
            )
        else:
            raise ValueError(
                f"Unknown model_selection_strategy: {self.model_selection_strategy}"
            )

    def _run_stacked_strategy(
        self,
        observations: List[Observation],
        model_names: List[str],
        node: Node,
    ) -> str | int:
        scores: Dict[str | int, float] = dict()
        for model_name in model_names:
            scores |= self.calculate_score(observations, model_name)

        sorted_scores = dict(
            sorted(scores.items(), key=lambda item: item[1], reverse=True)
        )
        for identifier in sorted_scores:
            if isinstance(identifier, str):
                return identifier
            else:
                child_node = node.children[identifier]
                if self.enable_pruning and is_prunable(
                    child_node, observations, self.pruning_config
                ):
                    continue
                return identifier

        raise RuntimeError(
            f"Internal Error: Failed to get best option from {sorted_scores}"
        )

    def _select_best_model(
        self,
        all_observations: List[Observation],
        model_names: List[str],
        strategy: Literal["thompson", "ucb"],
    ) -> str:
        """
        Select the best model using Thompson Sampling or UCB across all observations.
        This is used when we've decided to generate a new node.

        Args:
            all_observations: List of all observations from the tree
            model_names: List of model_names

        Returns:
            The best action to use for generating a new node
        """

        if strategy == "thompson":
            if len(all_observations) == 0:
                return random.choice(model_names)

            observed_model_names, _rewards, _coords = (
                self.preprocess_observations_for_multiarm_bandit(all_observations)
            )
            fitting_model = self._build_model_for_multiarm_bandit(
                all_observations, is_prediction_model=False
            )
            pred_model = self._build_model_for_multiarm_bandit(
                all_observations, is_prediction_model=True
            )

            # We use numpyro for sampling; It may use some amount of CPU resource
            with fitting_model:
                model_trace = sample_numpyro_nuts(
                    chains=4,
                    compute_convergence_checks=False,
                    idata_kwargs=dict(log_likelihood=False),
                    progressbar=False,
                )

            # Using the model_trace which includes sampled posterior information of parameters, we predict y (reward) values for GEN node and children nodes.
            # y represents the child node reward, and y_new represnets GEN node reward
            with pred_model:
                pred_model_trace = pm.sample_posterior_predictive(
                    model_trace, var_names=["y", "y_new"], progressbar=False
                )

            model_scores: Dict[str, float] = dict()
            for model_name in model_names:
                if model_name in observed_model_names:
                    model_scores[model_name] = self.get_score(
                        pred_model_trace.posterior_predictive.y.sel(
                            model_name=model_name
                        )
                    )
                else:
                    model_scores[model_name] = self.get_score(
                        pred_model_trace.posterior_predictive.y_new
                    )
        else:
            # Calculate score for each model using Thompson Sampling or UCB
            model_scores: Dict[str, float] = dict()
            for model_name in model_names:
                scores = [
                    observation.reward
                    for observation in all_observations
                    if observation.model_name == model_name
                ]
                ucb_score = sum(scores) / len(scores) + sqrt(2) * sqrt(
                    log(len(all_observations)) / len(scores)
                )
                model_scores[model_name] = ucb_score

        # Return the best action
        return max(model_scores, key=model_scores.get)

    def _run_multiarm_bandit_strategy(
        self,
        observations: List[Observation],
        model_names: List[str],
        node: Node,
        all_observations: List[Observation],
        strategy: Literal["thompson", "ucb"],
    ) -> str | int:
        """
        New multiarm bandit strategy with Thompson Sampling

        This is a two-step decision process:
        1. First decide between selecting an existing child or generating a new node (GEN)
           using Thompson Sampling with model_name=None.
        2. Only if GEN node is chosen, then decide which model_name to use based on
           Thompson Sampling or UCB across all observations from the tree.
        """
        # First step: decide between existing child nodes and generating a new node (GEN)
        # Calculate scores for both child nodes and GEN node with model_name=None
        scores = self.calculate_score(observations, model_name=None)

        # Determine if we should use a child node or generate a new node
        sorted_scores = dict(
            sorted(scores.items(), key=lambda item: item[1], reverse=True)
        )

        # Get the highest scoring option
        if not sorted_scores:
            raise RuntimeError("Internal Error: No scores calculated")

        best_option = next(iter(sorted_scores))

        # If the best option is to generate a new node
        if (
            best_option is None
        ):  # None represents the GEN node in calculate_score when action=None
            # Second step: decide which action to use
            return self._select_best_model(
                all_observations, model_names, strategy=strategy
            )
        else:
            # Otherwise, check if this child should be pruned
            child_node = node.children[best_option]
            if self.enable_pruning and is_prunable(
                child_node, observations, self.pruning_config
            ):
                # If pruned, get the next best option
                for identifier in sorted_scores:
                    if identifier is None:  # Skip the GEN node
                        continue
                    child_node = node.children[identifier]
                    if not (
                        self.enable_pruning
                        and is_prunable(child_node, observations, self.pruning_config)
                    ):
                        return identifier

                # If all child nodes are pruned, select best action for GEN node
                return self._select_best_model(
                    all_observations, model_names, strategy=strategy
                )

            # Return the child index
            return best_option

    def calculate_score(
        self, observations: List[Observation], model_name: Optional[str]
    ) -> Dict[str | int | None, float]:
        _child_indices, _rewards, coords = self.preprocess_observations(
            observations, model_name=model_name
        )

        # In case observations for model_name is empty, we sample from prior predictive
        # Prior Predictive Sampling START
        if len(coords) == 0:
            prior_model = self.build_fitting_model(
                observations, model_name, is_prior_model=True
            )
            with prior_model:
                prior_model_trace = pm.sample_prior_predictive(var_names=["y"])

            return {model_name: self.get_score(prior_model_trace.prior.y)}
        # Prior Predictive Sampling END

        fitting_model = self.build_fitting_model(observations, model_name)
        pred_model = self.build_prediction_model(observations, model_name)

        # We use numpyro for sampling; It may use some amount of CPU resource
        with fitting_model:
            model_trace = sample_numpyro_nuts(
                chains=4,
                compute_convergence_checks=False,
                idata_kwargs=dict(log_likelihood=False),
                progressbar=False,
            )

        # Using the model_trace which includes sampled posterior information of parameters, we predict y (reward) values for GEN node and children nodes.
        # y represents the child node reward, and y_new represnets GEN node reward
        with pred_model:
            pred_model_trace = pm.sample_posterior_predictive(
                model_trace, var_names=["y", "y_new"], progressbar=False
            )

        scores: Dict[str | int | None, float] = dict()
        for child_idx in coords["child_idx"]:
            scores[child_idx] = self.get_score(
                pred_model_trace.posterior_predictive.y.sel(child_idx=child_idx)
            )

        scores[model_name] = self.get_score(pred_model_trace.posterior_predictive.y_new)

        self.called_number += 1
        # Clear jax cache to avoid memory leak
        # numpyro sampling leads to memory leak, so we delete cached jax arrays here
        if self.called_number % 10 == 0:
            jax.clear_caches()
            for x in jax.live_arrays():
                x.delete()

        return scores

    def get_score(self, arr: DataArray) -> float:
        if self.algo == "thompson":
            return np.random.choice(arr.values.flatten())
        else:
            raise NotImplementedError(
                f"Algo type {self.algo} is not supported by PyMCInterface"
            )

    def build_fitting_model(
        self,
        observations: List[Observation],
        model_name: Optional[str],
        is_prior_model: bool = False,
    ) -> pm.Model:
        """
        Build Hierarchical PyMC model for (1) parameter fitting.
        See tests/figures/fitting_model.jpg for model structure.
        """
        return self._build_model_impl(
            observations,
            model_name,
            is_prediction_model=False,
            is_prior_model=is_prior_model,
        )

    def build_prediction_model(
        self,
        observations: List[Observation],
        model_name: Optional[str],
    ) -> pm.Model:
        """
        Build Hierarchical PyMC model for (2) prediction.
        See tests/figures/prediction_model.jpg for the model structure.
        """
        return self._build_model_impl(
            observations, model_name, is_prediction_model=True
        )

    def _build_model_for_multiarm_bandit(
        self,
        observations: List[Observation],
        is_prior_model: bool = False,
        is_prediction_model: bool = False,
    ) -> pm.Model:
        model_names, rewards, coords = self.preprocess_observations_for_multiarm_bandit(
            observations
        )

        with pm.Model(coords=coords if not is_prior_model else None) as model:
            # Priors START
            # Overall difficulty of the problem itself; mu is set to be 0.5 (50% prob of solving the problem)
            mu_alpha = pm.Normal("mu_alpha", mu=0.5, sigma=0.2)

            # expresses the strength of score fluctuation across models
            sigma_alpha = pm.HalfNormal("sigma_alpha", sigma=0.3)

            # expresses the strength of score fluctuation inside a model
            sigma_y = pm.HalfNormal("sigma_y", sigma=0.2)
            # Priors END

            group_dims = "model_name" if not is_prior_model else None

            # We use non-centered parameterization (see https://sjster.github.io/introduction_to_computational_statistics/docs/Production/Reparameterization.html)
            z_alpha = pm.Normal("z_alpha", mu=0, sigma=1, dims=group_dims)
            alpha = mu_alpha + z_alpha * sigma_alpha

            if is_prior_model:
                y_hat = alpha
                # This value is used by sample_prior_predictive
                _ = pm.Normal("y", mu=y_hat, sigma=sigma_y)  # noqa: F841
            elif not is_prediction_model:
                y_hat = alpha[model_names]
                # This observation is used for fitting
                _ = pm.Normal(
                    "y", mu=y_hat, sigma=sigma_y, observed=rewards
                )  # noqa: F841
            else:
                # Expected value
                # For prediction, we sample distribution of y for each model_name
                y_hat = alpha[list(range(len(coords["model_name"])))]

                y = pm.Normal("y", mu=y_hat, sigma=sigma_y, dims="model_name")

                # Prediction for unseen data (i.e. GEN node)
                z_alpha_new = pm.Normal("z_alpha_new", mu=0, sigma=1)
                alpha_new = mu_alpha + z_alpha_new * sigma_alpha
                y_new = pm.Normal("y_new", mu=alpha_new, sigma=sigma_y)

        return model

    def _build_model_impl(
        self,
        observations: List[Observation],
        model_name: Optional[str],
        is_prediction_model: bool,
        is_prior_model: bool = False,
    ) -> pm.Model:
        child_indices, rewards, coords = self.preprocess_observations(
            observations, model_name=model_name
        )

        with pm.Model(coords=coords if not is_prior_model else None) as model:
            # Priors START
            # Overall Goodness of the model itself; mu is set to be 0.5 (50% prob of solving the problem)
            mu_alpha = pm.Normal(
                "mu_alpha", mu=self.get_reward_average_prior(model_name), sigma=0.2
            )

            # expresses the strength of score fluctuation across answers
            sigma_alpha = pm.HalfNormal("sigma_alpha", sigma=0.2)

            # expresses the strength of score fluctuation inside answers
            sigma_y = pm.HalfNormal("sigma_y", sigma=0.3)
            # Priors END

            group_dims = "child_idx" if not is_prior_model else None
            # We use non-centered parameterization (see https://sjster.github.io/introduction_to_computational_statistics/docs/Production/Reparameterization.html)
            z_alpha = pm.Normal("z_alpha", mu=0, sigma=1, dims=group_dims)
            alpha = mu_alpha + z_alpha * sigma_alpha

            if is_prior_model:
                # Expected value
                y_hat = alpha

                y = pm.Normal("y", mu=y_hat, sigma=sigma_y)
            elif not is_prediction_model:
                # Expected value
                y_hat = alpha[child_indices]

                y = pm.Normal("y", mu=y_hat, sigma=sigma_y, observed=rewards)
            else:
                # Expected value
                # For prediction, we sample distribution of y for each child_idx
                y_hat = alpha[list(range(len(coords["child_idx"])))]

                y = pm.Normal("y", mu=y_hat, sigma=sigma_y, dims="child_idx")

                # Prediction for unseen data (i.e. GEN node)
                z_alpha_new = pm.Normal("z_alpha_new", mu=0, sigma=1)
                alpha_new = mu_alpha + z_alpha_new * sigma_alpha
                y_new = pm.Normal("y_new", mu=alpha_new, sigma=sigma_y)

        return model

    def preprocess_observations_for_multiarm_bandit(
        self, observations: List[Observation]
    ) -> Tuple[List[str], List[float], Dict[str, List[int]]]:
        """
        Extract necessary information from Observation list
        """
        df = Observation.to_pandas(observations, model_name=None)
        if len(df) == 0:
            return [], [], dict()

        model_names, mn_model_names = df["model_name"].factorize()

        rewards = list(df["reward"].values)

        coords = {"model_name": list(mn_model_names)}

        return model_names, rewards, coords

    def preprocess_observations(
        self, observations: List[Observation], model_name: Optional[str]
    ) -> Tuple[List[int], List[float], Dict[str, List[int]]]:
        """
        Extract necessary information from Observation list
        """
        df = Observation.to_pandas(observations, model_name=model_name)
        if len(df) == 0:
            return [], [], dict()

        child_indices, mn_child_indices = df["child_idx"].factorize()

        rewards = list(df["reward"].values)

        coords = {"child_idx": list(mn_child_indices)}

        return child_indices, rewards, coords

    def get_reward_average_prior(self, model_name: str) -> float:
        """
        Prior parameter for reward average value for each model.
        """
        if isinstance(self.reward_average_priors, float | int):
            return float(self.reward_average_priors)
        else:
            return self.reward_average_priors.get(model_name, 0.5)
