import pickle
from utils.plots import compute_normalized_rj
import numpy as np
from math import comb
from scipy.special import beta
import matplotlib.pyplot as plt
import os

if __name__=="__main__":
    data_path = "results/marg_contrib_for_rp.pkl"
    if not os.path.exists(data_path):
        raise FileNotFoundError(f"Missing file: '{data_path}'. Please run the experiment to generate it.")
    with open(data_path, "rb") as f:
        marg_contrib_dict = pickle.load(f)

    rj_dict_normalized = compute_normalized_rj(marg_contrib_dict)


    n = len(next(iter(rj_dict_normalized.values()))[0])
    j = np.arange(1, n + 1)

    shapley_w = np.ones(n) / n
    banzhaf_w = [(1 / 2 ** (n - 1)) * comb(n - 1, j_i - 1) for j_i in j]
    beta_w = [
        comb(n - 1, j_i - 1) * beta(j_i + 1 - 1, n - j_i + 4) / beta(4, 1)
        for j_i in j
    ]

    fig, ax1 = plt.subplots(figsize=(8, 4))
    colors = {"breast": "#AEC6CF", "titanic": "#FFB347"}

    for dataset, (rj_mean, rj_se) in rj_dict_normalized.items():
        fmt = "o" if dataset == "breast" else "s"
        ax1.errorbar(
            j, rj_mean, yerr=rj_se, fmt=fmt,
            label=f"{dataset.capitalize()} $r_j$",
            color=colors.get(dataset, "gray"), capsize=3
        )

    ax1.set_xlabel("$j$")
    ax1.set_ylabel("$r_j$ (mean ± SE)")
    ax1.set_ylim(0, 1)

    ax2 = ax1.twinx()
    ax2.plot(j, shapley_w, "--", label="Shapley", color="#B39EB5")
    ax2.plot(j, beta_w, "-.", label="Beta (4,1)", color="#77DD77")
    ax2.plot(j, banzhaf_w, ":", label="Banzhaf", color="#FF6961")
    ax2.set_ylabel(r"$\omega_j$")
    ax2.set_ylim(0, max(shapley_w.max(), max(beta_w), max(banzhaf_w)) * 1.2)

    lines1, labels1 = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax1.legend(lines1 + lines2, labels1 + labels2, loc="upper right", fontsize="small")

    ax1.grid(True)
    plt.tight_layout()


    output_path = "results/normalized_rj_max_j_plot.png"
    plt.savefig(output_path, dpi=300)
