import sys; import os
sys.path.append(os.getcwd())
import torch as t
import matplotlib.pyplot as plt
import analysis_util as AU
from pathlib import Path
import dataset.meta as DM


def turn_off_axes_ticks(ax):
    ax.set_xticks([]); ax.set_yticks([])

plt.rcParams.update({
    'font.size': 12,
    'axes.titlesize': 16,
    'axes.labelsize': 14,
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
    'legend.fontsize': 16
})

if __name__ == '__main__':
    root_dir = Path('./shaped/')
    pkl_files = [f for f in root_dir.iterdir() if f.suffix == '.pkl' and f.is_file()]
    def filter_params(x):
        correct_epochs = x['dataset'] not in DM.bigger_datasets and x['num_epochs'] == 300
        res = x['dof'] in [1e0, 'inf', 1e3, 1e1] and \
               x['center'] == 'id' and x['center_learned'] == False and \
               x['Pi'] == 100 and x['learn_Xi'] == 'yes' and \
               DM.best_kipf[x['dataset']]['mixup_scheme'] == x['mixup_scheme'] and \
               correct_epochs and \
               x['num_layers'] == 2 and \
               x['model'] == 'kipf' and \
               x['adj_lambda'] == 0.3 and \
               x.get('scale_inputs',False) == True
        return res
    ds_to_results = AU.get_ds_to_results(pkl_files, filter_params)

    def get_graph_data(results_list):
        results_list = sorted(results_list, key=lambda x: float(x['params']['dof']))

        dofs = [float(x['params']['dof']) for x in results_list]

        graph_data = {}
        for dof in dofs:
            metrics = [x['metrics'] for x in results_list if float(x['params']['dof']) == dof]
            assert len(metrics) <= 1, f"should only be a single run for each set of params"
            metrics = metrics[0][0]
            kernels = metrics['kernels']

            K = kernels['K'].clone().detach()
            K = K * t.rsqrt(K.diag().unsqueeze(-1) * K.diag())
            y = kernels['y'].clone().detach()
            yoh = t.nn.functional.one_hot(y)
            yyT = yoh @ yoh.T
            yyT = yyT * t.rsqrt(yyT.diag().unsqueeze(-1) * yyT.diag())
            kernels = dict(K=K, yyT=yyT, cka=kernels['cka'])
            graph_data[dof] = kernels
        return graph_data

    graph_datas = {k: get_graph_data(ds_to_results[k]) for k in ds_to_results.keys()}
    dof_values = sorted(graph_datas['cora'].keys())
    # Create a figure with 2 rows and len(dof_values) columns
    # fig, axes = plt.subplots(2, len(dof_values) + 1, figsize=(5*len(dof_values), 10))
    nkernels = len(dof_values) + 1
    fig, axes = plt.subplots(2, nkernels + 1, figsize=(18, 6), gridspec_kw={'width_ratios': [1]*nkernels + [0.05]})

    dataset_to_row = {'cora': 0, 'roman-empire': 1}
    dataset_to_ylabel = {'cora': 'Cora', 'roman-empire': 'Roman Empire'}
    for dataset in graph_datas.keys():
        graph_data = graph_datas[dataset]
        # Sort the dof values
        dof_values = sorted(graph_data.keys())

        vmin = min([graph_data[dof]['K'].min() for dof in dof_values])
        vmin = min(vmin ,min([graph_data[dof]['yyT'].min() for dof in dof_values]))

        vmax = max([graph_data[dof]['K'].max() for dof in dof_values])
        vmax = max(vmax ,max([graph_data[dof]['yyT'].max() for dof in dof_values]))
        imshow_params = dict(vmin=vmin, vmax=vmax, cmap='viridis')
        im_K_list = []; im_yyT_list = []
        row = dataset_to_row[dataset]
        for col, dof in enumerate(dof_values):
            # Plot K
            if col == 0:
                axes[row, col].set_ylabel(dataset_to_ylabel[dataset])
            im_K_list.append(axes[row, col].imshow(graph_data[dof]['K'], **imshow_params))
            if float(dof) == float('inf'):
                axes[row, col].set_title(rf'$\nu=\infty$, CKA={graph_data[dof]["cka"]:.2f}')
            else:
                axes[row, col].set_title(rf"$\nu={round(dof)}$, CKA={graph_data[dof]['cka']:.2f}")

            turn_off_axes_ticks(axes[row,col])

        im_yyT_list.append(axes[row, -2].imshow(graph_data[dof]['yyT'], **imshow_params))
        turn_off_axes_ticks(axes[row,-2])
        axes[row, -2].set_title(r'$\mathbf{Y}\mathbf{Y}^T$')
        plt.tight_layout()
    fig.colorbar(im_K_list[-1], cax=axes[0, -1], shrink=0.8)
    fig.colorbar(im_yyT_list[-1], cax=axes[1, -1], shrink=0.8)
    plt.tight_layout()

    # Adjust layout and display the plot
    plt.savefig('shaped/shaped.pdf')
