from pathlib import Path
import sys; import os; sys.path.append(os.getcwd())
import analysis_util as AU
import dataset.meta as DM
from pprint import pprint

def get_best_combo(ds_to_results):
    res = {}
    for ds in ds_to_results.keys():
        best_val_acc = 0.
        for r in ds_to_results[ds]:
            if r['val_acc'] > best_val_acc:
                best_val_acc = r['val_acc']
                ps = r['params']
                model = ps['model']
                adj_lambda = ps['adj_lambda']
        res[ds] = {'model': model, 'adj_lambda': adj_lambda}
    pprint(res)

if __name__ == '__main__':
    root_dir = Path('./hyperparam_results/arch_results/')
    pkl_files = [f for f in root_dir.iterdir() if f.suffix == '.pkl' and f.is_file()]
    def filter_params(x, kipf=None):
        correct_epochs = (x['dataset'] in DM.bigger_datasets and x['num_epochs'] == 150) \
                         or (x['dataset'] not in DM.bigger_datasets and x['num_epochs'] == 200)

        try:
            res = x['center'] == 'id' and x['center_learned'] == False and \
                x['Pi'] == 100 and x['learn_Xi'] == 'yes' and \
                kipf[x['dataset']]['mixup_scheme'] == x['mixup_scheme'] and \
                kipf[x['dataset']]['dof'] == x['dof'] and \
                correct_epochs and \
                x['num_layers'] == 2 and \
                (x['model'] in ['kipfres', 'kipf'] or x['adj_lambda'] <= 0.)
            return res
        except:
            return False
    ds_to_results = AU.get_ds_to_results(pkl_files,
                                         lambda x: filter_params(x, kipf=DM.best_kipf))
    AU.check_ds_to_results(ds_to_results)
    print("best_arch = ", end="")
    get_best_combo(ds_to_results)

    ds_to_results_nngp = AU.get_ds_to_results(pkl_files,
                                              lambda x: filter_params(x, kipf=DM.best_nngp_kipf) and float(x['dof']) == float('inf'),
                                              )
    AU.check_ds_to_results(ds_to_results_nngp)
    print()
    print("best_nn_arch = ", end="")
    get_best_combo(ds_to_results_nngp)