import argparse
import os
import numpy as np
import matplotlib.pyplot as plt

from experiments.experimental_pipeline import DEFAULT_RESULTS


def plot_shapley_comparison(dataset_name: str, show_xerr: bool = False):
    npz_path = os.path.join(DEFAULT_RESULTS, f"{dataset_name}_shapley_values.npz")
    if not os.path.exists(npz_path):
        print(f"File not found: {npz_path}")
        return

    print(f"Loading Shapley values from: {npz_path}")
    data = np.load(npz_path)
    mc_runs = data["mc_runs"] if "mc_runs" in data else [0]
    print(f"Using the results of {np.shape(mc_runs)[0]} Monte Carlo runs")

    # Ground truth (normalized MC estimate)
    gt_phi = data["mc_phi"]
    mc_err = data["mc_err"] if "mc_err" in data and show_xerr else None

    # Approximations
    if_phi = data["if_phi"]
    rif_phi = data["rif_phi"]

    # Plot
    plt.figure(figsize=(8, 8))

    if show_xerr and mc_err is not None:
        plt.errorbar(gt_phi, if_phi, xerr=mc_err, fmt='o', color="green",
                     label="Influence Function Estimate", alpha=0.7, capsize=3)
        plt.errorbar(gt_phi, rif_phi, xerr=mc_err, fmt='o', color="blue",
                     label="Rescaled IF Estimate", alpha=0.7, capsize=3)
    else:
        plt.scatter(gt_phi, if_phi, color="green", label="Influence Function Estimate", alpha=0.7)
        plt.scatter(gt_phi, rif_phi, color="blue", label="Rescaled IF Estimate", alpha=0.7)

    # Diagonal line y = x
    min_val = min(gt_phi.min(), if_phi.min(), rif_phi.min())
    max_val = max(gt_phi.max(), if_phi.max(), rif_phi.max())
    plt.plot([min_val, max_val], [min_val, max_val], 'k--', label="y = x")

    plt.xlabel("Monte Carlo Shapley Values (Ground Truth)")
    plt.ylabel("Estimated Shapley Values")
    plt.title(f"Shapley Estimate Comparison for {dataset_name}")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()

    plt_path = os.path.join(DEFAULT_RESULTS, f"{dataset_name}_shapley_scatter.png")
    plt.savefig(plt_path)
    print(f"Plot saved to: {plt_path}")
    plt.show()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Plot Shapley estimate comparisons.")
    parser.add_argument("--dataset", type=str, required=True, help="Name of the dataset to plot")
    parser.add_argument("--show_xerr", action="store_true", help="Include error bars on the x-axis")
    args = parser.parse_args()

    plot_shapley_comparison(args.dataset, show_xerr=args.show_xerr)
