import glob
import json
import scipy
import numpy as np
import re
from collections import defaultdict

measures = ["effective_rank", "swap", "swap_reg", "meco_opt", "zico", "zen", "synflow", "fisher", "grasp", "snip", "grad_norm", "num_param", "flops"]

average_ranking = {"NATSBench_TSS": defaultdict(int), "NATSBench_SSS": defaultdict(int), "NASBench-101": defaultdict(int), "overall": defaultdict(int)}
for filename in glob.glob("*.json"):
    with open(filename) as f:
        stats = json.load(f)

    
    correlations = []
    print(f"{filename.center(100, '#')}")    
    for measure in measures:
        scores = []
        acc = []
        for model_id in stats.keys():
            scores.append(stats[model_id][measure])
            acc.append(stats[model_id]["accuracy"])
        correlations.append(round(scipy.stats.spearmanr(scores, acc).statistic, 2))
        print(f"For {measure} SPR: {scipy.stats.spearmanr(scores, acc).statistic:.2f}, KT: {scipy.stats.kendalltau(scores, acc).statistic:.2f}")
    print(len(scores))
    print(f"{'#'*100}\n\n")

    
    search_space = re.findall(r"^.+?(?=_CIFAR|_ImageNet|$)", filename)[0]
    num_datasets = 3 if "NATS" in search_space else 1
    ranks = scipy.stats.rankdata(-np.array(correlations), method="dense")
    for i, measure in enumerate(measures):
        average_ranking[search_space][measure] += ranks[i] / num_datasets

for search_space in ["NATSBench_TSS", "NATSBench_SSS", "NASBench-101"]:
    num_datasets = 3 if "NATS" in search_space else 1
    for measure in measures:
        average_ranking["overall"][measure] += (num_datasets * average_ranking[search_space][measure])/7

print(average_ranking)
