import os
import sys
import json
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from utils import calculate_normalized_l2_norm


def process_results_files(result_dir, estimators, dataset, explainer, mode=None):
    errors = {estimator: [] for estimator in estimators}
    suffixes = {estimator: [] for estimator in estimators}
    estimations = {estimator: [] for estimator in estimators}

    if mode == "noise":
        est_result_dir = os.path.join(result_dir, "noise")
    else:
        est_result_dir = result_dir

    with open(
        os.path.join(result_dir, f"{explainer}_exact_{dataset}.json"), "r"
    ) as file:
        exact_data = json.load(file)

    if explainer == "banzhaf":
        estimations["exact"] = [
            exact_data[random_state][feature]
            for random_state in exact_data.keys()
            for feature in exact_data[random_state]
        ]
    else:
        estimations["exact"] = [
            exact_data[random_state] for random_state in exact_data.keys()
        ]

    for filename in os.listdir(est_result_dir):
        if filename.endswith(".json") and explainer in filename:
            for estimator in estimators:
                if mode == "cond":
                    cond = "cond" in filename
                else:
                    cond = "cond" not in filename
                if estimator in filename and cond:
                    suffix = float(".".join(filename.split("_")[-1].split(".")[:-1]))
                    suffixes[estimator].append(suffix)

                    estimated_filepath = os.path.join(est_result_dir, filename)
                    with open(estimated_filepath, "r") as file:
                        estimated_data = json.load(file)
                    error = []
                    curr_estimations = []
                    for random_state, estimated_values in estimated_data.items():
                        if explainer == "banzhaf":
                            if mode == "cond":
                                curr_estimations.append(estimated_values)
                            else:
                                for feature in estimated_values.keys():
                                    curr_estimations.append(estimated_values[feature])
                                e = calculate_normalized_l2_norm(
                                        exact_data[random_state], estimated_values
                                    )
                                error.append(e)
                        else:
                            curr_estimations.append(estimated_values)
                            differences = np.array(exact_data[random_state]) - np.array(
                                estimated_values
                            )
                            l2_norms = np.linalg.norm(differences, axis=0) / np.linalg.norm(
                                np.array(exact_data[random_state]), axis=0
                            )
                            error.append(l2_norms)
                    estimations[estimator].append(curr_estimations)
                    errors[estimator].append(error)
                    break
    return suffixes, errors, estimations


def plot_all_datasets(datasets, estimators, error, noise):
    num_datasets = len(datasets)
    if num_datasets == 4:
        fig, axs = plt.subplots(
            1, num_datasets, figsize=(2.6 * num_datasets, 2.6), sharey=False
        )
        rect = [0, 0.12, 1, 1]
    elif num_datasets == 8:
        fig, axs = plt.subplots(2, 4, figsize=(9.7, 4.3), sharey=False)
        axs = axs.flatten()
        rect = [0, 0.08, 1, 1]

    count = 0
    y_label = True
    for ax, dataset in zip(axs, datasets):
        if count % 4 == 0 and num_datasets == 8:
            y_label = True
        elif count % 4 != 0:
            y_label = False
        curr_result_dir = f"shap_results_{dataset}"
        if count == 0:
            l, lab = plot_dataset(
                ax, curr_result_dir, estimators, error, dataset, y_label, noise
            )
        else:
            plot_dataset(
                ax, curr_result_dir, estimators, error, dataset, y_label, noise
            )
        count += 1

    plt.subplots_adjust(wspace=0.01, hspace=0.1)
    plt.tight_layout(rect=rect)

    fig.legend(l, lab, loc="lower center", ncol=4, fontsize=11, frameon=False)

    if noise is not None:
        print("Saving all datasets by noise")
        file_name = f"shap_all_datasets_{error}_by_noise.pdf"
    else:
        print("Saving all datasets by sample size")
        error = "cond" if error == None else error
        file_name = f"shap_all_datasets_{error}_by_sample_size.pdf"
    plt.savefig(file_name)


def plot_dataset(ax, result_dir, estimators, error, dataset, y_label, noise):
    if noise is None:
        mode = None
    else:
        mode = "noise"

    if error == None:
        mode = "cond"

    sample_sizes, errors, conds = process_results_files(
        result_dir, estimators, dataset, explainer="shap", mode=mode
    )

    banzhaf_sample_sizes, banzhaf_errors, banzhaf_conds = process_results_files(
        result_dir, ["paired_sampling"], dataset, explainer="banzhaf", mode=mode
    )

    errors["paired_sampling"] = banzhaf_errors["paired_sampling"]
    conds["paired_sampling"] = banzhaf_conds["paired_sampling"]
    sample_sizes["paired_sampling"] = banzhaf_sample_sizes["paired_sampling"]
    estimators = ["paired_sampling"] + estimators

    if noise is not None:
        colors = plt.cm.tab20c([0])
        colors = np.concatenate((colors, plt.cm.tab20([4, 12])))
    else:
        colors = plt.cm.tab20c([1])
        colors = np.concatenate((colors, plt.cm.tab20([5, 13])))
    line_styles = ["-", (0, (3, 1, 1, 1, 1, 1)), (0, (5, 1))]
    marker_styles = ["o", "v", "x"]

    lines = []

    for estimator, color, line_style, marker_style in zip(
        estimators, colors, line_styles, marker_styles
    ):
        sorted_indices = np.argsort(sample_sizes[estimator])
        sorted_sample_sizes = np.array(sample_sizes[estimator])[sorted_indices]
        if mode == "cond":
            sorted_errors = np.array(conds[estimator])[sorted_indices]
        else:
            sorted_errors = np.array(errors[estimator])[sorted_indices]

        # only take elements where sample size is less than 2**n
        n_dict = {
            "diabetes": 8,
            "adult": 14,
            "bank": 16,
            "german_credit": 20,
            "nhanes": 79,
            "brca": 100,
            "communitiesandcrime": 101,
            "tuandromd": 241,
        }

        sorted_indices = np.where(sorted_sample_sizes <= 2 ** n_dict[dataset])
        sorted_sample_sizes = sorted_sample_sizes[sorted_indices]
        sorted_errors = sorted_errors[sorted_indices]

        # Calculating median, 25th percentile, and 75th percentile
        median_errors = np.median(sorted_errors, axis=1)
        perc25_errors = np.percentile(sorted_errors, 25, axis=1)
        perc75_errors = np.percentile(sorted_errors, 75, axis=1)

        label_dict = {
            "paired_sampling": "Kernel Banzhaf",
            "leverage": "Leverage SHAP",
            "optimized": "Optimized Kernel SHAP",
        }

        (line,) = ax.plot(
            sorted_sample_sizes,
            median_errors,
            marker=marker_style,
            linestyle=line_style,
            color=color,
            label=f"{label_dict[estimator]}",
            markersize=4,
            linewidth=1,
            # markersize=1.5,
            # linewidth=0.9,
        )
        lines.append(line)
        ax.fill_between(
            sorted_sample_sizes, perc25_errors, perc75_errors, color=color, alpha=0.2
        )
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
        ax.spines["left"].set_linewidth(0.5)
        ax.spines["bottom"].set_linewidth(0.5)

    error_name = "Normalized L2-norm" if error == "l2" else "Condition Number"

    dataset_dict = {
        "diabetes": "Diabetes (n=8)",
        "adult": "Adult (n=14)",
        "bank": "Bank (n=16)",
        "german_credit": "German Credit (n=20)",
        "nhanes": "NHANES (n=79)",
        "brca": "BRCA (n=100)",
        "communitiesandcrime": "Communities and Crime (n=101)",
        "tuandromd": "TUANDROMD (n=241)",
    }
    ax.set_title(f"{dataset_dict[dataset]}", fontsize=11)
    if noise is not None:
        ax.set_xlabel("Noise Level", fontsize=10)
    else:
        ax.set_xlabel("Number of Samples", fontsize=10)
    ax.tick_params(axis="x", labelsize=7)
    if y_label:
        ax.set_ylabel(f"{error_name}", fontsize=10)
    ax.tick_params(axis="y", labelsize=7)
    ax.set_xscale("log")
    ax.xaxis.set_major_locator(
        ticker.LogLocator(base=10.0)
    )  # Major ticks at powers of 10 for x-axis
    if mode != "cond":
        ax.set_yscale("log")
        ax.yaxis.set_major_locator(
            ticker.LogLocator(base=10.0)
        )  # Major ticks at powers of 10 for y-axis

    ax.minorticks_off()

    # ax.grid(True, which="major", linestyle="--", linewidth=0.5, color="gray", alpha=0.6)

    return lines, [line.get_label() for line in lines]


def plot_by_sample_size(result_dir, estimators, dataset, n):
    sample_sizes, errors, _ = process_results_files(
        result_dir, estimators, dataset, explainer="shap"
    )

    banzhaf_sample_sizes, banzhaf_errors, _ = process_results_files(
        result_dir, ["paired_sampling"], dataset, explainer="banzhaf"
    )

    errors["paired_sampling"] = banzhaf_errors["paired_sampling"]
    sample_sizes["paired_sampling"] = banzhaf_sample_sizes["paired_sampling"]
    estimators = ["paired_sampling"] + estimators

    plt.figure(figsize=(6, 4.5))
    colors = plt.cm.tab20c([1, 5, 9])
    line_styles = ["-", "--", "-."]
    marker_styles = ["o", "^", "D"]

    for estimator, color, line_style, marker_style in zip(
        estimators, colors, line_styles, marker_styles
    ):

        sorted_indices = np.argsort(sample_sizes[estimator])
        sorted_sample_sizes = np.array(sample_sizes[estimator])[sorted_indices]
        sorted_errors = np.array(errors[estimator])[sorted_indices]

        # only take elements where sample size is less than 2**n
        sorted_indices = np.where(sorted_sample_sizes <= 2 ** n)
        sorted_sample_sizes = sorted_sample_sizes[sorted_indices]
        sorted_errors = sorted_errors[sorted_indices]

        # Calculating median, 25th percentile, and 75th percentile
        median_errors = np.median(sorted_errors, axis=1)
        perc25_errors = np.percentile(sorted_errors, 25, axis=1)
        perc75_errors = np.percentile(sorted_errors, 75, axis=1)

        label_dict = {
            "paired_sampling": "Kernel Banzhaf",
            "leverage": "Leverage SHAP",
            "optimized": "Optimized Kernel SHAP",
        }

        plt.plot(
            sorted_sample_sizes,
            median_errors,
            marker=marker_style,
            linestyle=line_style,
            color=color,
            label=f"{label_dict[estimator]}",
            markersize=5,  # larger marker size
            linewidth=2,
        )
        plt.fill_between(
            sorted_sample_sizes, perc25_errors, perc75_errors, color=color, alpha=0.1
        )  # Percentile shading

    plt.title(
        f"L2 Norm Error by Sample Size on {dataset[0].upper() + dataset[1:]} Dataset"
    )
    plt.xlabel("Number of Samples")
    plt.ylabel(f"l2 Error")

    plt.xscale("log")
    plt.yscale("log")

    plt.gca().xaxis.set_major_locator(ticker.LogLocator(base=10.0))
    plt.gca().yaxis.set_major_locator(ticker.LogLocator(base=10.0))

    plt.gca().xaxis.set_minor_locator(ticker.NullLocator())
    plt.gca().yaxis.set_minor_locator(ticker.NullLocator())

    plt.grid(True, which="major", linestyle="--", linewidth=0.5)

    # plt.legend() # legend is too large
    plt.legend(fontsize=8)

    file_name = f"shap_{dataset}_l2_by_sample_size.png"
    plt.savefig(file_name)


def plot_condition_by_sample_size(result_dir, estimators, dataset, n):
    sample_sizes, _, conds = process_results_files(
        result_dir, estimators, dataset, explainer="shap", mode="cond"
    )

    banzhaf_sample_sizes, _, banzhaf_conds = process_results_files(
        result_dir, ["paired_sampling"], dataset, explainer="banzhaf", mode="cond"
    )

    conds["paired_sampling"] = banzhaf_conds["paired_sampling"]
    sample_sizes["paired_sampling"] = banzhaf_sample_sizes["paired_sampling"]
    estimators = ["paired_sampling"] + estimators

    print("Estimators: ", estimators)
    print(f"Errors keys: {conds.keys()}")

    plt.figure(figsize=(6, 4.5))
    colors = plt.cm.tab20c([1, 5, 9])
    line_styles = ["-", "--", "-."]
    marker_styles = ["o", "^", "D"]

    for estimator, color, line_style, marker_style in zip(
        estimators, colors, line_styles, marker_styles
    ):
        sorted_indices = np.argsort(sample_sizes[estimator])
        sorted_sample_sizes = np.array(sample_sizes[estimator])[sorted_indices]
        sorted_conds = np.array(conds[estimator])[sorted_indices]

        # only take elements where sample size is less than 2**n
        sorted_indices = np.where(sorted_sample_sizes <= 2 ** n)
        sorted_sample_sizes = sorted_sample_sizes[sorted_indices]
        sorted_conds = sorted_conds[sorted_indices]

        # Calculating median, 25th percentile, and 75th percentile
        median_conds = np.median(sorted_conds, axis=1)
        perc25_conds = np.percentile(sorted_conds, 25, axis=1)
        perc75_conds = np.percentile(sorted_conds, 75, axis=1)

        label_dict = {
            "paired_sampling": "Kernel Banzhaf",
            "leverage": "Leverage SHAP",
            "optimized": "Optimized Kernel SHAP",
        }

        plt.plot(
            sorted_sample_sizes,
            median_conds,
            marker=marker_style,
            linestyle=line_style,
            color=color,
            label=f"{label_dict[estimator]}",
            markersize=5,  # larger marker size
            linewidth=2,
        )
        plt.fill_between(
            sorted_sample_sizes, perc25_conds, perc75_conds, color=color, alpha=0.1
        )  # Percentile shading

    plt.title(
        f"Condition Number by Sample Size on {dataset[0].upper() + dataset[1:]} Dataset"
    )
    plt.xlabel("Number of Samples")
    plt.ylabel(f"Condition Number")

    plt.xscale("log")
    # plt.yscale("log")

    plt.gca().xaxis.set_major_locator(ticker.LogLocator(base=10.0))
    # plt.gca().yaxis.set_major_locator(ticker.LogLocator(base=10.0))

    plt.gca().xaxis.set_minor_locator(ticker.NullLocator())
    # plt.gca().yaxis.set_minor_locator(ticker.NullLocator())

    plt.grid(True, which="major", linestyle="--", linewidth=0.5)

    # plt.legend() # legend is too large
    plt.legend(fontsize=8)

    file_name = f"shap_{dataset}_cond_by_sample_size.png"
    plt.savefig(file_name)


def plot_by_noise(result_dir, estimators, dataset):
    noise, errors, _ = process_results_files(
        result_dir, estimators, dataset, explainer="shap", mode="noise"
    )

    banzhaf_noise, banzhaf_errors, _ = process_results_files(
        result_dir, ["paired_sampling"], dataset, explainer="banzhaf", mode="noise"
    )

    errors["paired_sampling"] = banzhaf_errors["paired_sampling"]
    noise["paired_sampling"] = banzhaf_noise["paired_sampling"]
    estimators = ["paired_sampling"] + estimators

    plt.figure(figsize=(6, 4.5))
    colors = plt.cm.tab20c([1, 5, 9])
    line_styles = ["-", "--", "-."]
    marker_styles = ["o", "^", "D"]

    for estimator, color, line_style, marker_style in zip(
        estimators, colors, line_styles, marker_styles
    ):
        sorted_indices = np.argsort(noise[estimator])
        sorted_noise = np.array(noise[estimator])[sorted_indices]
        sorted_errors = np.array(errors[estimator])[sorted_indices]

        # Calculating median, 25th percentile, and 75th percentile
        median_errors = np.median(sorted_errors, axis=1)
        perc25_errors = np.percentile(sorted_errors, 25, axis=1)
        perc75_errors = np.percentile(sorted_errors, 75, axis=1)

        label_dict = {
            "paired_sampling": "Kernel Banzhaf",
            "leverage": "Leverage SHAP",
            "optimized": "Optimized Kernel SHAP",
        }

        plt.plot(
            sorted_noise,
            median_errors,
            marker=marker_style,
            linestyle=line_style,
            color=color,
            label=f"{label_dict[estimator]}",
            markersize=5,  # larger marker size
            linewidth=2,
        )
        plt.fill_between(
            sorted_noise, perc25_errors, perc75_errors, color=color, alpha=0.1
        )  # Percentile shading

    plt.title(
        f"L2 Norm Error by Sample Size on {dataset[0].upper() + dataset[1:]} Dataset"
    )
    plt.xlabel("Noise Level")
    plt.ylabel(f"l2 Error")

    plt.xscale("log")
    plt.yscale("log")

    plt.gca().xaxis.set_major_locator(ticker.LogLocator(base=10.0))
    plt.gca().yaxis.set_major_locator(ticker.LogLocator(base=10.0))

    plt.gca().xaxis.set_minor_locator(ticker.NullLocator())
    plt.gca().yaxis.set_minor_locator(ticker.NullLocator())

    plt.grid(True, which="major", linestyle="--", linewidth=0.5)

    plt.legend(fontsize=8)
    plt.savefig(f"shap_{dataset}_l2_by_noise.png")
