"""
Can be used to reproduce table 12.
"""

import glob
import json
import scipy
import numpy as np
import re
from collections import defaultdict
import random

measures = ["swap", "swap_reg", "meco_opt", "zico", "zen", "synflow", "fisher", "grasp", "snip", "grad_norm", "num_param", "flops", "effective_rank"]

for filename in glob.glob("../NAS_Benchmarks/results/*.json"):
    with open(filename) as f:
        stats = json.load(f)

    
    correlations = []
    print(f"{filename.center(100, '#')}")    
    for measure in measures:
        scores = []
        acc = []
        for model_id in stats.keys():
            scores.append(stats[model_id][measure])
            acc.append(stats[model_id]["accuracy"])
        correlations.append(round(scipy.stats.spearmanr(scores, acc).statistic, 2))
        print(f"For {measure} SPR: {scipy.stats.spearmanr(scores, acc).statistic:.2f}, KT: {scipy.stats.kendalltau(scores, acc).statistic:.2f}")
    print(len(scores))
    
    keys = list(stats.keys())
    for measure in measures:
        correct = 0
        for _ in range(1_000_000):
            sample = random.sample(keys, 2)
            if stats[sample[0]]["accuracy"] > stats[sample[1]]["accuracy"]:
                if stats[sample[0]][measure] > stats[sample[1]][measure]:
                    correct += 1
            elif stats[sample[0]]["accuracy"] < stats[sample[1]]["accuracy"]:
                if stats[sample[0]][measure] < stats[sample[1]][measure]:
                    correct += 1
            elif stats[sample[0]]["accuracy"] == stats[sample[1]]["accuracy"]:
                if stats[sample[0]][measure] == stats[sample[0]][measure]:
                    correct += 1
        print(f"For {measure}: {correct / 1_000_000:.3f} are ranked correctly.")
    print(f"{'#'*100}\n\n")
     
