
from pathlib import Path
import sys; import os; sys.path.append(os.getcwd())
import analysis_util as AU
import dataset.meta as DM
from pprint import pprint

def get_best_combo(ds_to_results):
    res = {}
    for ds in ds_to_results.keys():
        best_val_acc = 0.
        for r in ds_to_results[ds]:
            if r['val_acc'] > best_val_acc:
                best_val_acc = r['val_acc']
                ps = r['params']
                best_Pi = ps['Pi']
        res[ds] = {'Pi': best_Pi}
    pprint(res)

if __name__ == '__main__':
    root_dir = Path('./hyperparam_results/Pi_results/')
    pkl_files = [f for f in root_dir.iterdir() if f.suffix == '.pkl' and f.is_file()]
    def filter_params(x, kipf=None, arch=None, center=None):
        correct_epochs = (x['dataset'] in DM.bigger_datasets and x['num_epochs'] == 150) \
                         or (x['dataset'] not in DM.bigger_datasets and x['num_epochs'] == 200)

        try:
            res =  x['Pi'] in [50, 100, 200, 300, 400] and x['learn_Xi'] == 'yes' and \
                kipf[x['dataset']]['mixup_scheme'] == x['mixup_scheme'] and \
                kipf[x['dataset']]['dof'] ==  x['dof'] and \
                arch[x['dataset']]['model'] == x['model'] and \
                arch[x['dataset']]['adj_lambda'] == x['adj_lambda'] and \
                correct_epochs and \
                center[x['dataset']]['center'] == x['center'] and \
                center[x['dataset']]['center_learned'] == x['center_learned'] and \
                x['num_layers'] == 2
            return res
        except: return False
    ds_to_results = AU.get_ds_to_results(pkl_files,
                                         lambda x: filter_params(x, kipf=DM.best_kipf,
                                                                 arch=DM.best_arch,
                                                                 center=DM.best_center))
    AU.check_ds_to_results(ds_to_results)
    print("best_Pi = ", end='')
    get_best_combo(ds_to_results)
    # create_center_compare_table(ds_to_results)

    ds_to_results_nngp = AU.get_ds_to_results(pkl_files,
                                         lambda x: filter_params(x, kipf=DM.best_nngp_kipf,
                                                                 arch=DM.best_nngp_arch,
                                                                 center=DM.best_nngp_center) and float(x['dof']) == float('inf'))
    print()
    AU.check_ds_to_results(ds_to_results_nngp)
    print("best_nngp_Pi = ", end='')
    get_best_combo(ds_to_results_nngp)