import os

# 限制 BLAS/MKL/OpenMP 只用 1 个线程
os.environ["OMP_NUM_THREADS"]         = "1"   # OpenMP
os.environ["OPENBLAS_NUM_THREADS"]    = "1"   # OpenBLAS
os.environ["MKL_NUM_THREADS"]         = "1"   # Intel MKL
os.environ["VECLIB_MAXIMUM_THREADS"]  = "1"   # Mac OS X Accelerate
os.environ["NUMEXPR_NUM_THREADS"]     = "1"   # NumExpr

import sys
sys.path.append('/mnt/data01/****/****')
import time

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from scipy.spatial.distance import cdist as scipy_cdist
from sklearn.utils.extmath import randomized_svd

from functools import partial
import os
import time

import torch
import dgl

torch.set_num_threads(1)
torch.set_num_interop_threads(1)

# import read_data
from gnn.Datasets import LargeCompleteMVGraphDatasets
from gnn.Datasets import graph_cut
from gnn.pe_utils import posenc_config
from gnn.networks.AutoVisualMultiGINComplete import AutoVisualNet


from gnn.useful_utils import graph_sampler
from gnn.useful_utils import visualization_metric


def get_model_ckp(ckp_name):
    net, save = torch.load(ckp_name)
    return net, save


def get_loader(data_names, exp_name, cdist_path, pe_path, b_size=1):
    get_ds_fn = partial(LargeCompleteMVGraphDatasets.DatasetGraphDataset,
                        exp_name=exp_name,
                        cdist_path=cdist_path,  # ***
                        visual_path='/mnt/data01/public/aad_data/bo',
                        visual_path_umap='/mnt/data01/public/aad_data/bo/umap',
                        is_test=False,  # ***
                        normalize_z=True, z_cali_method='none',
                        # z_anchor=dummy_train_ds.z[0],
                        precomputed_pe_path=pe_path,  # ***
                        z_mu=None, z_std=None, flip_sign_method='pos')
    ds = get_ds_fn(data_names=data_names)

    if b_size == 1:
        get_loader_for_eval = partial(dgl.dataloading.GraphDataLoader, batch_size=1,
                                      shuffle=True,
                                      num_workers=0,
                                      # collate_fn=lambda x: list(zip(*x))
                                      )

        train_loader = get_loader_for_eval(ds)
    else:
        train_sampler = torch.utils.data.RandomSampler(ds)
        train_loader = dgl.dataloading.GraphDataLoader(ds,
                                                       batch_sampler=graph_sampler.GraphBatchSampler(train_sampler,
                                                                                                     batch_size=30000,
                                                                                                     # 34000,
                                                                                                     block_list=ds.ds_block_sizes),
                                                       num_workers=10, )
    return train_loader


def get_loader_by_exp_name(data_names, exp_names):
    if exp_names == 'clip':
        return get_loader(data_names,
                          exp_name=exp_names,
                          cdist_path='/mnt/data01/public/aad_data',
                          pe_path='/mnt/data01/public/aad_data/pe')
    elif exp_names == 'gene':
        return get_loader(data_names,
                          exp_name=exp_names,
                          cdist_path='/mnt/data01/public/aad_data/gene_filtered',
                          pe_path='/mnt/data01/public/aad_data/pe/gene_filtered')
    elif exp_names == 'uci':
        return get_loader(data_names,
                          exp_name=exp_names,
                          cdist_path='/mnt/data01/public/aad_data/uci',
                          pe_path='/mnt/data01/public/aad_data/pe')


def get_d_names_from_save(save):
    train_names = save['train_names']
    test_names = save['test_names']
    return train_names, test_names


def plot_zs(z, zhat, y, d_name, exp_name):
    def scatter_z(z, y, ax, label, marker, alpha=1.0):
        z = z.cpu().numpy()
        scatter = ax.scatter(z[:, 0], z[:, 1], c=y,
                             label=label, marker=marker, alpha=alpha,
                             cmap='viridis'
                             )
        return scatter

    d_name = d_name[0]
    fig = plt.figure(figsize=(10, 8))
    fig.set(tight_layout=True)
    ax = fig.add_subplot()

    scatter_z(z, y, ax,
              label='z', marker='+', alpha=0.5,
              )

    s = scatter_z(zhat, y, ax,
                  label='z_hat', marker='o', alpha=1.0,
                  )
    plt.colorbar(s)
    plt.title(f'Visualization of {d_name} CLIP Feature')
    plt.legend()
    plt.savefig(f'/mnt/data01/****/****/gnn/transfer_exp/plot_res/{exp_name}-{d_name}.png')


def evaluate_model_performance(d_loader, net, n_method=2):
    net.eval()
    device = net.device
    with torch.no_grad():
        scs = []
        nmis = []
        d_names = []
        losses = []
        for d_name, n_view_graphs in d_loader:
            edge_only_graph = dgl.graph(n_view_graphs.edges(), num_nodes=n_view_graphs.num_nodes()).to(device)
            # tsne_z = n_view_graphs.ndata['z'].to(torch.float32).to(device)
            # umap_z = n_view_graphs.ndata['z_umap'].to(torch.float32).to(device)
            # batch_graph_size = n_view_graphs.batch_num_nodes()

            # batch_graph_size_cumsum = torch.cumsum(torch.cat([torch.zeros(1), batch_graph_size]), dim=0).long()
            # tsne_zdist = torch.cat([torch.pdist(tsne_z[batch_graph_size_cumsum[i]: batch_graph_size_cumsum[i + 1]])
            #                         for i in range(n_view_graphs.batch_size)])
            #
            # umap_zdist = torch.cat([torch.pdist(umap_z[batch_graph_size_cumsum[i]: batch_graph_size_cumsum[i + 1]])
            #                         for i in range(n_view_graphs.batch_size)])
            #
            # batch_zdist_size = (batch_graph_size * (batch_graph_size - 1) / 2).long().to(device)

            # student_t_loss, z_hat = net.loss_fn_kl_t_pdist_w_umap(n_view_graphs, edge_only_graph, tsne_zdist,
            #                                                       umap_zdist, batch_zdist_size)
            z_hat = net(edge_only_graph, n_view_graphs)
            y = n_view_graphs.ndata['y']
            if n_method > 1:
                tsne_nmi, tsne_sc = visualization_metric.get_nmi_sc(z_hat[:, :2].cpu().numpy(), y.tolist())
                umap_nmi, umap_sc = visualization_metric.get_nmi_sc(z_hat[:, 2:].cpu().numpy(), y.tolist())
            else:
                # plot_zs(umap_z, z_hat[:, :2], y, d_name=d_name, exp_name='kl')
                umap_nmi, umap_sc = visualization_metric.get_nmi_sc(z_hat[:, :2].cpu().numpy(), y.tolist())
                tsne_nmi, tsne_sc = 0.0, 0.0

            scs.append(torch.Tensor([tsne_sc, umap_sc]))
            nmis.append(torch.Tensor([tsne_nmi, umap_nmi]))
            # losses.append(student_t_loss)
            d_names.append(d_name[0])

    nmis_t = torch.stack(nmis)
    scs_t = torch.stack(scs)
    # losses_t = torch.Tensor(losses)
    # net.train()
    return nmis_t, scs_t, d_names


def get_gt_by_name(d_name, gt_dict):
    d_name = (d_name, )
    tsne_v = gt_dict[d_name]['tsne']
    umap_v = gt_dict[d_name]['umap']
    return tsne_v, umap_v


def get_gt_df(gt_save, d_names, prefix=''):
    nmi_gt_dict, sc_gt_dict = gt_save
    # print(list(nmi_gt_dict.keys())[:10])
    df_dict = {}
    for d_name in d_names:
        df_dict['d_name'] = df_dict.get('d_name', []) + [d_name]
        tsne_nmi, umap_nmi = get_gt_by_name(d_name, nmi_gt_dict)
        tsne_sc, umap_sc = get_gt_by_name(d_name, sc_gt_dict)
        df_dict[f'{prefix}gt_tsne_nmi'] = df_dict.get(f'{prefix}gt_tsne_nmi', []) + [tsne_nmi]
        df_dict[f'{prefix}gt_umap_nmi'] = df_dict.get(f'{prefix}gt_umap_nmi', []) + [umap_nmi]
        df_dict[f'{prefix}gt_tsne_sc'] = df_dict.get(f'{prefix}gt_tsne_sc', []) + [tsne_sc]
        df_dict[f'{prefix}gt_umap_sc'] = df_dict.get(f'{prefix}gt_umap_sc', []) + [umap_sc]
    df = pd.DataFrame(df_dict)
    return df


def get_eval_df(nmis_t, scs_t, d_names, prefix='pred_'):
    tsne_nmi, umap_nmi = nmis_t[:, 0], nmis_t[:, 1]
    tsne_sc, umap_sc = scs_t[:, 0], scs_t[:, 1]

    df_dict = {}
    df_dict['d_name'] = d_names
    df_dict[f'{prefix}tsne_nmi'] = tsne_nmi
    df_dict[f'{prefix}umap_nmi'] = umap_nmi
    df_dict[f'{prefix}tsne_sc'] = tsne_sc
    df_dict[f'{prefix}umap_sc'] = umap_sc
    # df_dict[f'{prefix}losses'] = losses_t

    df = pd.DataFrame(df_dict)
    return df


def get_gt_pred_df_from_d(gt_save, d_names, d_loader, net, to_save_name, n_method):
    gt_df = get_gt_df(gt_save, d_names)
    # print(gt_df)

    net = net.cuda()
    res = evaluate_model_performance(d_loader, net, n_method=n_method)
    pred_df = get_eval_df(*res)
    # print(pred_df)

    df = pd.merge(gt_df, pred_df, on='d_name')
    df.to_csv(to_save_name)
    return df


def get_mean_performance(df):
    relative_tsne_nmi_precision = df['pred_tsne_nmi'] / df['gt_tsne_nmi']
    relative_tsne_nmi_gap = df['gt_tsne_nmi'] - df['pred_tsne_nmi']

    relative_umap_nmi_precision = df['pred_umap_nmi'] / df['gt_umap_nmi']
    relative_umap_nmi_gap = df['gt_umap_nmi'] - df['pred_umap_nmi']

    relative_tsne_sc_precision = df['pred_tsne_sc'] / df['gt_tsne_sc']
    relative_tsne_sc_gap = df['gt_tsne_sc'] - df['pred_tsne_sc']

    relative_umap_sc_precision = df['pred_umap_sc'] / df['gt_umap_sc']
    relative_umap_sc_gap = df['gt_umap_sc'] - df['pred_umap_sc']

    df['relative_tsne_nmi_precision'] = relative_tsne_nmi_precision
    df['relative_tsne_nmi_gap'] = relative_tsne_nmi_gap
    df['relative_umap_nmi_precision'] = relative_umap_nmi_precision
    df['relative_umap_nmi_gap'] = relative_umap_nmi_gap

    df['relative_tsne_sc_precision'] = relative_tsne_sc_precision
    df['relative_tsne_sc_gap'] = relative_tsne_sc_gap
    df['relative_umap_sc_precision'] = relative_umap_sc_precision
    df['relative_umap_sc_gap'] = relative_umap_sc_gap

    mean_df = df.drop('d_name', axis=1).mean()
    std_df = df.drop('d_name', axis=1).std()
    return mean_df, std_df



from torch_geometric.graphgym.config import (cfg, dump_cfg,
                                             # set_agg_dir,
                                             set_cfg, load_cfg,
                                             makedirs_rm_exist)
if __name__ == '__main__':
    root = '/mnt/data01/****/****/gnn/res/'
    # ckp_name = root + '500ds-GNN-clip_w_umap-regin2-8gt-sigma_mv-run4-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-59.tar'
    # ckp_name = root + '500ds-GNN-clip_w_umap-bce-regin2-8gt-sigma_mv-run4-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-119.tar'
    # ckp_name = root + '500ds-GNN-clip_w_umap-l2-subnet-regin2-8gt-sigma_mv-run4-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-149.tar'
    # ckp_name = root + '500ds-GNN-clip_w_umap-l2-uninet-regin2-8gt-sigma_mv-run4-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-154.tar'
    # ckp_name = root + '500ds-GNN-clip_w_umap-bothl2-uninet-regin2-8gt-sigma_mv-run4-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-219.tar'

    save_f_names = [
        '500ds-GNN-clip_w_umap-umaponly-knn-l2-newparallelgt-regin2-8gt-sigma_mv-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-94.tar'
                    ]
    loss_types = ['kl-student-t', 'kl-umap', 'bce', 'bce-knn', 'l2-umap', 'l2-dist', 'umap-orig', 'knn-l2', 'knn-l2']

    for ind, save_f_name in enumerate(save_f_names):
        ckp_name = root + save_f_name
        loss_type = loss_types[ind]

        net, save = get_model_ckp(ckp_name)


        train_d_names, _ = get_d_names_from_save(save)
        _, test_d_names = get_d_names_from_save(get_model_ckp(root + '500ds-GNN-clip_w_umap-umaponly-umap_knn_l2-new-regin2-8gt-sigma_mv-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-84.tar')[1])
        train_d_loader = get_loader_by_exp_name(train_d_names, exp_names='clip')
        test_d_loader = get_loader_by_exp_name(test_d_names, exp_names='clip')

        device = 'cpu'
        net = net.cpu()
        net.device = 'cpu'
        print(net.device)

        set_cfg(cfg)
        posenc_config.set_cfg_posenc(cfg)
        with torch.no_grad():
            used_times = []
            for i in range(10):
                d_time = time.perf_counter()
                x = np.random.randn(3000, 512)
                cdist = scipy_cdist(x, x).astype(np.float32)
                pes = []
                ss = []
                for j in range(5):
                    bandwidth = torch.median(torch.from_numpy(cdist)) * 1.0  # Standard deviation as scaling factor
                    gaussian_weights = torch.exp(-cdist ** 2 / (2 * bandwidth ** 2))  # Gaussian similarity
                    weights_a = gaussian_weights
                    u, s, vh = randomized_svd(weights_a.numpy(), n_components=64)
                    graph_pe = u
                    pes.append(graph_pe)
                    ss.append(weights_a)
                d_time = time.perf_counter() - d_time
                print(d_time)

                src, dst = torch.where(torch.isnan(torch.from_numpy(cdist)) == False)
                N = cdist.shape[0]
                print('N', N)
                n_view_graphs = dgl.graph((src, dst), num_nodes=N)
                for j in range(5):
                    src, dst = torch.where(torch.isnan(weights_a) == False)
                    edge_weights = ss[j][src, dst].to(torch.float32)

                    n_view_graphs.edata[f'weight{j}'] = edge_weights.to(torch.float32)
                    n_view_graphs.ndata[f'pe{j}'] = torch.from_numpy(pes[j])
                edge_only_graph = dgl.graph(n_view_graphs.edges(), num_nodes=n_view_graphs.num_nodes()).to(device)

                time_start = time.perf_counter()
                z_hat = net(edge_only_graph, n_view_graphs)
                t = time.perf_counter() - time_start
                used_times.append(t + d_time)
                print(d_time, t)

            # i = 0
            # for d_name, n_view_graphs in test_d_loader:
            #     edge_only_graph = dgl.graph(n_view_graphs.edges(), num_nodes=n_view_graphs.num_nodes()).to(device)
            #     n_view_graphs = n_view_graphs.to(device)
            #     # edge_only_graph.edata["_edge_weight"] = edge_only_graph.edata["edge_weight"].to(device)
            #     # n_view_graphs.edata["_edge_weight"] = n_view_graphs.edata["_edge_weight"].to(device)
            #     for j in range(5):
            #         n_view_graphs.ndata[f'pe{j}'] = n_view_graphs.ndata[f'pe{j}'].to(device)
            #         n_view_graphs.edata[f'weight{j}'] = n_view_graphs.edata[f'weight{j}'].to(device)
            #     time_start = time.perf_counter()
            #     z_hat = net(edge_only_graph, n_view_graphs)
            #     t = time.perf_counter() - time_start
            #     used_times.append(t)
            #     print(t)
            #     i = i + 1
            #     if i == 10:
            #         break

        used_times = np.array(used_times)
        torch.save(used_times, './autodv_w_pe_used_times.save')
        print(used_times.mean(), used_times.std())