from typing import Any, Dict

import numpy as np
import tqdm
from numpy.random import RandomState
from sklearn.utils import check_random_state


class SystematicSampler:
    GR_MAX = np.inf

    def __init__(
        self,
        configurations: Dict[str, Any],
        gr_threshold: float = 1.05,
        max_mc_epochs: int = 1000,
        permutations_per_epoch: int = 100,
        min_permutations: int = 500,
        random_state: RandomState = 42,
    ):
        self.configurations = configurations
        first_config = next(iter(self.configurations.values()))
        self.num_points = first_config["trainset"].get_len()
        self.gr_threshold = gr_threshold
        self.max_mc_epochs = max_mc_epochs
        self.permutations_per_epoch = permutations_per_epoch
        self.min_permutations = min_permutations
        self.random_state = check_random_state(random_state)
        self.permutations = None
        self.marginal_contrib_sum = {}
        self.marginal_count = {}
        self.gr_stats = {}
        self.marginal_increment_array_stack = {}
        self.utility_cache = {}

    def initialize_contributions(self):
        for config_key in self.configurations:
            self.marginal_contrib_sum[config_key] = np.zeros(
                (self.num_points, self.num_points)
            )  # num_cardinalities x num_points
            self.marginal_count[config_key] = (
                np.zeros((self.num_points, self.num_points)) + 1e-8
            )  # Prevent division by 0
            self.gr_stats[config_key] = SystematicSampler.GR_MAX
            self.marginal_increment_array_stack[config_key] = np.zeros(
                (0, self.num_points)
            )

    def compute_marginal_contributions_for_all(self, verbose=True):
        self.permutations = np.array(
            [
                self.random_state.permutation(self.num_points)
                for _ in range(self.max_mc_epochs * self.permutations_per_epoch)
            ]
        )
        iteration = 0
        self.initialize_contributions()

        while iteration < self.max_mc_epochs and not self._all_converged():
            if verbose:
                print(f"Epoch {iteration+1}/{self.max_mc_epochs}")
            for config_key in self.configurations.keys():
                if self.gr_stats[config_key] > self.gr_threshold:
                    if verbose:
                        print(f"Processing configuration {config_key}")
                    for perm_idx in tqdm.tqdm(
                        range(
                            iteration * self.permutations_per_epoch,
                            (iteration + 1) * self.permutations_per_epoch,
                        ),
                        desc=f"Config {config_key}",
                        disable=not verbose,
                    ):
                        self._calculate_marginal_contributions(config_key, perm_idx)
                    self.gr_stats[config_key] = self._compute_gr_statistic(config_key)
                    if verbose:
                        print(
                            f"GR Statistic for {config_key}: {self.gr_stats[config_key]}"
                        )
            iteration += 1

        # Return the matrix of shape (num_cardinalities x num_points) of marginal contributions
        return {
            key: self.marginal_contrib_sum[key] / self.marginal_count[key]
            for key in self.configurations
        }

    def _calculate_marginal_contributions(self, config_key: str, perm_idx: int):
        subset = self.permutations[perm_idx]
        marginal_increment = np.zeros(self.num_points)
        coalition_indices = []
        prev_perf = self.compute_utility_cached(coalition_indices, config_key)
        truncation_counter = 0
        for cardinality, idx in enumerate(subset):
            coalition_indices.append(idx)
            curr_perf = self.compute_utility_cached(coalition_indices, config_key)
            marginal_increment[idx] = curr_perf - prev_perf
            self.marginal_contrib_sum[config_key][cardinality, idx] += (
                marginal_increment[idx]
            )
            self.marginal_count[config_key][cardinality, idx] += 1
            if (
                abs(curr_perf - prev_perf) / max(np.sum(marginal_increment), 1e-8)
                < 1e-8
            ):
                truncation_counter += 1
            else:
                truncation_counter = 0
            if truncation_counter == 10:
                self.marginal_count[config_key][
                    cardinality + 1 :, subset[cardinality + 1 :]
                ] += 1
                break

            prev_perf = curr_perf

        self.marginal_increment_array_stack[config_key] = np.vstack(
            [
                self.marginal_increment_array_stack[config_key],
                marginal_increment.reshape(1, -1),
            ]
        )

    def compute_utility_cached(self, coalition_indices, config_key):
        cache_key = (tuple(sorted(coalition_indices)), config_key)

        if cache_key in self.utility_cache:
            return self.utility_cache[cache_key]
        else:
            configuration = self.configurations[config_key]
            X_subset = configuration["trainset"].X[coalition_indices]
            y_subset = configuration["trainset"].y[coalition_indices]
            utility_value = configuration["utility"].compute_utility(
                X_subset, y_subset, configuration
            )
            self.utility_cache[cache_key] = utility_value
            return utility_value

    def _compute_gr_statistic(self, config_key: str, num_chains: int = 10) -> float:
        samples = self.marginal_increment_array_stack[config_key]
        if len(samples) < self.min_permutations:
            return SystematicSampler.GR_MAX

        num_samples, num_datapoints = samples.shape
        num_samples_per_chain, offset = divmod(num_samples, num_chains)

        if num_samples_per_chain == 0:
            return SystematicSampler.GR_MAX

        samples = samples[offset:]
        mcmc_chains = samples.reshape(num_chains, num_samples_per_chain, num_datapoints)

        s_term = np.mean(np.var(mcmc_chains, axis=1, ddof=1), axis=0)
        sampling_mean = np.mean(mcmc_chains, axis=1, keepdims=False)
        b_term = num_samples_per_chain * np.var(sampling_mean, axis=0, ddof=1)

        epsilon = 1e-8
        s_term = np.where(s_term == 0, epsilon, s_term)

        gr_stats = np.sqrt(
            (num_samples_per_chain - 1) / num_samples_per_chain
            + (b_term / (s_term * num_samples_per_chain))
        )

        return np.max(gr_stats)

    def _all_converged(self) -> bool:
        return all(gr_stat <= self.gr_threshold for gr_stat in self.gr_stats.values())