
from pathlib import Path
import sys; import os; sys.path.append(os.getcwd())
import analysis_util as AU
import dataset.meta as DM

import pandas as pd

def create_final_results_table(ds_to_results, ds_to_results_nngp):
    data = { ## dataset -> model -> {test_acc, std_test_acc}
        ds: { model: {'test_acc': -1, 'std_test_acc': -1}
              for model in ['gcdkm', 'our-gcnngp', 'other-gcnngp', 'rel-improv', 'gcn-nodropout', 'gcn'] }
        for ds in DM.all_datasets
    }

    _star = '$^*$'
    dtr = {}
    for k in ds_to_results.keys():
        dtr[k] = ds_to_results[k]
        dtr[k].extend(ds_to_results_nngp[k])

    for ds, results in dtr.items():
        for r in results:
            test_acc = r['test_acc'] * 100; test_acc_std = r['split_std_test_acc'] * 100

            if 'dof' in r['params']:
                dof = r['params']['dof']
                if float(dof) == float('inf'):
                    data[ds]['our-gcnngp']['test_acc'] = test_acc
                    data[ds]['our-gcnngp']['std_test_acc'] = test_acc_std
                    if float(DM.best_kipf[ds]['dof']) == float('inf'):
                        data[ds]['gcdkm']['test_acc'] = test_acc
                        data[ds]['gcdkm']['std_test_acc'] = test_acc_std
                        data[ds]['gcdkm']['std_test_acc'] = test_acc_std
                        data[ds]['gcdkm']['extra'] = _star
                if float(dof) < float('inf'):
                    data[ds]['gcdkm']['test_acc'] = test_acc
                    data[ds]['gcdkm']['std_test_acc'] = test_acc_std
            else: # must be gcn
                pass # TODO

    ## to pandas df
    _df = {}
    cols = ['gcdkm', 'our-gcnngp', 'other-gcnngp', 'gcn-nodropout', 'gcn']
    # cols = ['oldpaper', 'gcdkm', 'our-gcnngp', 'other-gcnngp', 'gcn']
    # data['cora']['oldpaper'] = {'test_acc': 81.5, 'std_test_acc': 0.2}
    # data['pubmed']['oldpaper'] = {'test_acc': 79.4, 'std_test_acc': 0.1}
    # data['citeseer']['oldpaper'] = {'test_acc': 69.5, 'std_test_acc': 0.2}
    # data['minesweeper']['oldpaper'] = {'test_acc': 85.7, 'std_test_acc': 0.5}
    # data['tolokers']['oldpaper'] = {'test_acc': 80.7, 'std_test_acc': 0.6}
    # data['amazon-ratings']['oldpaper'] = {'test_acc': 48.8, 'std_test_acc': 0.6}
    # data['chameleon']['oldpaper'] = {'test_acc': 70.8, 'std_test_acc': 1.9}
    # data['squirrel']['oldpaper'] = {'test_acc': 57.7, 'std_test_acc': 1.8}
    # data['roman-empire']['oldpaper'] = {'test_acc': 75.6, 'std_test_acc': 0.7}

    _dagger = '$^\\dagger$'
    data['cora']['other-gcnngp'] = {'test_acc': 82.80, 'std_test_acc': -1., 'extra': _dagger}
    data['pubmed']['other-gcnngp'] = {'test_acc': 79.60, 'std_test_acc': -1.,'extra': _dagger }
    data['citeseer']['other-gcnngp'] = {'test_acc': 69.5, 'std_test_acc': -1., 'extra': _dagger}
    data['arxiv']['other-gcnngp'] = {'test_acc': 70.11, 'std_test_acc': 0.11, 'extra': _dagger}
    data['reddit']['other-gcnngp'] = {'test_acc': 94.65, 'std_test_acc': 0.03, 'extra': _dagger}

    ## rel-improv
    # for ds in data.keys():
    #     nngp_acc = data[ds]['our-gcnngp']['test_acc']
    #     dkm_acc = data[ds]['gcdkm']['test_acc']
    #     data[ds]['rel-improv']['test_acc'] = (dkm_acc - nngp_acc) / nngp_acc
    #     data[ds]['rel-improv']['std_test_acc'] = -1

    ## oldpaper gcn
    # data['cora']['gcn'] = {'test_acc': 81.5, 'std_test_acc': -1.}
    # data['pubmed']['gcn'] = {'test_acc': 79.0, 'std_test_acc': -1.}
    # data['citeseer']['gcn'] = {'test_acc': 70.3, 'std_test_acc': -1}
    # data['minesweeper']['gcn'] = {'test_acc': 86.6, 'std_test_acc': 0.6}
    # data['tolokers']['gcn'] = {'test_acc': 80.7, 'std_test_acc': 0.3}
    # data['amazon-ratings']['gcn'] = {'test_acc': 49.3, 'std_test_acc': 0.6}
    # data['chameleon']['gcn'] = {'test_acc': 67.8, 'std_test_acc': 2.8}
    # data['squirrel']['gcn'] = {'test_acc': 58.5, 'std_test_acc': 1.9}
    # data['roman-empire']['gcn'] = {'test_acc': 83.4, 'std_test_acc': 0.4}

    ## gcn no dropout
    data['reddit']['gcn-nodropout'] = {'test_acc': 95.30, 'std_test_acc': 0.08}
    data['minesweeper']['gcn-nodropout'] = {'test_acc': 85.63, 'std_test_acc': 0.36}
    data['roman-empire']['gcn-nodropout'] = {'test_acc': 79.92, 'std_test_acc': 0.51}
    data['squirrel']['gcn-nodropout'] = {'test_acc': 57.3, 'std_test_acc': 1.40}
    data['tolokers']['gcn-nodropout'] = {'test_acc': 81.63, 'std_test_acc': 0.70}
    data['chameleon']['gcn-nodropout'] = {'test_acc': 65.32, 'std_test_acc': 2.58}
    data['pubmed']['gcn-nodropout'] = {'test_acc': 79.36, 'std_test_acc': 0.22}
    data['citeseer']['gcn-nodropout'] = {'test_acc': 71.7100, 'std_test_acc': 0.2119}
    data['amazon-ratings']['gcn-nodropout'] = {'test_acc': 48.48, 'std_test_acc': 0.48}
    data['arxiv']['gcn-nodropout'] = {'test_acc': 70.02, 'std_test_acc': 0.11}
    data['cora']['gcn-nodropout'] = {'test_acc': 80.62, 'std_test_acc': 0.24}

    data['reddit']['gcn'] = {'test_acc': 95.77, 'std_test_acc': 0.05}
    data['minesweeper']['gcn'] = {'test_acc': 85.63, 'std_test_acc': 0.36}
    data['roman-empire']['gcn'] = {'test_acc': 82.77, 'std_test_acc': 0.48}
    data['squirrel']['gcn'] = {'test_acc': 57.73, 'std_test_acc': 1.43}
    data['tolokers']['gcn'] = {'test_acc': 82.19, 'std_test_acc': 0.74}
    data['chameleon']['gcn'] = {'test_acc': 65.32, 'std_test_acc': 2.58}
    data['pubmed']['gcn'] = {'test_acc': 79.47, 'std_test_acc': 0.13}
    data['citeseer']['gcn'] = {'test_acc': 71.7100, 'std_test_acc': 0.2119}
    data['amazon-ratings']['gcn'] = {'test_acc': 49.79, 'std_test_acc': 0.60}
    data['arxiv']['gcn'] = {'test_acc': 70.57, 'std_test_acc': 0.24}
    data['cora']['gcn'] = {'test_acc': 81.06, 'std_test_acc': 0.29}

    prettify_col = {'gcdkm': 'GCDKM', 'our-gcnngp': 'sparse GCNNGP', 'other-gcnngp': 'GCNNGP', 'gcn-nodropout': 'GCN (no dropout)', 'gcn': 'GCN'}

    data = {DM.plot_names[ds]: data[ds] for ds in data.keys()}

    cols = ['gcdkm', 'our-gcnngp', 'other-gcnngp', 'gcn-nodropout', 'gcn']
    for ds, model in data.items():
        row = []
        x = data[ds]
        for col in cols:
            c = x[col]
            r = []
            r.append(c['test_acc'])
            if 'std_test_acc' in c:
                r.append(c['std_test_acc'])
            if 'extra' in c:
                r.append(c['extra'])
            row.append(r)

        row = AU.means_stds_to_latex(row)
        _df[ds] = row
    df = pd.DataFrame(_df).transpose()
    df.columns = [prettify_col[c] for c in cols]
    m = {DM.plot_names[k]: DM.homophily_map[k] for k in DM.plot_names.keys()}
    df['homophily_ratio'] = df.index.map(m)
    df = df.sort_values('homophily_ratio', ascending=False)
    df = df.drop('homophily_ratio', axis=1)
    print(df.to_latex())

if __name__ == '__main__':
    root_dir = Path('./hyperparam_results/final_results/')
    pkl_files = [f for f in root_dir.iterdir() if f.suffix == '.pkl' and f.is_file()]
    def filter_params(x, kipf, arch, center, Pi):
        correct_epochs =  (x['dataset'] in DM.bigger_datasets and x['num_epochs'] == 150) \
                       or (x['dataset'] not in DM.bigger_datasets and x['num_epochs'] == 200)
        res =  x['Pi'] == Pi[x['dataset']]['Pi'] and x['learn_Xi'] == 'yes' and \
            correct_epochs and \
            kipf[x['dataset']]['mixup_scheme'] == x['mixup_scheme'] and \
            kipf[x['dataset']]['dof'] ==  x['dof'] and \
            arch[x['dataset']]['model'] == x['model'] and \
            arch[x['dataset']]['adj_lambda'] == x['adj_lambda'] and \
            center[x['dataset']]['center'] == x['center'] and \
            center[x['dataset']]['center_learned'] == x['center_learned'] and \
            x['num_layers'] == 2
        return res
    ds_to_results = AU.get_ds_to_results(pkl_files, lambda x: filter_params(x,
                                                                            DM.best_kipf,
                                                                            DM.best_arch,
                                                                            DM.best_center,
                                                                            DM.best_Pi))

    AU.check_ds_to_results(ds_to_results)

    ds_to_results_nngp = AU.get_ds_to_results(pkl_files, lambda x: filter_params(x,
                                                                                 DM.best_nngp_kipf,
                                                                                 DM.best_nngp_arch,
                                                                                 DM.best_nngp_center,
                                                                                 DM.best_nngp_Pi) and float(x['dof']) == float('inf'))
    AU.check_ds_to_results(ds_to_results_nngp)

    create_final_results_table(ds_to_results, ds_to_results_nngp)