import os, pickle, sys
import matplotlib.pyplot as plt
from scipy import stats
import numpy as np
import glob
from tqdm import tqdm
from prettytable import PrettyTable

d = './'
runs = []
accs = []
processed = {}

dataset='ImageNet16-120' # cifar10 cifar100 ImageNet16-120

file_list = None

if file_list is None:
    file_list = os.listdir(d)

for f in tqdm(file_list):
    if dataset + '_' not in f or '.p' not in f:
        continue
    else:
        pass
        # print(f)
    pf = open(os.path.join(d,f),'rb')
    # print(os.path.join(d,f))
    while 1:
        try:
            p = pickle.load(pf)
            if p['i'] in processed:
                idx = processed[p['i']]
                runs[idx]['logmeasures'] = {**runs[idx]['logmeasures'], **p['logmeasures']}
                continue
            processed[p['i']] = len(runs)
            runs.append(p)
            accs.append(p['testacc'])
        except:
            break
    pf.close()

t = None

print(d, len(runs))
metrics = {}
for k in runs[0]['logmeasures'].keys():
    metrics[k] = []
acc = accs

if t is None:
    hl = ['Dataset']
    hl.extend(['effective_capacity'])
    t = PrettyTable(hl)

for r in runs:
    for k, v in r['logmeasures'].items():
        metrics[k].append(v)

print(hl)

res = []
for k in hl:
    if k == 'Dataset':
        continue
    v = metrics[k]
    cr = stats.spearmanr(acc, v, nan_policy='omit').correlation
    res.append(round(cr, 3))

ds = dataset # 'CIFAR10' CIFAR100, ImageNet16-120
t.add_row([ds] + res)

print(t)