import glob
import json
import scipy
import numpy as np

from api import TransNASBenchAPI
api = TransNASBenchAPI("transnas-bench_v10141024.pth")

TASK_METRIC = {
	"class_scene": "test_top1",
	"class_object": "test_top1",
	"autoencoder": "test_ssim",
	"normal": "test_ssim",
	"room_layout": "test_neg_loss",
    "segmentsemantic": "test_mIoU",
    "jigsaw": "test_top1"
}

measures = ["num_param", "flops", "effective_rank"]

for filename in glob.glob("results/*.json"):
    with open(filename) as f:
        stats = json.load(f)
    task = filename.split("_micro_")[-1].replace(".json", "")
    print(f"{task}".center(100, "#"))
    for measure in measures:
        scores = []
        acc = []
        for model_id in stats.keys():
            arch = api.index2arch(int(model_id))
            scores.append(stats[model_id][measure])
            tmp = api.get_single_metric(arch, task, TASK_METRIC[task], mode="best")
            acc.append(tmp)
        ind = np.argsort(scores)
        scores = np.array(scores)[ind]
        acc = np.array(acc)[ind]
        print(f"For {measure} PR: {scipy.stats.spearmanr(scores, acc).statistic:.2f}, KT: {scipy.stats.kendalltau(scores, acc).statistic:.2f}")
    print(len(scores))
    print(f"{task}".center(100, "#"))
    print()