from pathlib import Path
import sys; import os; sys.path.append(os.getcwd())
import analysis_util as AU

def get_best_combo(ds_to_results):
    print("{")
    for ds in ds_to_results.keys():
        best_val_acc = 0.
        for r in ds_to_results[ds]:
            if r['val_acc'] > best_val_acc:
                best_val_acc = r['val_acc']
                ps = r['params']
        parameters = ['normalize_features', 'norm', 'dropout', 'nhidden', 'model']
        dict_str = str({k: ps[k] for k in parameters})
        print(f"'{ds}': " +  dict_str)
    print("}")

if __name__ == '__main__':
    root_dir = Path('./hyperparam_results/gcn_results/')
    pkl_files = [f for f in root_dir.iterdir() if f.suffix == '.pkl' and f.is_file()]
    def filter_params(x):
        return x['norm'] in ['none', 'batch'] and \
               x['model'] in ['gcn', 'kipfgcn'] and \
               x['nhidden'] in [100, 200] and \
               x['dropout'] in [0., 0.5] and \
               x['normalize_features'] in [False, True] and \
               x['num_epochs'] == 200
    ds_to_results = AU.get_ds_to_results(pkl_files, filter_params)
    ds_to_results = {k: v for k, v in ds_to_results.items()}
    AU.check_ds_to_results(ds_to_results)
    print("best_gcn = ", end='')
    get_best_combo(ds_to_results)

    def filter_params(x):
        return x['norm'] in ['none', 'batch'] and \
               x['model'] in ['gcn', 'kipfgcn'] and \
               x['nhidden'] in [100, 200] and \
               x['dropout'] in [0.] and \
               x['normalize_features'] in [False, True] and \
               x['num_epochs'] == 200
    ds_to_results = AU.get_ds_to_results(pkl_files, filter_params)
    ds_to_results = {k: v for k, v in ds_to_results.items()}
    AU.check_ds_to_results(ds_to_results)
    print()
    print("best_gcn_no_dropout = ", end='')
    get_best_combo(ds_to_results)