from pathlib import Path
import sys; import os; sys.path.append(os.getcwd())
import analysis_util as AU
import dataset.meta as DM

def get_best_combo(ds_to_results):
    for ds in ds_to_results.keys():
        best_val_acc = 0.
        for r in ds_to_results[ds]:
            if r['val_acc'] > best_val_acc:
                best_val_acc = r['val_acc']*100
                test_acc = r['test_acc']*100
                std_test_acc = r['std_test_acc']*100
        print(f"{ds}: test acc: {test_acc:.2f} \pm {std_test_acc:.2f}")

if __name__ == '__main__':
    root_dir = Path('./hyperparam_results/final_gcn_results/')
    pkl_files = [f for f in root_dir.iterdir() if f.suffix == '.pkl' and f.is_file()]
    def filter_params(x):
        ds = x['dataset']
        return x['norm'] == DM.best_gcn_no_dropout[ds]['norm'] and \
               x['model'] == DM.best_gcn_no_dropout[ds]['model']  and \
               x['nhidden'] == DM.best_gcn_no_dropout[ds]['nhidden'] and \
               x['dropout'] == DM.best_gcn_no_dropout[ds]['dropout'] and \
               x['normalize_features'] == DM.best_gcn_no_dropout[ds]['normalize_features'] and \
               x['num_epochs'] == 200
    ds_to_results = AU.get_ds_to_results(pkl_files, filter_params)
    ds_to_results = {k: v for k, v in ds_to_results.items()}
    print("GCN _without_ dropout")
    AU.check_ds_to_results(ds_to_results)
    get_best_combo(ds_to_results)
    print("==")

    def filter_params(x):
        ds = x['dataset']
        return x['norm'] == DM.best_gcn[ds]['norm'] and \
               x['model'] == DM.best_gcn[ds]['model']  and \
               x['nhidden'] == DM.best_gcn[ds]['nhidden'] and \
               x['dropout'] == DM.best_gcn[ds]['dropout'] and \
               x['normalize_features'] == DM.best_gcn[ds]['normalize_features'] and \
               x['num_epochs'] == 200
    ds_to_results = AU.get_ds_to_results(pkl_files, filter_params)
    ds_to_results = {k: v for k, v in ds_to_results.items()}
    print("GCN _with_ dropout")
    AU.check_ds_to_results(ds_to_results)
    get_best_combo(ds_to_results)
    print("==")