import json
import scipy
import matplotlib.pyplot as plt


titles = {
    "effective_rank": "NEAR",
    "swap_reg": "reg_swap",
    "meco_opt": "MeCo_opt",
    "zico": "ZiCo",
    "zen": "Zen-Score",
    "synflow": "SynFlow",
    "fisher": "Fisher",
    "grasp": "GraSP",
    "snip": "SNIP",
    "grad_norm": "Grad_norm",
    "num_param": "#Params",
    "flops": "FLOPs"
}

filename = "NATSBench_SSS_CIFAR10_trained.json"
with open(filename) as f:
    stats = json.load(f)

fig, axes = plt.subplots(nrows=3, ncols=4, figsize=(15, 10))
axes = axes.flatten()
print(f"{filename.center(100, '#')}")   
for i, measure in enumerate(["effective_rank", "meco_opt", "zico", "synflow", "swap_reg", "num_param", "zen", "snip", "flops", "grad_norm", "fisher", "grasp"]):
    spr_correlations = []
    kt_correlations = []
    for epoch in ["0", "1", "3", "5", "10"]:
        scores = []
        acc = []
        params = []
        for model_id in stats.keys():
            scores.append(stats[model_id][epoch][measure])
            acc.append(stats[model_id][epoch]["accuracy"])
            params.append(stats[model_id][epoch]["num_param"])
        spr_correlations.append(round(scipy.stats.spearmanr(scores, acc).statistic, 2))
        kt_correlations.append(round(scipy.stats.kendalltau(scores, acc).statistic, 2))
        print(f"For {measure} SPR: {scipy.stats.spearmanr(scores, acc).statistic:.3f}, KT: {scipy.stats.kendalltau(scores, acc).statistic:.2f}")

    axes[i].set_title(f"{titles[measure]}", fontsize=14)
    axes[i].scatter([0, 1, 3, 5, 10], spr_correlations)
    axes[i].plot([0, 1, 3, 5, 10], spr_correlations)
    axes[i].set_xlabel("Training epoch", fontsize=14)
    axes[i].set_ylabel(r"Spearman's $\rho$", fontsize=14)

plt.tight_layout()
plt.savefig("correlation_vs_training.pdf")
