import sys
sys.path.append('/mnt/data01/****/****')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from functools import partial
import os
import time

import torch
import dgl

# import read_data
from gnn.Datasets import LargeCompleteMVGraphDatasets
from gnn.networks.AutoVisualMultiGINComplete import AutoVisualNet


from gnn.useful_utils import graph_sampler
from gnn.useful_utils import visualization_metric


def get_model_ckp(ckp_name):
    net, save = torch.load(ckp_name)
    return net, save


def get_loader(data_names, exp_name, cdist_path, pe_path, b_size=1):
    get_ds_fn = partial(LargeCompleteMVGraphDatasets.DatasetGraphDataset,
                        exp_name=exp_name,
                        cdist_path=cdist_path,  # ***
                        visual_path='/mnt/data01/public/aad_data/bo',
                        visual_path_umap='/mnt/data01/public/aad_data/bo/umap',
                        is_test=False,  # ***
                        normalize_z=True, z_cali_method='none',
                        # z_anchor=dummy_train_ds.z[0],
                        precomputed_pe_path=pe_path,  # ***
                        z_mu=None, z_std=None, flip_sign_method='pos')
    ds = get_ds_fn(data_names=data_names)

    if b_size == 1:
        get_loader_for_eval = partial(dgl.dataloading.GraphDataLoader, batch_size=1,
                                      shuffle=True,
                                      num_workers=0,
                                      # collate_fn=lambda x: list(zip(*x))
                                      )

        train_loader = get_loader_for_eval(ds)
    else:
        train_sampler = torch.utils.data.RandomSampler(ds)
        train_loader = dgl.dataloading.GraphDataLoader(ds,
                                                       batch_sampler=graph_sampler.GraphBatchSampler(train_sampler,
                                                                                                     batch_size=30000,
                                                                                                     # 34000,
                                                                                                     block_list=ds.ds_block_sizes),
                                                       num_workers=10, )
    return train_loader


def get_loader_by_exp_name(data_names, exp_names):
    if exp_names == 'clip':
        return get_loader(data_names,
                          exp_name=exp_names,
                          cdist_path='/mnt/data01/public/aad_data',
                          pe_path='/mnt/data01/public/aad_data/pe')
    elif exp_names == 'gene':
        return get_loader(data_names,
                          exp_name=exp_names,
                          cdist_path='/mnt/data01/public/aad_data/gene_filtered',
                          pe_path='/mnt/data01/public/aad_data/pe/gene_filtered')
    elif exp_names == 'uci':
        return get_loader(data_names,
                          exp_name=exp_names,
                          cdist_path='/mnt/data01/public/aad_data/uci',
                          pe_path='/mnt/data01/public/aad_data/pe')


def get_d_names_from_save(save):
    train_names = save['train_names']
    test_names = save['test_names']
    return train_names, test_names


def plot_zs(zhats, ys, d_name, nmi, sc, method_name):
    tsne_results = zhats
    labels = ys.tolist()

    labels_dict = dict(zip(list(set(labels)), range(len(set(labels)))))
    labels = [labels_dict[i] for i in labels]
    labels = np.array(labels)
    # Define a marker for each label
    # marker_dict = {0: 'x', 1: '*', 2: '+', 3: 'D'}

    # Font size settings
    title_fs = 16
    label_fs = 16
    tick_fs = 12
    legend_fs = 14

    fig, ax = plt.subplots(figsize=(4, 4))

    # Scatter plot with different markers
    for lbl in np.unique(labels):
        idx = labels == lbl
        ax.scatter(
            tsne_results[idx, 0],
            tsne_results[idx, 1],
            label=f'Class {lbl}',
            # marker=marker_dict[lbl],
            s=20,
            ##        alpha=0.7
            linewidths=1,
        )

    # Title and axis labels with adjustable font sizes
    ax.set_title(f'{method_name}, NMI: {nmi*100:.2f}, SC: {sc*100:.2f}', fontsize=title_fs)

    # Legend with adjustable font sizes
    # leg = ax.legend(
    #     ##    title='Label',
    #     ##    title_fontsize=legend_fs,
    #     fontsize=legend_fs,
    #     frameon=True,
    #     ncol=1  # number of columns in legend
    # )
    ##leg.get_frame().set_linewidth()

    # Hide ticks and tick labels
    ax.set_xticks([])
    ax.set_yticks([])

    # Keep the plot border (spines) visible
    for spine in ax.spines.values():
        spine.set_visible(True)

    plt.tight_layout()
    plt.savefig(f'./plot_res/{d_name}.pdf')


def evaluate_model_performance(d_loader, net, n_method=2):
    net.eval()
    device = net.device
    with torch.no_grad():
        scs = []
        nmis = []
        d_names = []
        losses = []
        zhats_d = {}
        for d_name, n_view_graphs in d_loader:
            edge_only_graph = dgl.graph(n_view_graphs.edges(), num_nodes=n_view_graphs.num_nodes()).to(device)
            tsne_z = n_view_graphs.ndata['z'].to(torch.float32).to(device)
            umap_z = n_view_graphs.ndata['z_umap'].to(torch.float32).to(device)
            # batch_graph_size = n_view_graphs.batch_num_nodes()

            # batch_graph_size_cumsum = torch.cumsum(torch.cat([torch.zeros(1), batch_graph_size]), dim=0).long()
            # tsne_zdist = torch.cat([torch.pdist(tsne_z[batch_graph_size_cumsum[i]: batch_graph_size_cumsum[i + 1]])
            #                         for i in range(n_view_graphs.batch_size)])
            #
            # umap_zdist = torch.cat([torch.pdist(umap_z[batch_graph_size_cumsum[i]: batch_graph_size_cumsum[i + 1]])
            #                         for i in range(n_view_graphs.batch_size)])
            #
            # batch_zdist_size = (batch_graph_size * (batch_graph_size - 1) / 2).long().to(device)

            # student_t_loss, z_hat = net.loss_fn_kl_t_pdist_w_umap(n_view_graphs, edge_only_graph, tsne_zdist,
            #                                                       umap_zdist, batch_zdist_size)
            z_hat = net(edge_only_graph, n_view_graphs)
            y = n_view_graphs.ndata['y']
            _zhats_d = {}
            _zhats_d['tsne_gt'] = tsne_z
            _zhats_d['umap_gt'] = umap_z
            _zhats_d['pred'] = z_hat[:, :2]
            _zhats_d['y'] = y
            zhats_d[d_name] = _zhats_d

            if n_method > 1:
                tsne_nmi, tsne_sc = visualization_metric.get_nmi_sc(z_hat[:, :2].cpu().numpy(), y.tolist())
                umap_nmi, umap_sc = visualization_metric.get_nmi_sc(z_hat[:, 2:].cpu().numpy(), y.tolist())
            else:
                # plot_zs(umap_z, z_hat[:, :2], y, d_name=d_name, exp_name='kl')
                umap_nmi, umap_sc = visualization_metric.get_nmi_sc(z_hat[:, :2].cpu().numpy(), y.tolist())

                # plot_zs(umap_z, y, d_name=d_name, exp_name='kl')
                # plot_zs(zhats, ys, d_name, nmi, sc, method_name)
                tsne_nmi, tsne_sc = 0.0, 0.0

            scs.append(torch.Tensor([tsne_sc, umap_sc]))
            nmis.append(torch.Tensor([tsne_nmi, umap_nmi]))
            # losses.append(student_t_loss)
            d_names.append(d_name[0])

    nmis_t = torch.stack(nmis)
    scs_t = torch.stack(scs)
    # losses_t = torch.Tensor(losses)
    # net.train()
    return nmis_t, scs_t, d_names, zhats_d


def get_gt_by_name(d_name, gt_dict):
    d_name = (d_name, )
    tsne_v = gt_dict[d_name]['tsne']
    umap_v = gt_dict[d_name]['umap']
    return tsne_v, umap_v


def get_gt_df(gt_save, d_names, prefix=''):
    nmi_gt_dict, sc_gt_dict = gt_save
    # print(list(nmi_gt_dict.keys())[:10])
    df_dict = {}
    for d_name in d_names:
        df_dict['d_name'] = df_dict.get('d_name', []) + [d_name]
        tsne_nmi, umap_nmi = get_gt_by_name(d_name, nmi_gt_dict)
        tsne_sc, umap_sc = get_gt_by_name(d_name, sc_gt_dict)
        df_dict[f'{prefix}gt_tsne_nmi'] = df_dict.get(f'{prefix}gt_tsne_nmi', []) + [tsne_nmi]
        df_dict[f'{prefix}gt_umap_nmi'] = df_dict.get(f'{prefix}gt_umap_nmi', []) + [umap_nmi]
        df_dict[f'{prefix}gt_tsne_sc'] = df_dict.get(f'{prefix}gt_tsne_sc', []) + [tsne_sc]
        df_dict[f'{prefix}gt_umap_sc'] = df_dict.get(f'{prefix}gt_umap_sc', []) + [umap_sc]
    df = pd.DataFrame(df_dict)
    return df


def get_eval_df(nmis_t, scs_t, d_names, prefix='pred_'):
    tsne_nmi, umap_nmi = nmis_t[:, 0], nmis_t[:, 1]
    tsne_sc, umap_sc = scs_t[:, 0], scs_t[:, 1]

    df_dict = {}
    df_dict['d_name'] = d_names
    df_dict[f'{prefix}tsne_nmi'] = tsne_nmi
    df_dict[f'{prefix}umap_nmi'] = umap_nmi
    df_dict[f'{prefix}tsne_sc'] = tsne_sc
    df_dict[f'{prefix}umap_sc'] = umap_sc
    # df_dict[f'{prefix}losses'] = losses_t

    df = pd.DataFrame(df_dict)
    return df


def get_gt_pred_df_from_d(gt_save, d_names, d_loader, net, to_save_name, n_method):
    gt_df = get_gt_df(gt_save, d_names)
    # print(gt_df)

    net = net.cuda()
    res = evaluate_model_performance(d_loader, net, n_method=n_method)
    nmis_t, scs_t, d_names, zhats_d = res
    pred_df = get_eval_df(nmis_t, scs_t, d_names)
    # print(pred_df)

    df = pd.merge(gt_df, pred_df, on='d_name')
    df.to_csv(to_save_name)
    torch.save(zhats_d, f'{to_save_name[:-4]}.tar')
    return df


def get_mean_performance(df):
    relative_tsne_nmi_precision = df['pred_tsne_nmi'] / df['gt_tsne_nmi']
    relative_tsne_nmi_gap = df['gt_tsne_nmi'] - df['pred_tsne_nmi']

    relative_umap_nmi_precision = df['pred_umap_nmi'] / df['gt_umap_nmi']
    relative_umap_nmi_gap = df['gt_umap_nmi'] - df['pred_umap_nmi']

    relative_tsne_sc_precision = df['pred_tsne_sc'] / df['gt_tsne_sc']
    relative_tsne_sc_gap = df['gt_tsne_sc'] - df['pred_tsne_sc']

    relative_umap_sc_precision = df['pred_umap_sc'] / df['gt_umap_sc']
    relative_umap_sc_gap = df['gt_umap_sc'] - df['pred_umap_sc']

    df['relative_tsne_nmi_precision'] = relative_tsne_nmi_precision
    df['relative_tsne_nmi_gap'] = relative_tsne_nmi_gap
    df['relative_umap_nmi_precision'] = relative_umap_nmi_precision
    df['relative_umap_nmi_gap'] = relative_umap_nmi_gap

    df['relative_tsne_sc_precision'] = relative_tsne_sc_precision
    df['relative_tsne_sc_gap'] = relative_tsne_sc_gap
    df['relative_umap_sc_precision'] = relative_umap_sc_precision
    df['relative_umap_sc_gap'] = relative_umap_sc_gap

    mean_df = df.drop('d_name', axis=1).mean()
    std_df = df.drop('d_name', axis=1).std()
    return mean_df, std_df

def load_d_data_names(datasets):
    root = '/mnt/data01/public/aad_data'
    train_names = []
    for dataset in datasets:
        file_names = os.listdir(root + '/' + dataset)
        file_names = ['.'.join(name.split('.')[:-1]) for name in file_names]
        train_names += file_names
    return train_names


if __name__ == '__main__':
    root = '/mnt/data01/****/****/gnn/res/'
    # ckp_name = root + '500ds-GNN-clip_w_umap-regin2-8gt-sigma_mv-run4-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-59.tar'
    # ckp_name = root + '500ds-GNN-clip_w_umap-bce-regin2-8gt-sigma_mv-run4-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-119.tar'
    # ckp_name = root + '500ds-GNN-clip_w_umap-l2-subnet-regin2-8gt-sigma_mv-run4-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-149.tar'
    # ckp_name = root + '500ds-GNN-clip_w_umap-l2-uninet-regin2-8gt-sigma_mv-run4-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-154.tar'
    # ckp_name = root + '500ds-GNN-clip_w_umap-bothl2-uninet-regin2-8gt-sigma_mv-run4-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-219.tar'

    save_f_names = [
        '500ds-GNN-clip_w_umap-l2-uninet-regin2-8gt-sigma_mv-run4-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-154.tar'

        # '500ds-GNN-clip_w_umap-umaponly-knn-l2-newparallelgt-regin2-8gt-sigma_mv-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-94.tar'
                    ]
    loss_types = ['tsne-knn-l2']

    # test_names = load_d_data_names(['cifar10'])
    # shuffle = torch.randperm(len(test_names))
    # test_names = [test_names[i] for i in shuffle[:500]]

    tnames = '''fmnist-2class-comb6_size9_seed0_cdist
mnist-3class-comb16_size8_seed0_cdist
mnist-4class-comb10_size7_seed0_cdist
mnist-5class-comb39_size6_seed0_cdist
fmnist-6class-comb19_size6_seed0_cdist
fmnist-7class-comb42_size7_seed0_cdist
mnist-8class-comb24_size9_seed0_cdist
cifar10-2class-comb41_size9_seed0_cdist
cifar10-7class-comb10_size8_seed0_cdist
cifar10-3class-comb23_size3_seed0_cdist
cifar10-8class-comb17_size5_seed0_cdist
cifar10-6class-comb27_size9_seed0_cdist
cifar10-4class-comb7_size4_seed0_cdist
cifar10-5class-comb55_size8_seed0_cdist'''
    test_names = tnames.split()


    for ind, save_f_name in enumerate(save_f_names):
        ckp_name = root + save_f_name
        loss_type = loss_types[ind]

        net, save = get_model_ckp(ckp_name)
        net.n_out_subnet = 1
        net.is_gt_parallel = False

        max = torch.argmax(torch.tensor([torch.mean(i, dim=0)[1] for i in  save['test_nmi']]))
        print('===============================')
        print(save_f_name)
        print(loss_type)
        print(max)

        gt_save = torch.load('/mnt/data01/public/aad_data/clip_tsne_w_umap_ground_truth.tar', weights_only=False)
        nmi_gt_dict, sc_gt_dict = gt_save

        train_d_names, _ = get_d_names_from_save(save)
        # _, test_d_names = get_d_names_from_save(get_model_ckp(root + '500ds-GNN-clip_w_umap-umaponly-umap_knn_l2-new-regin2-8gt-sigma_mv-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-84.tar')[1])
        test_d_names = test_names
        train_d_loader = get_loader_by_exp_name(train_d_names, exp_names='clip')
        test_d_loader = get_loader_by_exp_name(test_d_names, exp_names='clip')

        # train_df = get_gt_pred_df_from_d(gt_save, train_d_names, train_d_loader, net,
        #                                  to_save_name=f'2for_plot_clip_500_train_df_{loss_type}.csv',
        #                                  n_method=2)
        # train_performance = get_mean_performance(train_df)
        # print('train', train_performance)

        test_df = get_gt_pred_df_from_d(gt_save, test_d_names, test_d_loader, net,
                                        to_save_name=f'2for_plot_clip_500_test_df_{loss_type}.csv',
                                        n_method=2)
        test_performance = get_mean_performance(test_df)
        print('test', test_performance)
        print('===========================================')

# for flip in [
#     'pos',
#         # 'sum',
#         # 'linf',
#         #                    'l1', 'l2','random'
#         ]:
#     # f_name = f'./res/GNN-resgin-8gt-sigma_mv-run4-500epoch-kl_t_64dim_svd_u-run2_{flip}_only_complete_G_sep_gnn-epoch-499.tar'
#     f_name = f'/mnt/data01/****/****/gnn/res/GNN-regin2-8gt-sigma_mv-run4-300epoch-kl_t_64dim_svdu_complete_G-5000ds-epoch-244.tar'
#
#
#     train_names = save['train_names']
#     test_names = save['test_names']
#     print(len(train_names))
#     gt_train = get_ground_truth(ground_truth_dict, train_names)
#     gt_test = get_ground_truth(ground_truth_dict, test_names)
#
#     print(f'GT train: {gt_train}, GT test: {gt_test}')
#
#     train_nmi = save['train_nmi']
#     test_nmi = save['test_nmi']
#     print(train_nmi)
#     print(test_nmi)
#
#     print(flip)
#     print('best: ', torch.max(torch.tensor(test_nmi)), torch.argmax(torch.tensor(test_nmi)))
#     print('last: ', test_nmi[-1])
#     print('train last: ', train_nmi[-1])
#     print('===========')
#
# # plt.plot(train_nmi, label='train_nmi')
# # plt.plot(test_nmi, label='test_nmi')
# #
# # plt.ylabel('NMI')
# # plt.xlabel('10 * epoch')
# # plt.title('5 sigma view, 64 lap row norm PE dim, 5 ds, train & test nmi during the training')
# #
# # plt.legend()
# # plt.show()
#
# # plt.plot(save['running_loss'], label='train_nmi')
# # plt.show()
