import os

# Avoid file lock error
os.environ["PYTENSOR_FLAGS"] = (
    f"compiledir_format=compiledir_{os.getpid()},base_compiledir={os.path.expanduser('~')}/.pytensor/compiledir_llm-mcts"
)


from collections import Counter
from dataclasses import dataclass, replace
from typing import Dict, List, Literal, Optional, Tuple

import jax
import numpy as np
import numpyro
import pandas as pd
import pymc as pm
from pymc.sampling.jax import sample_numpyro_nuts
from xarray import DataArray

from llm_mcts.mcts_algo.node import Node

numpyro.set_host_device_count(1)  # Use only 1 core for sample_numpyro_nuts
numpyro.set_platform("cpu")  # Use CPU rather than GPU for sample_numpyro_nuts


@dataclass(frozen=True)
class Observation:
    reward: float
    model_name: str
    node: Node

    # Linear model group index to be used by PyMC; It is not used or set from MCTS side
    child_idx: int = -1

    @classmethod
    def to_pandas(
        cls, observations: List["Observation"], model_name: str
    ) -> pd.DataFrame:
        observations_with_id = []

        # We iterate over observation to avoid time-consuming dataclasses.asdict operation
        for idx, observation in enumerate(observations):
            if observation.model_name == model_name:
                observations_with_id.append(
                    {
                        "obs_id": idx,
                        "reward": observation.reward,
                        "child_idx": observation.child_idx,
                        "model_name": observation.model_name,
                    }
                )
        return pd.DataFrame(observations_with_id)

    @classmethod
    def collect_all_observations_of_descendant(
        cls,
        parent: Node,
        all_observations: Dict[Node, "Observation"],
    ) -> List["Observation"]:
        """
        A helper method to collect all the decsendant observations.
        """
        observations: List["Observation"] = []
        for child_idx, child in enumerate(parent.children):
            cls.collect_all_observations(
                child,
                all_observations,
                child_idx_override=child_idx,
                observations=observations,
            )

        return observations

    @classmethod
    def collect_all_observations(
        cls,
        parent: Node,
        all_observations: Dict[Node, "Observation"],
        child_idx_override: int,
        observations: List["Observation"],
    ) -> None:
        """
        A helper method to collect all the observations (includeing the one of the current node) to list.
        """
        observation = replace(all_observations[parent], child_idx=child_idx_override)
        observations.append(observation)
        for child_idx, child in enumerate(parent.children):
            cls.collect_all_observations(
                child,
                all_observations,
                child_idx_override=child_idx_override,
                observations=observations,
            )


@dataclass
class PruningConfig:
    # min subtree size where pruning is enabled. If it is 4, the subtree with size >=4 is amenable to pruning
    min_subtree_size_for_pruning: int = 4

    # subtree is pruned if (max number of nodes with the same score inside subtree) / (# nodes of subtree) >= same_score_proportion_threshold
    same_score_proportion_threshold: float = 0.75


def is_prunable(
    node: Node, observations: List["Observation"], pruning_config: PruningConfig
) -> bool:
    """
    Check if the subtree where the node is parent is prunable from next search.
    """
    scores = []
    for obs in observations:
        obs_node = obs.node
        while obs_node.parent is not None:
            if obs_node == node:
                scores.append(int(round(obs.reward * 100)))
                break
            obs_node = obs_node.parent

    _max_element, max_count = Counter(scores).most_common(1)[0]
    if (
        max_count / len(scores) >= pruning_config.same_score_proportion_threshold
        and len(scores) >= pruning_config.min_subtree_size_for_pruning
    ):
        return True
    else:
        return False


class PyMCInterface:
    """
    We leverage PyMC to perform (1) parameter fitting and (2) prediction.

    We use the fact that we can use different statistical models for (1) and (2) to make predictions for unobserved group (i.e. GEN node).

    See https://www.pymc-labs.com/blog-posts/out-of-model-predictions-with-pymc/ for pedagogical introduction for using different models for (1) and (2),
    and see https://www.pymc.io/projects/examples/en/latest/generalized_linear_models/multilevel_modeling.html for hierarchical modeling in PyMC.
    """

    called_number: int = 0

    def __init__(
        self,
        algo: Literal["thompson"] = "thompson",
        enable_pruning: bool = True,
        pruning_config: Optional[PruningConfig] = None,
    ):
        self.algo = algo
        self.enable_pruning = enable_pruning
        self.pruning_config = (
            pruning_config if pruning_config is not None else PruningConfig()
        )

    def run(
        self, observations: List[Observation], model_names: List[str], node: Node
    ) -> str | int:
        """
        Main entry point of Mixed Thompson Algorithm.
        Returns the model_name in case GEN Node is chosen; otherwise return child_idx
        """
        scores: Dict[str | int, float] = dict()
        for model_name in model_names:
            scores |= self.calculate_score(observations, model_name)

        sorted_scores = dict(
            sorted(scores.items(), key=lambda item: item[1], reverse=True)
        )
        for identifier in sorted_scores:
            if isinstance(identifier, str):
                return identifier
            else:
                child_node = node.children[identifier]
                if self.enable_pruning and is_prunable(
                    child_node, observations, self.pruning_config
                ):
                    continue
                return identifier

        raise RuntimeError(
            f"Internal Error: Failed to get best option from {sorted_scores}"
        )

    def calculate_score(
        self, observations: List[Observation], model_name: str
    ) -> Dict[str | int, float]:
        _child_indices, _rewards, coords = self.preprocess_observations(
            observations, model_name=model_name
        )

        # In case observations for model_name is empty, we sample from prior predictive
        # Prior Predictive Sampling START
        if len(coords) == 0:
            prior_model = self.build_fitting_model(
                observations, model_name, is_prior_model=True
            )
            with prior_model:
                prior_model_trace = pm.sample_prior_predictive(var_names=["y"])

            return {model_name: self.get_score(prior_model_trace.prior.y)}
        # Prior Predictive Sampling END

        fitting_model = self.build_fitting_model(observations, model_name)
        pred_model = self.build_prediction_model(observations, model_name)

        # We use numpyro for sampling; It may use some amount of CPU resource
        with fitting_model:
            model_trace = sample_numpyro_nuts(
                chains=4,
                compute_convergence_checks=False,
                idata_kwargs=dict(log_likelihood=False),
                progressbar=False,
            )

        # Using the model_trace which includes sampled posterior information of parameters, we predict y (reward) values for GEN node and children nodes.
        # y represents the child node reward, and y_new represnets GEN node reward
        with pred_model:
            pred_model_trace = pm.sample_posterior_predictive(
                model_trace, var_names=["y", "y_new"], progressbar=False
            )

        scores: Dict[str | int, float] = dict()
        for child_idx in coords["child_idx"]:
            scores[child_idx] = self.get_score(
                pred_model_trace.posterior_predictive.y.sel(child_idx=child_idx)
            )

        scores[model_name] = self.get_score(pred_model_trace.posterior_predictive.y_new)

        self.called_number += 1
        # Clear jax cache to avoid memory leak
        # numpyro sampling leads to memory leak, so we delete cached jax arrays here
        if self.called_number % 10 == 0:
            jax.clear_caches()
            for x in jax.live_arrays():
                x.delete()

        return scores

    def get_score(self, arr: DataArray) -> float:
        if self.algo == "thompson":
            return np.random.choice(arr.values.flatten())
        else:
            raise NotImplementedError(
                f"Algo type {self.algo} is not supported by PyMCInterface"
            )

    def build_fitting_model(
        self,
        observations: List[Observation],
        model_name: str,
        is_prior_model: bool = False,
    ) -> pm.Model:
        """
        Build Hierarchical PyMC model for (1) parameter fitting.
        See tests/figures/fitting_model.jpg for model structure.
        """
        return self._build_model_impl(
            observations,
            model_name,
            is_prediction_model=False,
            is_prior_model=is_prior_model,
        )

    def build_prediction_model(
        self,
        observations: List[Observation],
        model_name: str,
    ) -> pm.Model:
        """
        Build Hierarchical PyMC model for (2) prediction.
        See tests/figures/prediction_model.jpg for the model structure.
        """
        return self._build_model_impl(
            observations, model_name, is_prediction_model=True
        )

    def _build_model_impl(
        self,
        observations: List[Observation],
        model_name: str,
        is_prediction_model: bool,
        is_prior_model: bool = False,
    ) -> pm.Model:
        child_indices, rewards, coords = self.preprocess_observations(
            observations, model_name=model_name
        )

        with pm.Model(coords=coords if not is_prior_model else None) as model:
            # Priors START
            # Overall Goodness of the model itself; mu is set to be 0.5 (50% prob of solving the problem)
            mu_alpha = pm.Normal("mu_alpha", mu=0.5, sigma=0.2)

            # expresses the strength of score fluctuation across answers
            sigma_alpha = pm.HalfNormal("sigma_alpha", sigma=0.2)

            # expresses the strength of score fluctuation inside answers
            sigma_y = pm.HalfNormal("sigma_y", sigma=0.3)
            # Priors END

            group_dims = "child_idx" if not is_prior_model else None
            # We use non-centered parameterization (see https://sjster.github.io/introduction_to_computational_statistics/docs/Production/Reparameterization.html)
            z_alpha = pm.Normal("z_alpha", mu=0, sigma=1, dims=group_dims)
            alpha = mu_alpha + z_alpha * sigma_alpha

            if is_prior_model:
                # Expected value
                y_hat = alpha

                y = pm.Normal("y", mu=y_hat, sigma=sigma_y)
            elif not is_prediction_model:
                # Expected value
                y_hat = alpha[child_indices]

                y = pm.Normal("y", mu=y_hat, sigma=sigma_y, observed=rewards)
            else:
                # Expected value
                # For prediction, we sample distribution of y for each child_idx
                y_hat = alpha[list(range(len(coords["child_idx"])))]

                y = pm.Normal("y", mu=y_hat, sigma=sigma_y, dims="child_idx")

                # Prediction for unseen data (i.e. GEN node)
                z_alpha_new = pm.Normal("z_alpha_new", mu=0, sigma=1)
                alpha_new = mu_alpha + z_alpha_new * sigma_alpha
                y_new = pm.Normal("y_new", mu=alpha_new, sigma=sigma_y)

        return model

    def preprocess_observations(
        self, observations: List[Observation], model_name: str
    ) -> Tuple[List[int], List[float], Dict[str, List[int]]]:
        """
        Extract necessary information from Observation list
        """
        df = Observation.to_pandas(observations, model_name=model_name)
        if len(df) == 0:
            return [], [], dict()

        child_indices, mn_child_indices = df["child_idx"].factorize()

        rewards = list(df["reward"].values)

        coords = {"child_idx": list(mn_child_indices)}

        return child_indices, rewards, coords
