import os
import sys
import json
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from utils import (
    calculate_l2_norm,
    calculate_objective,
    calculate_gamma,
)


def process_results_files(result_dir, estimators, plot_error, dataset, mode=None):
    print(f"Processing results for {dataset}")
    errors = {estimator: [] for estimator in estimators}
    suffixes = {estimator: [] for estimator in estimators}
    estimations = {estimator: [] for estimator in estimators}

    if mode == "noise":
        est_result_dir = os.path.join(result_dir, "noise")
    else:
        est_result_dir = result_dir

    with open(os.path.join(result_dir, f"banzhaf_exact_{dataset}.json"), "r") as file:
        exact_data = json.load(file)

    estimations["exact"] = [
        exact_data[random_state][feature]
        for random_state in exact_data.keys()
        for feature in exact_data[random_state]
    ]

    for filename in os.listdir(est_result_dir):
        if filename.endswith(".json"):
            for estimator in estimators:
                if estimator in filename:
                    suffix = float(".".join(filename.split("_")[-1].split(".")[:-1]))
                    suffixes[estimator].append(suffix)

                    estimated_filepath = os.path.join(est_result_dir, filename)
                    with open(estimated_filepath, "r") as file:
                        estimated_data = json.load(file)
                    error = []
                    curr_estimations = []
                    for random_state, estimated_values in estimated_data.items():
                        for feature in estimated_values.keys():
                            curr_estimations.append(estimated_values[feature])
                        if plot_error == "l2":
                            error.append(
                                calculate_l2_norm(
                                    exact_data[random_state], estimated_values
                                )
                            )
                        elif plot_error == "objective":
                            S = np.load(
                                os.path.join(
                                    result_dir, f"S_{dataset}_{random_state}.npy"
                                )
                            )
                            b = np.load(
                                os.path.join(
                                    result_dir, f"b_{dataset}_{random_state}.npy"
                                )
                            )
                            error.append(
                                calculate_objective(
                                    exact_data[random_state], S, b, estimated_values
                                )
                            )
                    estimations[estimator].append(curr_estimations)
                    errors[estimator].append(error)
                    break
    return suffixes, errors, estimations


def plot_all_datasets(datasets, estimators, error, noise, model_type):
    num_datasets = len(datasets)
    if num_datasets == 4:
        fig, axs = plt.subplots(
            1, num_datasets, figsize=(2.6 * num_datasets, 2.6), sharey=False
        )
        rect = [0, 0.12, 1, 1]
    elif num_datasets == 8:
        fig, axs = plt.subplots(2, 4, figsize=(9.7, 4.3), sharey=False)
        axs = axs.flatten()
        rect = [0, 0.08, 1, 1]

    count = 0
    y_label = True
    for ax, dataset in zip(axs, datasets):
        if count % 4 == 0 and num_datasets == 8:
            y_label = True
        elif count % 4 != 0:
            y_label = False
        curr_result_dir = (
            f"results_{dataset}" if model_type != "nn" else f"nn_results_{dataset}"
        )
        if count == 0:
            l, lab = plot_dataset(
                ax, curr_result_dir, estimators, error, dataset, y_label, noise
            )
        else:
            plot_dataset(
                ax, curr_result_dir, estimators, error, dataset, y_label, noise
            )
        count += 1

    plt.subplots_adjust(wspace=0.01, hspace=0.1)
    plt.tight_layout(rect=rect)
    error_name = "L2 Norm" if error == "l2" else "Relative Objective"

    fig.legend(l, lab, loc="lower center", ncol=4, fontsize=11, frameon=False)

    if len(estimators) == 2:
        file_name = f"all_datasets_swor_vs_swr.pdf"
    else:
        if noise is not None:
            print("Saving all datasets by noise")
            # fig.suptitle(f"{error_name} Error by Noise Level", fontsize=11)
            file_name = (
                f"all_datasets_{error}_by_noise.pdf"
                if model_type != "nn"
                else f"nn_all_datasets_{error}_by_noise.pdf"
            )
        else:
            print("Saving all datasets by sample size")
            # fig.suptitle(f"{error_name} Error by Sample Size", fontsize=11)
            file_name = (
                f"all_datasets_{error}_by_sample_size.pdf"
                if model_type != "nn"
                else f"nn_all_datasets_{error}_by_sample_size.pdf"
            )
    plt.savefig(file_name)


def plot_dataset(ax, result_dir, estimators, error, dataset, y_label, noise):
    if noise is None:
        mode = None
    else:
        mode = "noise"

    sample_sizes, errors, _ = process_results_files(
        result_dir, estimators, error, dataset, mode
    )

    if len(estimators) == 2:
        colors = plt.cm.tab20c([1])
        colors = np.concatenate((colors, plt.cm.tab20b([13])))
        line_styles = ["-", (0, (5, 10))]
        marker_styles = ["o", "p"]
    else:
        if noise is None:
            colors = plt.cm.tab20c([1, 5, 9, 13])
        else:
            colors = plt.cm.tab20c([0, 4, 8, 12])
        line_styles = ["-", (5, (10, 3)), "--", "-."]
        marker_styles = ["o", "^", "D", "s"]

    lines = []

    for estimator, color, line_style, marker_style in zip(
        estimators, colors, line_styles, marker_styles
    ):
        sorted_indices = np.argsort(sample_sizes[estimator])
        sorted_sample_sizes = np.array(sample_sizes[estimator])[sorted_indices]
        sorted_errors = np.array(errors[estimator])[sorted_indices]
        median_errors = np.median(sorted_errors, axis=1)
        perc25_errors = np.percentile(sorted_errors, 25, axis=1)
        perc75_errors = np.percentile(sorted_errors, 75, axis=1)

        label_dict = {
            "paired_sampling": "Kernel Banzhaf",
            "kernel": "Kernel Banzhaf (excl. Pairs)",
            "mc": "MC",
            "msr": "MSR",
            "swor": "Kernel Banzhaf (SWOR)",
        }

        (line,) = ax.plot(
            sorted_sample_sizes,
            median_errors,
            marker=marker_style,
            linestyle=line_style,
            color=color,
            label=f"{label_dict[estimator]}",
            markersize=3,
            linewidth=1,
            # markersize=1.5,
            # linewidth=0.9,
        )
        lines.append(line)
        ax.fill_between(
            sorted_sample_sizes, perc25_errors, perc75_errors, color=color, alpha=0.2
        )
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
        ax.spines["left"].set_linewidth(0.5)
        ax.spines["bottom"].set_linewidth(0.5)

    error_name = "L2 Norm" if error == "l2" else "Relative Objective"

    dataset_dict = {
        "diabetes": "Diabetes (n=8)",
        "adult": "Adult (n=14)",
        "bank": "Bank (n=16)",
        "german_credit": "German Credit (n=20)",
        "nhanes": "NHANES (n=79)",
        "brca": "BRCA (n=100)",
        "communitiesandcrime": "Communities and Crime (n=101)",
        "tuandromd": "TUANDROMD (n=241)",
    }
    ax.set_title(f"{dataset_dict[dataset]}", fontsize=11)
    if noise is not None:
        ax.set_xlabel("Noise Level", fontsize=10)
    else:
        ax.set_xlabel("Number of Samples", fontsize=10)
    ax.tick_params(axis="x", labelsize=7)
    if y_label:
        ax.set_ylabel(f"{error_name} Error", fontsize=10)
    ax.tick_params(axis="y", labelsize=7)
    ax.set_xscale("log")
    ax.set_yscale("log")

    ax.xaxis.set_major_locator(
        ticker.LogLocator(base=10.0)
    )  # Major ticks at powers of 10 for x-axis
    ax.yaxis.set_major_locator(
        ticker.LogLocator(base=10.0)
    )  # Major ticks at powers of 10 for y-axis

    ax.minorticks_off()

    # ax.grid(True, which="major", linestyle="--", linewidth=0.5, color="gray", alpha=0.3)

    return lines, [line.get_label() for line in lines]


def plot_by_sample_size_median(result_dir, estimators, error, dataset):
    sample_sizes, errors, _ = process_results_files(
        result_dir, estimators, error, dataset
    )

    plt.figure(figsize=(6, 4.5))
    colors = plt.cm.tab20c([1, 5, 9, 13, 17])
    line_styles = ["-", (5, (10, 3)), "--", "-.", (0, (3, 1, 1, 1, 1, 1))]
    marker_styles = ["o", "^", "D", "s", "v"]

    for estimator, color, line_style, marker_style in zip(
        estimators, colors, line_styles, marker_styles
    ):

        sorted_indices = np.argsort(sample_sizes[estimator])
        sorted_sample_sizes = np.array(sample_sizes[estimator])[sorted_indices]
        sorted_errors = np.array(errors[estimator])[sorted_indices]

        # Calculating median, 25th percentile, and 75th percentile
        median_errors = np.median(sorted_errors, axis=1)
        perc25_errors = np.percentile(sorted_errors, 25, axis=1)
        perc75_errors = np.percentile(sorted_errors, 75, axis=1)

        label_dict = {
            "paired_sampling": "Kernel Banzhaf",
            "kernel": "Kernel Banzhaf (excl. Pairs)",
            "mc": "MC",
            "msr": "MSR",
        }

        plt.plot(
            sorted_sample_sizes,
            median_errors,
            marker=marker_style,
            linestyle=line_style,
            color=color,
            label=f"{label_dict[estimator]}",
            markersize=5,  # larger marker size
            linewidth=2,
        )
        plt.fill_between(
            sorted_sample_sizes, perc25_errors, perc75_errors, color=color, alpha=0.1
        )  # Percentile shading

    error_name = "L2 Norm" if error == "l2" else "Relative Objective"
    plt.title(
        f"{error_name} Error by Sample Size on {dataset[0].upper() + dataset[1:]} Dataset"
    )
    plt.xlabel("Number of Samples")
    plt.ylabel(f"{error_name} Error")

    plt.xscale("log")
    plt.yscale("log")

    plt.gca().xaxis.set_major_locator(ticker.LogLocator(base=10.0))
    plt.gca().yaxis.set_major_locator(ticker.LogLocator(base=10.0))

    plt.gca().xaxis.set_minor_locator(ticker.NullLocator())
    plt.gca().yaxis.set_minor_locator(ticker.NullLocator())

    plt.grid(True, which="major", linestyle="--", linewidth=0.5)

    # plt.legend() # legend is too large
    plt.legend(fontsize=8)

    file_name = (
        f"{dataset}_{error}_by_sample_size.png"
        if "nn" not in result_dir
        else f"nn_{dataset}_{error}_by_sample_size.png"
    )
    plt.savefig(file_name)


def plot_by_noise(result_dir, estimators, error, dataset):
    noise, errors, _ = process_results_files(
        result_dir, estimators, error, dataset, mode="noise"
    )
    plt.figure(figsize=(8, 6))
    colors = plt.cm.tab20c([1, 5, 9, 13, 17])

    for estimator, color in zip(estimators, colors):
        sorted_indices = np.argsort(noise[estimator])
        sorted_noise = np.array(noise[estimator])[sorted_indices]
        sorted_errors = np.array(errors[estimator])[sorted_indices]

        # Calculating median, 25th percentile, and 75th percentile
        median_errors = np.median(sorted_errors, axis=1)
        perc25_errors = np.percentile(sorted_errors, 25, axis=1)
        perc75_errors = np.percentile(sorted_errors, 75, axis=1)

        label_dict = {
            "paired_sampling": "Kernel Banzhaf",
            "kernel": "Kernel Banzhaf (excl. Pairs)",
            "mc": "MC",
            "msr": "MSR",
        }

        plt.plot(
            sorted_noise,
            median_errors,
            "o-",
            color=color,
            label=f"{label_dict[estimator]}",
        )
        plt.fill_between(
            sorted_noise, perc25_errors, perc75_errors, color=color, alpha=0.1
        )  # Percentile shading

    error_name = "L2 Norm" if error == "l2" else "Relative Objective"
    plt.title(
        f"{error_name} Error by Noise Level on {dataset[0].upper() + dataset[1:]} Dataset"
    )
    plt.xscale("log")
    plt.yscale("log")

    plt.gca().xaxis.set_major_locator(ticker.LogLocator(base=10.0))
    plt.gca().yaxis.set_major_locator(ticker.LogLocator(base=10.0))

    plt.gca().xaxis.set_minor_locator(ticker.NullLocator())
    plt.gca().yaxis.set_minor_locator(ticker.NullLocator())

    plt.grid(True, which="major", linestyle="--", linewidth=0.5)

    plt.legend()
    plt.savefig(f"{dataset}_{error}_by_noise.png")


def plot_by_set_function(result_dir, estimators, plot_error, dataset):
    result_dirs = [result_dir, f"{result_dir}_1", f"{result_dir}_2"]

    errors = {estimator: [] for estimator in estimators}
    values = {estimator: [] for estimator in estimators}

    if dataset == "adult":
        sample_size = "100."
    else:
        sample_size = "300."

    with open(os.path.join(result_dir, f"banzhaf_exact_{dataset}.json"), "r") as file:
        exact_data = json.load(file)

    for result_dir in result_dirs:
        for filename in os.listdir(result_dir):
            if filename.endswith(".json"):
                for estimator in estimators:
                    if all(x in filename for x in [estimator, sample_size]):
                        estimated_filepath = os.path.join(result_dir, filename)
                        with open(estimated_filepath, "r") as file:
                            estimated_data = json.load(file)
                        error = []
                        value = []
                        for random_state, estimated_values in estimated_data.items():
                            S = np.load(
                                os.path.join(
                                    result_dir, f"S_{dataset}_{random_state}.npy"
                                )
                            )
                            b = np.load(
                                os.path.join(
                                    result_dir, f"b_{dataset}_{random_state}.npy"
                                )
                            )
                            value.append(
                                calculate_gamma(exact_data[random_state], S, b) - 1
                            )
                            if plot_error == "l2":
                                error.append(
                                    calculate_l2_norm(
                                        exact_data[random_state], estimated_values
                                    )
                                )
                            elif plot_error == "objective":
                                error.append(
                                    calculate_objective(
                                        exact_data[random_state], S, b, estimated_values
                                    )
                                )

                        errors[estimator].append(error)
                        values[estimator] = value

    plt.figure(figsize=(6, 4.5))
    colors = plt.cm.tab20c([1, 5, 9, 13, 17])

    for estimator, color in zip(estimators, colors):

        errors[estimator] = np.array(errors[estimator]).T

        sorted_indices = np.argsort(values[estimator])
        sorted_values = np.array(values[estimator])[sorted_indices]
        sorted_errors = np.array(errors[estimator])[sorted_indices]

        median_errors = np.median(sorted_errors, axis=1)
        perc25_errors = np.percentile(sorted_errors, 25, axis=1)
        perc75_errors = np.percentile(sorted_errors, 75, axis=1)

        label_dict = {
            "paired_sampling": "Kernel Banzhaf",
            "kernel": "Kernel Banzhaf (excl. Pairs)",
            "mc": "MC",
            "msr": "MSR",
        }

        plt.plot(
            sorted_values,
            median_errors,
            "o-",
            color=color,
            label=f"{label_dict[estimator]}",
            markersize=5,  # larger marker size
            linewidth=2,
        )
        plt.fill_between(
            sorted_values, perc25_errors, perc75_errors, color=color, alpha=0.1
        )  # Percentile shading

    error_name = "L2 Norm" if plot_error == "l2" else "Relative Objective"
    plt.title(
        f"{error_name} Error by Gamma Function on {dataset[0].upper() + dataset[1:]} Dataset"
    )
    plt.xlabel("Value of Gamma Function")
    plt.ylabel(f"{error_name} Error")

    plt.xscale("log")
    plt.yscale("log")

    plt.gca().xaxis.set_major_locator(ticker.LogLocator(base=10.0))
    plt.gca().yaxis.set_major_locator(ticker.LogLocator(base=10.0))

    plt.gca().xaxis.set_minor_locator(ticker.NullLocator())
    plt.gca().yaxis.set_minor_locator(ticker.NullLocator())

    plt.grid(True, which="major", linestyle="--", linewidth=0.5)

    plt.legend()

    file_name = f"{dataset}_{plot_error}_by_set_function.png"
    plt.savefig(file_name)


def plot_all_time(datasets, estimators):
    fig, axs = plt.subplots(2, 4, figsize=(9.7, 4.3), sharey=False)
    axs = axs.flatten()
    rect = [0, 0.08, 1, 1]

    count = 0
    y_label = True
    for ax, dataset in zip(axs, datasets):
        if count % 4 == 0:
            y_label = True
        elif count % 4 != 0:
            y_label = False
        curr_result_dir = f"results_{dataset}"
        if count == 0:
            l, lab = plot_time_by_sample_size(
                ax, curr_result_dir, estimators, dataset, y_label
            )
        else:
            plot_time_by_sample_size(ax, curr_result_dir, estimators, dataset, y_label)
        count += 1

    plt.subplots_adjust(wspace=0.01, hspace=0.1)
    plt.tight_layout(rect=rect)

    fig.legend(l, lab, loc="lower center", ncol=4, fontsize=11, frameon=False)

    # fig.suptitle("Time (s) by Sample Size", fontsize=11)
    plt.savefig("all_datasets_time_by_sample_size.pdf")


def plot_time_by_sample_size(ax, result_dir, estimators, dataset, y_label):
    times = {est: [] for est in estimators}
    estimator_dict = {
        "kernel_paired": "paired_sampling",
        "kernel": "kernel",
        "mc": "mc",
        "msr": "msr",
        "exact": "exact",
        "swor": "swor",
    }

    for filename in os.listdir(result_dir):
        if filename.startswith("time") and filename.endswith(".json"):
            parts = filename.split("_")
            sample_size = int(parts[-1].split(".")[0])
            with open(os.path.join(result_dir, filename), "r") as file:
                content = json.load(file)
            for state, measurements in content.items():
                for estimator, curr_time in measurements.items():
                    if estimator_dict[estimator] in estimators:
                        times[estimator_dict[estimator]].append(
                            (sample_size, curr_time)
                        )

    colors = plt.cm.tab20c([1, 5, 9, 13])
    line_styles = ["-", (5, (10, 3)), "--", "-."]
    marker_styles = ["o", "^", "D", "s"]

    lines = []

    for estimator, color, line_style, marker_style in zip(
        estimators, colors, line_styles, marker_styles
    ):
        sample_sizes, all_times = zip(*sorted(times[estimator], key=lambda x: x[0]))
        unique_sizes = sorted(set(sample_sizes))
        x = np.array(unique_sizes)
        median_y = []
        percentile_25 = []
        percentile_75 = []

        for size in unique_sizes:
            curr_times = [time for s, time in times[estimator] if s == size]
            curr_times = np.array(curr_times)
            median_y.append(np.median(curr_times))
            percentile_25.append(np.percentile(curr_times, 25))
            percentile_75.append(np.percentile(curr_times, 75))

        # Convert lists to numpy arrays for plotting
        median_y = np.array(median_y)
        percentile_25 = np.array(percentile_25)
        percentile_75 = np.array(percentile_75)

        label_dict = {
            "paired_sampling": "Kernel Banzhaf",
            "kernel": "Kernel Banzhaf (excl. Pairs)",
            "mc": "MC",
            "msr": "MSR",
        }

        (line,) = ax.plot(
            x,
            median_y,
            marker=marker_style,
            linestyle=line_style,
            color=color,
            label=f"{label_dict[estimator]}",
            markersize=3,
            linewidth=1,
        )
        lines.append(line)
        ax.fill_between(x, percentile_25, percentile_75, color=color, alpha=0.2)
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
        ax.spines["left"].set_linewidth(0.5)
        ax.spines["bottom"].set_linewidth(0.5)

    dataset_dict = {
        "diabetes": "Diabetes (n=8)",
        "adult": "Adult (n=14)",
        "bank": "Bank (n=16)",
        "german_credit": "German Credit (n=20)",
        "nhanes": "NHANES (n=79)",
        "brca": "BRCA (n=100)",
        "communitiesandcrime": "Communities and Crime (n=101)",
        "tuandromd": "TUANDROMD (n=241)",
    }
    ax.set_title(f"{dataset_dict[dataset]}", fontsize=11)
    ax.set_xlabel("Number of Samples", fontsize=10)
    ax.tick_params(axis="x", labelsize=7)
    if y_label:
        ax.set_ylabel("Time (s)", fontsize=10)
        ax.tick_params(axis="y", labelsize=7)
    ax.set_xscale("log")
    ax.set_yscale("log")
    ax.xaxis.set_major_locator(ticker.LogLocator(base=10.0))
    ax.yaxis.set_major_locator(ticker.LogLocator(base=10.0))
    ax.minorticks_off()
    # ax.grid(True, which="major", linestyle="--", linewidth=0.5, color="gray", alpha=0.6)

    return lines, [line.get_label() for line in lines]


def plot_by_feature(datasets, estimators):
    fig, axes = plt.subplots(1, len(estimators), figsize=(9.5, 3.2))
    # colors = plt.cm.tab10([2, 1, 3, 0, 4, 8, 6, 9])
    colors = plt.cm.tab20([4, 2, 7, 0, 8, 16, 12, 18])
    markers = ["o", "^", "v", "s", "p", "D", "x", "+"]
    # markers = ['o']
    dataset_colors = {dataset: colors[i] for i, dataset in enumerate(datasets)}
    dataset_markers = {
        dataset: markers[i % len(markers)] for i, dataset in enumerate(datasets)
    }
    all_estimations = {
        dataset: {estimator: [] for estimator in estimators} for dataset in datasets
    }
    all_errors = {
        dataset: {estimator: [] for estimator in estimators} for dataset in datasets
    }
    exact_banzhaf = {dataset: [] for dataset in datasets}
    size_index = 2
    for dataset in datasets:
        result_dir = f"results_{dataset}"
        sample_sizes, errors, estimations = process_results_files(
            result_dir, estimators, "l2", dataset
        )
        sorted_indices = {}
        sorted_estimations = {}
        sorted_errors = {}
        for estimator in estimators:
            sorted_indices[estimator] = np.argsort(sample_sizes[estimator])
            sorted_estimations[estimator] = np.array(estimations[estimator])[
                sorted_indices[estimator]
            ]
            sorted_errors[estimator] = np.array(errors[estimator])[
                sorted_indices[estimator]
            ]
            all_estimations[dataset][estimator] = sorted_estimations[estimator][
                size_index
            ]
            all_errors[dataset][estimator] = sorted_errors[estimator][size_index]
        exact_banzhaf[dataset] = np.array(estimations["exact"])

        for estimator in estimators:
            if dataset in ["diabetes", "adult", "bank", "german_credit"]:
                all_estimations[dataset][estimator] = np.array(
                    all_estimations[dataset][estimator]
                ).reshape(
                    50,
                    int(len(all_estimations[dataset][estimator]) / 50),
                    len(all_estimations[dataset][estimator][0]),
                )
            else:
                all_estimations[dataset][estimator] = np.array(
                    all_estimations[dataset][estimator]
                ).reshape(50, int(len(all_estimations[dataset][estimator]) / 50))
        if dataset in ["diabetes", "adult", "bank", "german_credit"]:
            exact_banzhaf[dataset] = np.array(exact_banzhaf[dataset]).reshape(
                50,
                int(exact_banzhaf[dataset].shape[0] / 50),
                exact_banzhaf[dataset].shape[1],
            )
        else:
            exact_banzhaf[dataset] = np.array(exact_banzhaf[dataset]).reshape(
                50, int(exact_banzhaf[dataset].shape[0] / 50)
            )

    legend = True
    lines = []
    dataset_dict = {
        "diabetes": "Diabetes (n=8)",
        "adult": "Adult (n=14)",
        "bank": "Bank (n=16)",
        "german_credit": "German Credit (n=20)",
        "nhanes": "NHANES (n=79)",
        "brca": "BRCA (n=100)",
        "communitiesandcrime": "Communities and Crime (n=101)",
        "tuandromd": "TUANDROMD (n=241)",
    }

    for i, estimator in enumerate(estimators):
        exacts = []
        estimates = []
        ax = axes[i]
        for dataset in datasets:
            # normalize exact_banzhaf[dataset][0] and all_estimations[dataset][estimator][0] to [-1, 1]
            exact = exact_banzhaf[dataset][6]
            estimate = all_estimations[dataset][estimator][6]
            # take the maximum magnitude value (over both predicted and true banzhaf values) then divide all values in that dataset by this number
            max_value = np.max(np.abs(np.concatenate((exact, estimate))))
            exact = exact / max_value
            estimate = estimate / max_value
            exacts.extend(exact.flatten())
            estimates.extend(estimate.flatten())
            line = ax.scatter(
                exact,
                estimate,
                s=10,
                color=dataset_colors[dataset],
                marker=dataset_markers[dataset],
                label=dataset_dict[dataset],
                alpha=0.7,
            )
            lines.append(line) if legend else None
        legend = False
        ax.plot([-1.0, 1.0], [-1.0, 1.0], "g--", alpha=0.5, linewidth=0.5)
        label_dict = {
            "paired_sampling": "Kernel Banzhaf",
            "mc": "MC",
            "msr": "MSR",
        }
        differences = np.array(exacts) - np.array(estimates)
        l2_norm_errors = np.linalg.norm(differences)
        ax.set_title(
            f"{label_dict[estimator]} (l2-norm: {l2_norm_errors:.2f})", fontsize=11,
        )
        ax.set_xlabel("Exact Banzhaf Values", fontsize=9)
        ax.set_ylabel("Estimated Banzhaf Values", fontsize=9)
        ax.spines["top"].set_linewidth(0.5)
        ax.spines["right"].set_linewidth(0.5)
        ax.spines["left"].set_linewidth(0.5)
        ax.spines["bottom"].set_linewidth(0.5)
        ax.grid(True, linestyle="--", linewidth=0.3, color="gray", alpha=0.3)

    # plt.subplots_adjust(bottom=0.2)
    plt.subplots_adjust(wspace=0.01, hspace=0.1)
    plt.tight_layout(rect=[0, 0.2, 1, 1])
    fig.legend(
        lines,
        [line.get_label() for line in lines],
        # title='Datasets',
        loc="lower center",
        # bbox_to_anchor=(0.5, -0.1),
        ncol=4,
        fontsize=10.5,
        frameon=False,
    )
    plt.savefig(f"all_datasets_by_feature.pdf")
