
from pathlib import Path
import sys; import os; sys.path.append(os.getcwd())
import analysis_util as AU
import dataset.meta as DM
from pprint import pprint

def get_best_combo(ds_to_results):
    res = {}
    for ds in ds_to_results.keys():
        best_val_acc = 0.
        for r in ds_to_results[ds]:
            if r['val_acc'] > best_val_acc:
                best_val_acc = r['val_acc']
                ps = r['params']
                center = ps['center']
                center_learned = ps['center_learned']
        res[ds] = {'center': center, 'center_learned': center_learned}
    pprint(res)

if __name__ == '__main__':
    root_dir = Path('./hyperparam_results/center_results/')
    pkl_files = [f for f in root_dir.iterdir() if f.suffix == '.pkl' and f.is_file()]
    def filter_params(x, kipf=None, arch=None):
        correct_epochs = (x['dataset'] in DM.bigger_datasets and x['num_epochs'] == 150) \
                         or (x['dataset'] not in DM.bigger_datasets and x['num_epochs'] == 200)
        try:
            res = x['Pi'] == 100 and x['learn_Xi'] == 'yes' and \
                kipf[x['dataset']]['mixup_scheme'] == x['mixup_scheme'] and \
                kipf[x['dataset']]['dof'] == x['dof'] and \
                arch[x['dataset']]['model'] == x['model'] and \
                arch[x['dataset']]['adj_lambda'] == x['adj_lambda'] and \
                correct_epochs and \
                x['num_layers'] == 2
            return res
        except: return False
    ds_to_results = AU.get_ds_to_results(pkl_files,
                                         lambda x: filter_params(x, kipf=DM.best_kipf,
                                                                 arch=DM.best_arch))
    AU.check_ds_to_results(ds_to_results)
    print("best_center = ", end='')
    get_best_combo(ds_to_results)

    ds_to_results_nngp = AU.get_ds_to_results(pkl_files,
                                              lambda x: filter_params(x, kipf=DM.best_nngp_kipf,
                                                                         arch=DM.best_nngp_arch
                                                                      ) and float(x['dof']) == float('inf'))
    AU.check_ds_to_results(ds_to_results_nngp)
    print()
    print("best_nngp_center = ", end='')
    get_best_combo(ds_to_results_nngp)