"""
Can be used to reproduce figure A.6.
"""

import glob
import json
import scipy
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt


titles = {
    "effective_rank": "NEAR",
    "swap_reg": "reg_swap",
    "meco_opt": "MeCo_opt",
    "zico": "ZiCo",
    "zen": "Zen-Score",
    "synflow": "SynFlow",
    "fisher": "Fisher",
    "grasp": "GraSP",
    "snip": "SNIP",
    "grad_norm": "Grad_norm",
    "num_param": "#Params",
    "flops": "FLOPs"
}
average_ranking = {"NATSBench_TSS": defaultdict(int), "NATSBench_SSS": defaultdict(int), "NASBench-101": defaultdict(int), "overall": defaultdict(int)}
for filename in glob.glob("../NAS_Benchmarks/results/*.json"):
    with open(filename) as f:
        stats = json.load(f)

    fig, axes = plt.subplots(nrows=3, ncols=4, figsize=(15, 10))
    axes = axes.flatten()
    correlations = []
    print(f"{filename.center(100, '#')}")   
    for i, measure in enumerate(["effective_rank", "meco_opt", "zico", "synflow", "swap_reg", "num_param", "zen", "snip", "flops", "grad_norm", "fisher", "grasp"]):
        scores = []
        acc = []
        params = []
        for model_id in stats.keys():
            scores.append(stats[model_id][measure])
            acc.append(stats[model_id]["accuracy"])
            params.append(stats[model_id]["num_param"])
        correlations.append(round(scipy.stats.spearmanr(scores, acc).statistic, 2))
        print(f"For {measure} SPR: {scipy.stats.spearmanr(scores, acc).statistic:.3f}, KT: {scipy.stats.kendalltau(scores, acc).statistic:.2f}")
        acc_sorted_idx = np.argsort(acc)
        scores_sorted_idx = np.argsort(scores)

        ranks_acc = np.empty_like(acc_sorted_idx)
        ranks_scores = np.empty_like(scores_sorted_idx)

        ranks_acc[acc_sorted_idx] = np.arange(len(acc))
        ranks_scores[scores_sorted_idx] = np.arange(len(scores))

        rank_differences = np.abs(ranks_acc - ranks_scores)
        top_10_indices = np.argsort(rank_differences)[-50:]


        axes[i].set_title(f"{titles[measure]}", fontsize=14)
        axes[i].scatter(np.array(params)[::100], np.array(acc)[::100], color="#0077BB", alpha=0.5)
        axes[i].scatter(np.array(params)[top_10_indices], np.array(acc)[top_10_indices], label=f"{measure}", color="#EE7733", alpha=0.5)
        axes[i].set_xlabel("Number of Parameters", fontsize=14)
        axes[i].set_ylabel("Test accuracy", fontsize=14)

    plt.tight_layout()
    name = filename.split("/")[-1].replace(".json", "")
    print(name)
    plt.savefig(f"{name}.pdf")
