"""
creates plot showing graph metric
"""

import matplotlib.pyplot as plt
from pathlib import Path
import sys; import os; sys.path.append(os.getcwd())
import analysis_util as AU
import dataset.meta as DM
import numpy as np
from pprint import pprint

def plot(plot_data):
    plt.rcParams.update({
        'font.size': 12,
        'axes.titlesize': 16,
        'axes.labelsize': 14,
        'xtick.labelsize': 12,
        'ytick.labelsize': 12,
        'legend.fontsize': 16
    })
    ds_to_pos = {'cora':(0,0), 'pubmed':(0, 1), 'reddit':(0, 2),
                 'citeseer':(1, 0), 'minesweeper':(1, 1), 'arxiv':(1, 2),
                 'tolokers':(2, 0), 'amazon-ratings':(2, 1), 'chameleon':(2, 2),
                 'squirrel':(3, 0), 'roman-empire':(3, 1)}
    fig, axes = plt.subplots(4, 3, figsize=(15, 15))
    have_added_labels = False
    for ds_name in plot_data.keys():
        data = plot_data[ds_name]
        pos = ds_to_pos[ds_name]
        ax = axes[pos[0], pos[1]]


        mixup_scheme = DM.best_kipf[ds_name]['mixup_scheme']
        vgraph = data[mixup_scheme]

        if len(vgraph['val_acc']) == 0: continue
        if len(vgraph['dof']) == 0: continue
        if float(vgraph['dof'][-1]) != float('inf'): continue

        xs = vgraph['dof'][:-1] # :-1 to avoid plotting dof=inf
        ys = np.array(vgraph['val_acc'][:-1])*100
        ys_std = np.array(vgraph['std_val_acc'][:-1])*100

        if not have_added_labels: label = mixup_scheme
        else: label=None
        ax.plot(xs, ys, label=label)
        ys_ub = ys + np.array(ys_std)
        ys_lb = ys - np.array(ys_std)
        ax.fill_between(xs, ys_lb, ys_ub, alpha=0.2)

        ymin = min(ys_lb.min() - 3, ys.min() - 3)
        ymax = max(ys_lb.max() + 3, ys.max() + 3)
        if ymax - ymin < 8:
            ymax = min(100, round(ymax + 5))
            ymin = max(0, round(ymin - 5))
        ax.set_ylim(ymin, ymax)

        have_added_labels = True
        ax.set_xscale('log')
        plot_name = DM.plot_names[ds_name]
        ax.set_title(f"{plot_name} ($h$={DM.homophily_map[ds_name]})")
        is_bottom_row = pos[0] == 3
        if is_bottom_row or pos[0] == 2 and pos[1] == 2:
            ax.set_xlabel(r'$\nu$') # only set xlabel for bottom row

        if pos[1] == 0: ax.set_ylabel('Val acc.') # only set ylabel for first column
    axes[3, 2].axis('off')
    plt.tight_layout()
    p = Path("./hyperparam_results/kipf_results/plots")
    p.mkdir(exist_ok=True, parents=True)
    plt.savefig(f"./{p}/graph_metric.pdf")
    plt.clf()


def get_best_combo(ds_to_results):
    res = {}
    for ds in ds_to_results.keys():
        best_val_acc = 0.
        for r in ds_to_results[ds]:
            if r['val_acc'] > best_val_acc:
                best_val_acc = r['val_acc']
                ps = r['params']
                best_mixup_scheme = ps['mixup_scheme']
                best_dof = ps['dof']
        res[ds] = {'mixup_scheme': best_mixup_scheme, 'dof': best_dof}
    pprint(res)

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--plot', action='store_true')
    args = parser.parse_args()

    root_dir = Path('./hyperparam_results/kipf_results')
    (root_dir / 'plots').mkdir(exist_ok=True, parents=True)
    pkl_files = [f for f in root_dir.iterdir() if f.suffix == '.pkl' and f.is_file()]
    def filter_params(x):
        correct_epochs = (x['dataset'] in DM.bigger_datasets and x['num_epochs'] == 200) \
                         or (x['dataset'] not in DM.bigger_datasets and x['num_epochs'] == 300)
        res = x['dof'] in [0., 1e-3, 1e-2, 1e-1, 1e0, 'inf', 1e1, 1e2, 1e3] and \
               x['Pi'] == 100 and x['learn_Xi'] == 'yes' and \
               x['model'] == 'kipf' and \
               x['adj_lambda'] == 0. and \
               correct_epochs and \
               x['center'] == 'id' and \
               x['center_learned'] == False and \
               x['num_layers'] == 2
        return res

    ds_to_results = AU.get_ds_to_results(pkl_files, filter_params)

    if not args.plot:
        AU.check_ds_to_results(ds_to_results)
        print("best_kipf = ", end='')
        get_best_combo(ds_to_results)

        ds_to_results_gcnngp = AU.get_ds_to_results(pkl_files, lambda x: filter_params(x) and float(x['dof']) == float('inf'))
        AU.check_ds_to_results(ds_to_results_gcnngp)
        print()
        print("best_nngp_kipf = ", end='')
        get_best_combo(ds_to_results_gcnngp)

    if args.plot:
        plot_data = {
            ds_name: {
                mixup_scheme: {
                    'val_acc': [r['val_acc'] for r in rs if r['params']['mixup_scheme'] == mixup_scheme],
                    'std_val_acc': [r['std_val_acc'] for r in rs if r['params']['mixup_scheme'] == mixup_scheme],
                    'dof': [r['params']['dof'] for r in rs if r['params']['mixup_scheme'] == mixup_scheme]
                }
                for mixup_scheme in ['fixed-indep', 'fixed-full']
            }
            for ds_name, rs in ds_to_results.items()
        }
        plot(plot_data)