import sys
sys.path.append('/mnt/data01/****/****')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from functools import partial
import os
import time

import torch
import dgl

# import read_data
from gnn.Datasets import LargeCompleteMVGraphDatasets
from gnn.networks.AutoVisualMultiGINComplete import AutoVisualNet


from gnn.useful_utils import graph_sampler
from gnn.useful_utils import visualization_metric


def get_model_ckp(ckp_name):
    net, save = torch.load(ckp_name)
    return net, save


def get_loader(data_names, exp_name, cdist_path, pe_path, b_size=1):
    get_ds_fn = partial(LargeCompleteMVGraphDatasets.DatasetGraphDataset,
                        exp_name=exp_name,
                        cdist_path=cdist_path,  # ***
                        visual_path='/mnt/data01/public/aad_data/bo_nmi',
                        visual_path_umap='/mnt/data01/public/aad_data/bo/umap/',
                        is_test=False,  # ***
                        normalize_z=True, z_cali_method='none',
                        # z_anchor=dummy_train_ds.z[0],
                        precomputed_pe_path=pe_path,  # ***
                        z_mu=None, z_std=None, flip_sign_method='pos')
    ds = get_ds_fn(data_names=data_names)

    if b_size == 1:
        get_loader_for_eval = partial(dgl.dataloading.GraphDataLoader, batch_size=1,
                                      shuffle=True,
                                      num_workers=0,
                                      # collate_fn=lambda x: list(zip(*x))
                                      )

        train_loader = get_loader_for_eval(ds)
    else:
        train_sampler = torch.utils.data.RandomSampler(ds)
        train_loader = dgl.dataloading.GraphDataLoader(ds,
                                                       batch_sampler=graph_sampler.GraphBatchSampler(train_sampler,
                                                                                                     batch_size=30000,
                                                                                                     # 34000,
                                                                                                     block_list=ds.ds_block_sizes),
                                                       num_workers=10, )
    return train_loader


def get_loader_by_exp_name(data_names, exp_names):
    if exp_names == 'clip':
        return get_loader(data_names,
                          exp_name=exp_names,
                          cdist_path='/mnt/data01/public/aad_data',
                          pe_path='/mnt/data01/public/aad_data/pe')
    elif exp_names == 'gene':
        return get_loader(data_names,
                          exp_name=exp_names,
                          cdist_path='/mnt/data01/public/aad_data/gene_filtered',
                          pe_path='/mnt/data01/public/aad_data/pe/gene_filtered')
    elif exp_names == 'uci':
        return get_loader(data_names,
                          exp_name=exp_names,
                          cdist_path='/mnt/data01/public/aad_data',
                          pe_path='/mnt/data01/public/aad_data/pe')


def get_d_names_from_save(save):
    train_names = save['train_names']
    test_names = save['test_names']
    return train_names, test_names


def plot_zs(z, zhat, y, d_name, exp_name):
    def scatter_z(z, y, ax, label, marker, alpha=1.0):
        z = z.cpu().numpy()
        scatter = ax.scatter(z[:, 0], z[:, 1], c=y,
                             label=label, marker=marker, alpha=alpha,
                             cmap='viridis'
                             )
        return scatter

    d_name = d_name[0]
    fig = plt.figure(figsize=(10, 8))
    fig.set(tight_layout=True)
    ax = fig.add_subplot()

    scatter_z(z, y, ax,
              label='z', marker='+', alpha=0.5,
              )

    s = scatter_z(zhat, y, ax,
                  label='z_hat', marker='o', alpha=1.0,
                  )
    plt.colorbar(s)
    plt.title(f'Visualization of {d_name} CLIP Feature')
    plt.legend()
    plt.savefig(f'/mnt/data01/****/****/gnn/transfer_exp/plot_res/{exp_name}-{d_name}.png')


def evaluate_model_performance(d_loader, net, n_method=2):
    net.eval()
    device = net.device
    with torch.no_grad():
        scs = []
        nmis = []
        d_names = []
        losses = []
        for d_name, n_view_graphs in d_loader:
            edge_only_graph = dgl.graph(n_view_graphs.edges(), num_nodes=n_view_graphs.num_nodes()).to(device)
            # tsne_z = n_view_graphs.ndata['z'].to(torch.float32).to(device)
            umap_z = n_view_graphs.ndata['z_umap'].to(torch.float32).to(device)
            # batch_graph_size = n_view_graphs.batch_num_nodes()

            # batch_graph_size_cumsum = torch.cumsum(torch.cat([torch.zeros(1), batch_graph_size]), dim=0).long()
            # tsne_zdist = torch.cat([torch.pdist(tsne_z[batch_graph_size_cumsum[i]: batch_graph_size_cumsum[i + 1]])
            #                         for i in range(n_view_graphs.batch_size)])
            #
            # umap_zdist = torch.cat([torch.pdist(umap_z[batch_graph_size_cumsum[i]: batch_graph_size_cumsum[i + 1]])
            #                         for i in range(n_view_graphs.batch_size)])
            #
            # batch_zdist_size = (batch_graph_size * (batch_graph_size - 1) / 2).long().to(device)

            # student_t_loss, z_hat = net.loss_fn_kl_t_pdist_w_umap(n_view_graphs, edge_only_graph, tsne_zdist,
            #                                                       umap_zdist, batch_zdist_size)
            z_hat = net(edge_only_graph, n_view_graphs.to(device))
            y = n_view_graphs.ndata['y']
            if n_method > 1:
                tsne_nmi, tsne_sc = visualization_metric.get_nmi_sc(z_hat[:, :2].cpu().numpy(), y.tolist())
                umap_nmi, umap_sc = visualization_metric.get_nmi_sc(z_hat[:, 2:].cpu().numpy(), y.tolist())
            else:
                # plot_zs(umap_z, z_hat[:, :2], y, d_name=d_name, exp_name='kl')
                umap_nmi, umap_sc = visualization_metric.get_nmi_sc(z_hat[:, :2].cpu().numpy(), y.tolist())
                tsne_nmi, tsne_sc = 0.0, 0.0

            scs.append(torch.Tensor([tsne_sc, umap_sc]))
            nmis.append(torch.Tensor([tsne_nmi, umap_nmi]))
            # losses.append(student_t_loss)
            d_names.append(d_name[0])

    nmis_t = torch.stack(nmis)
    scs_t = torch.stack(scs)
    # losses_t = torch.Tensor(losses)
    # net.train()
    return nmis_t, scs_t, d_names


def get_gt_by_name(d_name, gt_dict):
    d_name = (d_name, )
    tsne_v = gt_dict[d_name]['tsne']
    umap_v = gt_dict[d_name]['umap']
    return tsne_v, umap_v


def get_gt_df(gt_save, d_names, prefix=''):
    nmi_gt_dict, sc_gt_dict = gt_save
    # print(list(nmi_gt_dict.keys())[:10])
    df_dict = {}
    for d_name in d_names:
        df_dict['d_name'] = df_dict.get('d_name', []) + [d_name]
        tsne_nmi, umap_nmi = get_gt_by_name(d_name, nmi_gt_dict)
        tsne_sc, umap_sc = get_gt_by_name(d_name, sc_gt_dict)
        df_dict[f'{prefix}gt_tsne_nmi'] = df_dict.get(f'{prefix}gt_tsne_nmi', []) + [tsne_nmi]
        df_dict[f'{prefix}gt_umap_nmi'] = df_dict.get(f'{prefix}gt_umap_nmi', []) + [umap_nmi]
        df_dict[f'{prefix}gt_tsne_sc'] = df_dict.get(f'{prefix}gt_tsne_sc', []) + [tsne_sc]
        df_dict[f'{prefix}gt_umap_sc'] = df_dict.get(f'{prefix}gt_umap_sc', []) + [umap_sc]
    df = pd.DataFrame(df_dict)
    return df


def get_eval_df(nmis_t, scs_t, d_names, prefix='pred_'):
    tsne_nmi, umap_nmi = nmis_t[:, 0], nmis_t[:, 1]
    tsne_sc, umap_sc = scs_t[:, 0], scs_t[:, 1]

    df_dict = {}
    df_dict['d_name'] = d_names
    df_dict[f'{prefix}tsne_nmi'] = tsne_nmi
    df_dict[f'{prefix}umap_nmi'] = umap_nmi
    df_dict[f'{prefix}tsne_sc'] = tsne_sc
    df_dict[f'{prefix}umap_sc'] = umap_sc
    # df_dict[f'{prefix}losses'] = losses_t

    df = pd.DataFrame(df_dict)
    return df


def get_gt_pred_df_from_d(gt_save, d_names, d_loader, net, to_save_name, n_method):
    gt_df = get_gt_df(gt_save, d_names)
    # print(gt_df)

    net = net.cuda()
    res = evaluate_model_performance(d_loader, net, n_method=n_method)
    pred_df = get_eval_df(*res)
    # print(pred_df)

    df = pd.merge(gt_df, pred_df, on='d_name')
    df.to_csv(to_save_name)
    return df


def get_mean_performance(df):
    relative_tsne_nmi_precision = df['pred_tsne_nmi'] / df['gt_tsne_nmi']
    relative_tsne_nmi_gap = df['gt_tsne_nmi'] - df['pred_tsne_nmi']

    relative_umap_nmi_precision = df['pred_umap_nmi'] / df['gt_umap_nmi']
    relative_umap_nmi_gap = df['gt_umap_nmi'] - df['pred_umap_nmi']

    relative_tsne_sc_precision = df['pred_tsne_sc'] / df['gt_tsne_sc']
    relative_tsne_sc_gap = df['gt_tsne_sc'] - df['pred_tsne_sc']

    relative_umap_sc_precision = df['pred_umap_sc'] / df['gt_umap_sc']
    relative_umap_sc_gap = df['gt_umap_sc'] - df['pred_umap_sc']

    df['relative_tsne_nmi_precision'] = relative_tsne_nmi_precision
    df['relative_tsne_nmi_gap'] = relative_tsne_nmi_gap
    df['relative_umap_nmi_precision'] = relative_umap_nmi_precision
    df['relative_umap_nmi_gap'] = relative_umap_nmi_gap

    df['relative_tsne_sc_precision'] = relative_tsne_sc_precision
    df['relative_tsne_sc_gap'] = relative_tsne_sc_gap
    df['relative_umap_sc_precision'] = relative_umap_sc_precision
    df['relative_umap_sc_gap'] = relative_umap_sc_gap

    mean_df = df.drop('d_name', axis=1).mean()
    std_df = df.drop('d_name', axis=1).std()
    return mean_df, std_df


if __name__ == '__main__':
    root = '/mnt/data01/****/****/gnn/res/'
    # ckp_name = root + '500ds-GNN-clip_w_umap-regin2-8gt-sigma_mv-run4-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-59.tar'
    # ckp_name = root + '500ds-GNN-clip_w_umap-bce-regin2-8gt-sigma_mv-run4-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-119.tar'
    # ckp_name = root + '500ds-GNN-clip_w_umap-l2-subnet-regin2-8gt-sigma_mv-run4-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-149.tar'
    # ckp_name = root + '500ds-GNN-clip_w_umap-l2-uninet-regin2-8gt-sigma_mv-run4-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-154.tar'
    # ckp_name = root + '500ds-GNN-clip_w_umap-bothl2-uninet-regin2-8gt-sigma_mv-run4-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-219.tar'

    save_f_names = [
                'run2-GNN-uci-regin2-8gt-sigma_mv-run4-300epoch-kl_t_64dim_svdu_complete_G-epoch-99.tar'
                # 'uci-GNN-clip_w_umap-umaponly-knn-l2-newparallelgt-regin2-8gt-sigma_mv-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-169.tar'
                    ]
    loss_types = ['tsne-uci-knn-l2']

    exp_name = 'uci'

    for ind, save_f_name in enumerate(save_f_names):
        ckp_name = root + save_f_name
        loss_type = loss_types[ind]

        net, save = get_model_ckp(ckp_name)
        net.n_out_subnet = 1
        net.is_gt_parallel = False

        # max = torch.argmax(torch.tensor([torch.mean(i, dim=0)[1] for i in  save['test_nmi']]))
        print('===============================')
        print(save_f_name)
        print(loss_type)
        # print(max)

        gt_save = torch.load('/mnt/data01/public/aad_data/uci_tsne_w_umap_ground_truth.tar', weights_only=False)
        nmi_gt_dict, sc_gt_dict = gt_save

        train_d_names, test_d_names = get_d_names_from_save(save)
        train_d_loader = get_loader_by_exp_name(train_d_names, exp_names=exp_name)
        test_d_loader = get_loader_by_exp_name(test_d_names, exp_names=exp_name)

        train_df = get_gt_pred_df_from_d(gt_save, train_d_names, train_d_loader, net,
                                         to_save_name=f'{exp_name}_train_df_{loss_type}.csv',
                                         n_method=1)
        train_performance = get_mean_performance(train_df)
        print('train', train_performance)

        test_df = get_gt_pred_df_from_d(gt_save, test_d_names, test_d_loader, net,
                                        to_save_name=f'{exp_name}_test_df_{loss_type}.csv',
                                        n_method=1)
        test_performance = get_mean_performance(test_df)
        print('test', test_performance)
        print('===========================================')

# for flip in [
#     'pos',
#         # 'sum',
#         # 'linf',
#         #                    'l1', 'l2','random'
#         ]:
#     # f_name = f'./res/GNN-resgin-8gt-sigma_mv-run4-500epoch-kl_t_64dim_svd_u-run2_{flip}_only_complete_G_sep_gnn-epoch-499.tar'
#     f_name = f'/mnt/data01/****/****/gnn/res/GNN-regin2-8gt-sigma_mv-run4-300epoch-kl_t_64dim_svdu_complete_G-5000ds-epoch-244.tar'
#
#
#     train_names = save['train_names']
#     test_names = save['test_names']
#     print(len(train_names))
#     gt_train = get_ground_truth(ground_truth_dict, train_names)
#     gt_test = get_ground_truth(ground_truth_dict, test_names)
#
#     print(f'GT train: {gt_train}, GT test: {gt_test}')
#
#     train_nmi = save['train_nmi']
#     test_nmi = save['test_nmi']
#     print(train_nmi)
#     print(test_nmi)
#
#     print(flip)
#     print('best: ', torch.max(torch.tensor(test_nmi)), torch.argmax(torch.tensor(test_nmi)))
#     print('last: ', test_nmi[-1])
#     print('train last: ', train_nmi[-1])
#     print('===========')
#
# # plt.plot(train_nmi, label='train_nmi')
# # plt.plot(test_nmi, label='test_nmi')
# #
# # plt.ylabel('NMI')
# # plt.xlabel('10 * epoch')
# # plt.title('5 sigma view, 64 lap row norm PE dim, 5 ds, train & test nmi during the training')
# #
# # plt.legend()
# # plt.show()
#
# # plt.plot(save['running_loss'], label='train_nmi')
# # plt.show()
