import numpy as np
import scipy.special
import numpy as np
import scipy
import math
import os
import sys
import json
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from utils import compute_condition_number, update_json_file

class KernelBanzhafCond:
    def __init__(self, n_features, n_samples):
        self.m = n_samples
        self.n = n_features

    def __call__(self):
        weight = 1 / 2

        S_original = np.random.randint(2, size=(self.m // 2, self.n))
        S_complement = 1 - S_original

        S = np.vstack((S_original, S_complement))
        lr_features = S - weight

        return lr_features


def ith_combination(pool, r, index):
    # Function written by ChatGPT
    """
    Compute the index-th combination (0-based) in lexicographic order
    without generating all previous combinations.
    """
    n = len(pool)
    combination = []
    elements_left = n
    k = r
    start = 0
    for i in range(r):
        # Find the largest value for the first element in the combination
        # that allows completing the remaining k-1 elements
        for j in range(start, elements_left):
            count = math.comb(elements_left - j - 1, k - 1)
            if index < count:
                combination.append(pool[j])
                k -= 1
                start = j + 1
                break
            index -= count
    return tuple(combination)


def combination_generator(gen, n, s, num_samples):
    """
    Generate num_samples random combinations of s elements from a pool num_samples of size n in two settings:
    1. If the number of combinations is small (converting to an int does NOT cause an overflow error), randomly sample num_samples integers without replacement and generate the corresponding combinations on the fly with ith_combination.
    2. If the number of combinations is large (converting to an int DOES cause an overflow error), randomly sample num_samples combinations directly with replacement.
    """
    num_combos = math.comb(n, s)
    try:
        indices = gen.choice(num_combos, num_samples, replace=False)
        for i in indices:
            yield ith_combination(range(n), s, i)
    except OverflowError:
        for _ in range(num_samples):
            yield gen.choice(n, s, replace=False)


class RegressionEstimatorCond:
    def __init__(
        self,
        n_features,
        num_samples,
        paired_sampling=False,
        leverage_sampling=False,
        bernoulli_sampling=False,
    ):
        # Subtract 2 for the baseline and explicand and ensure num_samples is even
        self.num_samples = int((num_samples - 2) // 2) * 2
        self.paired_sampling = paired_sampling
        self.n = n_features
        self.gen = np.random.Generator(np.random.PCG64())
        self.sample_weight = lambda s: (
            1 / (s * (self.n - s)) if not leverage_sampling else np.ones_like(s)
        )
        self.reweight = lambda s: 1 / (self.sample_weight(s) * (s * (self.n - s)))
        self.kernel_weights = []
        self.sample = (
            self.sample_with_replacement
            if not bernoulli_sampling
            else self.sample_without_replacement
        )
        # self.used_indices = set()

    def add_one_sample(self, idx, indices, weight):
        # indices = sorted(indices)
        # if tuple(indices) in self.used_indices: return
        # self.used_indices.add(tuple(indices))
        if not self.paired_sampling:
            self.SZ_binary[idx, indices] = 1
            self.kernel_weights.append(weight)
        else:
            indices_complement = np.array(
                [i for i in range(self.n) if i not in indices]
            )
            self.SZ_binary[2 * idx, indices] = 1
            self.kernel_weights.append(weight)
            self.SZ_binary[2 * idx + 1, indices_complement] = 1
            self.kernel_weights.append(weight)

    def sample_with_replacement(self):
        self.SZ_binary = np.zeros((self.num_samples, self.n))
        valid_sizes = np.array(list(range(1, self.n)))
        prob_sizes = self.sample_weight(valid_sizes)
        prob_sizes = prob_sizes / np.sum(prob_sizes)
        num_sizes = (
            self.num_samples if not self.paired_sampling else self.num_samples // 2
        )
        sampled_sizes = self.gen.choice(valid_sizes, num_sizes, p=prob_sizes)
        for idx, s in enumerate(sampled_sizes):
            indices = self.gen.choice(self.n, s, replace=False)
            # weight = Pr(sampling this set) * w(s)
            weight = 1 / (self.sample_weight(s) * s * (self.n - s))
            self.add_one_sample(idx, indices, weight=weight)

    def find_constant_for_bernoulli(self, max_C=1e10):
        # Choose C so that sampling without replacement from min(1, C*prob) gives the same expected number of samples
        C = 1  # Assume at least n - 1 samples
        m = min(
            self.num_samples, 2 ** self.n - 2
        )  # Maximum number of samples is 2^n -2

        def expected_samples(C):
            expected = [
                min(scipy.special.binom(self.n, s), 2 * C * self.sample_weight(s))
                for s in range(1, self.n)
            ]
            return np.sum(expected)

        # Efficiently find C with binary search
        L = 1
        R = scipy.special.binom(self.n, self.n // 2) * self.n ** 2
        while round(expected_samples(C)) != m:
            if expected_samples(C) < m:
                L = C
            else:
                R = C
            C = (L + R) / 2
        self.C = round(C)

    def sample_without_replacement(self):
        self.find_constant_for_bernoulli()
        m_s_all = []
        for s in range(1, self.n):
            # Sample from Binomial distribution with (n choose s) trials and probability min(1, C*sample_weight(s) / (n choose s))
            prob = min(
                1, 2 * self.C * self.sample_weight(s) / scipy.special.binom(self.n, s)
            )
            try:
                m_s = self.gen.binomial(int(scipy.special.binom(self.n, s)), prob)
            except (
                OverflowError
            ):  # If the number of samples is too large, assume the number of samples is the expected number
                m_s = int(prob * scipy.special.binom(self.n, s))
            if self.paired_sampling:
                if (
                    s == self.n // 2
                ):  # Already sampled all larger sets with the complement
                    if (
                        self.n % 2 == 0
                    ):  # Special handling for middle set size if n is even
                        m_s_all.append(m_s // 2)
                    else:
                        m_s_all.append(m_s)
                    break
            m_s_all.append(m_s)
        sampled_m = np.sum(m_s_all)
        num_rows = sampled_m if not self.paired_sampling else sampled_m * 2
        self.SZ_binary = np.zeros((num_rows, self.n))
        idx = 0
        for s, m_s in enumerate(m_s_all):
            s += 1
            prob = min(
                1, 2 * self.C * self.sample_weight(s) / scipy.special.binom(self.n, s)
            )
            weight = 1 / (prob * scipy.special.binom(self.n, s) * (self.n - s) * s)
            if self.paired_sampling and s == self.n // 2 and self.n % 2 == 0:
                # Partition the all middle sets into two
                # based on whether the combination contains n-1
                combo_gen = combination_generator(self.gen, self.n - 1, s - 1, m_s)
                for indices in combo_gen:
                    self.add_one_sample(
                        idx, list(indices) + [self.n - 1], weight=weight
                    )
                    idx += 1
            else:
                combo_gen = combination_generator(self.gen, self.n, s, m_s)
                for indices in combo_gen:
                    self.add_one_sample(idx, list(indices), weight=weight)
                    idx += 1

    def compute(self):
        self.sample()
        SZ_binary = self.SZ_binary[np.sum(self.SZ_binary, axis=1) != 0]
        P = np.eye(self.n) - 1 / self.n * np.ones((self.n, self.n))
        PZSSZP = P @ SZ_binary.T @ np.diag(self.kernel_weights) @ SZ_binary @ P

        return P, PZSSZP


def leverage_shap(n_features, num_samples):
    estimator = RegressionEstimatorCond(
        n_features,
        num_samples,
        paired_sampling=True,
        leverage_sampling=True,
        bernoulli_sampling=True,
    )
    return estimator.compute()


def optimized_kernel_shap(n_features, num_samples):
    estimator = RegressionEstimatorCond(
        n_features,
        num_samples,
        paired_sampling=True,
        leverage_sampling=False,
        bernoulli_sampling=True,
    )
    return estimator.compute()


random_states = np.random.RandomState(42).choice(1000, 2, replace=False)
random_states = [int(i) for i in random_states]

for random_state in random_states:
    print(f"Random state: {random_state}")
    for n_features in [10, 40, 160, 640]:
        result_dir = f"cond_results/results_{n_features}"
        if not os.path.exists(result_dir):
            os.makedirs(result_dir)

        for n_samples in [
            5 * n_features,
            10 * n_features,
            20 * n_features,
            40 * n_features,
            80 * n_features,
            160 * n_features,
            320 * n_features,
        ]:
            if n_features == 640 and n_samples > 640 * 80:
                continue

            print(f"leverage shap for {n_features} features and {n_samples} samples")
            leverage_P, leverage_A = leverage_shap(n_features, n_samples)
            leverage_shap_matrix = (
                leverage_P @ leverage_A @ leverage_A.T @ leverage_P
            )
            leverage_shap_cond = compute_condition_number(leverage_shap_matrix)
            leverage_path = f"{result_dir}/shap_leverage_{n_features}_{n_samples}.json"
            update_json_file(leverage_path, leverage_shap_cond, random_state)

            print(
                f"optimized kernel shap for {n_features} features and {n_samples} samples"
            )
            optimized_P, optimized_A = optimized_kernel_shap(n_features, n_samples)
            optimized_kernel_shap_matrix = (
                optimized_P @ optimized_A.T @ optimized_A @ optimized_P
            )
            optimized_kernel_shap_cond = compute_condition_number(optimized_kernel_shap_matrix)
            optimized_kernel_path = (
                f"{result_dir}/shap_optimized_{n_features}_{n_samples}.json"
            )
            update_json_file(
                optimized_kernel_path, optimized_kernel_shap_cond, random_state
            )

            print(f"banzhaf for {n_features} features and {n_samples} samples")
            banzhaf_A = KernelBanzhafCond(n_features, n_samples)()
            banzhaf_matrix = banzhaf_A.T @ banzhaf_A
            banzhaf_cond = compute_condition_number(banzhaf_matrix)
            banzhaf_path = f"{result_dir}/banzhaf_paired_sampling_cond_{n_features}_{n_samples}.json"
            update_json_file(banzhaf_path, banzhaf_cond, random_state)


def process_results_files(result_dir, estimators, explainer):
    suffixes = {estimator: [] for estimator in estimators}
    estimations = {estimator: [] for estimator in estimators}

    for filename in os.listdir(result_dir):
        if filename.endswith(".json") and explainer in filename:
            for estimator in estimators:
                if estimator in filename:
                    suffix = float(".".join(filename.split("_")[-1].split(".")[:-1]))
                    suffixes[estimator].append(suffix)

                    estimated_filepath = os.path.join(result_dir, filename)
                    with open(estimated_filepath, "r") as file:
                        estimated_data = json.load(file)
                    curr_estimations = []
                    for random_state, estimated_values in estimated_data.items():
                        curr_estimations.append(estimated_values)
                    estimations[estimator].append(curr_estimations)
                    break
    return suffixes, estimations


def plot_all_datasets():
    datasets = ["10", "40", "160", "640"]
    estimators = ["leverage", "optimized"]
    num_datasets = len(datasets)
    if num_datasets == 4:
        fig, axs = plt.subplots(
            1, num_datasets, figsize=(2.5 * num_datasets, 2.6), sharey=False
        )
        rect = [0, 0.12, 1, 1]

    count = 0
    y_label = True
    for ax, dataset in zip(axs, datasets):
        if count % 4 == 0 and num_datasets == 8:
            y_label = True
        elif count % 4 != 0:
            y_label = False
        curr_result_dir = f"cond_results/results_{dataset}"
        if count == 0:
            l, lab = plot_dataset(ax, curr_result_dir, estimators, dataset, y_label)
        else:
            plot_dataset(ax, curr_result_dir, estimators, dataset, y_label)
        count += 1

    plt.subplots_adjust(wspace=0.01, hspace=0.1)
    plt.tight_layout(rect=rect)

    fig.legend(l, lab, loc="lower center", ncol=4, fontsize=11, frameon=False)

    print("Saving all datasets by sample size")
    file_name = f"shap_cond_by_sample_size.pdf"
    plt.savefig(file_name)


def plot_dataset(ax, result_dir, estimators, dataset, y_label):

    sample_sizes, conds = process_results_files(
        result_dir, estimators, explainer="shap"
    )

    banzhaf_sample_sizes, banzhaf_conds = process_results_files(
        result_dir, ["paired_sampling"], explainer="banzhaf"
    )

    conds["paired_sampling"] = banzhaf_conds["paired_sampling"]
    sample_sizes["paired_sampling"] = banzhaf_sample_sizes["paired_sampling"]
    estimators = ["paired_sampling"] + estimators

    colors = plt.cm.tab20c([0])
    colors = np.concatenate((colors, plt.cm.tab20([4, 12])))
    line_styles = ["-", (0, (3, 1, 1, 1, 1, 1)), (0, (5, 1))]
    marker_styles = ["o", "v", "x"]

    lines = []

    for estimator, color, line_style, marker_style in zip(
        estimators, colors, line_styles, marker_styles
    ):
        sorted_indices = np.argsort(sample_sizes[estimator])
        sorted_sample_sizes = np.array(sample_sizes[estimator])[sorted_indices]
        sorted_errors = np.array(conds[estimator])[sorted_indices]

        # Calculating median, 25th percentile, and 75th percentile
        median_errors = np.median(sorted_errors, axis=1)
        perc25_errors = np.percentile(sorted_errors, 25, axis=1)
        perc75_errors = np.percentile(sorted_errors, 75, axis=1)

        label_dict = {
            "paired_sampling": "Kernel Banzhaf",
            "leverage": "Leverage SHAP",
            "optimized": "Optimized Kernel SHAP",
        }

        (line,) = ax.plot(
            sorted_sample_sizes,
            median_errors,
            marker=marker_style,
            linestyle=line_style,
            color=color,
            label=f"{label_dict[estimator]}",
            markersize=4,
            linewidth=1,
            # markersize=1.5,
            # linewidth=0.9,
        )
        lines.append(line)
        ax.fill_between(
            sorted_sample_sizes, perc25_errors, perc75_errors, color=color, alpha=0.2
        )
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
        ax.spines["left"].set_linewidth(0.5)
        ax.spines["bottom"].set_linewidth(0.5)

    error_name = "Condition Number"

    dataset_dict = {
        "10": "n = 10",
        "40": "n = 40",
        "160": "n = 160",
        "640": "n = 640",
    }
    ax.set_title(f"{dataset_dict[dataset]}", fontsize=11)
    ax.set_xlabel("Number of Samples", fontsize=10)
    ax.tick_params(axis="x", labelsize=7)
    if y_label:
        ax.set_ylabel(f"{error_name}", fontsize=10)
    ax.tick_params(axis="y", labelsize=7)
    ax.set_xscale("log")
    ax.xaxis.set_major_locator(
        ticker.LogLocator(base=10.0)
    )  # Major ticks at powers of 10 for x-axis
    ax.set_yscale("log")
    ax.yaxis.set_major_locator(
        ticker.LogLocator(base=10.0)
    )  # Major ticks at powers of 10 for y-axis

    ax.minorticks_off()

    # ax.grid(True, which="major", linestyle="--", linewidth=0.5, color="gray", alpha=0.6)

    return lines, [line.get_label() for line in lines]


plot_all_datasets()
