import sys
sys.path.append('/mnt/data01/****/****')

from functools import partial
import os

import torch
import dgl

# import read_data
from gnn.Datasets import LargeCompleteMVGraphDatasets as LargeCompleteMVGraphDatasets
from gnn.networks import AutoVisualMultiGINComplete
from gnn.useful_utils import visualization_metric

from gnn.useful_utils import train_utils
from gnn.useful_utils import graph_sampler

def evaluate(d_loader, net,):
    # print(f'net.type = {type(net)}')
    net.eval()
    device = net.device
    with torch.no_grad():
        scs = []
        nmis = []
        losses = []
        for n_view_graphs in d_loader:
            edge_only_graph = dgl.graph(n_view_graphs.edges(), num_nodes=n_view_graphs.num_nodes()).to(device)
            z = n_view_graphs.ndata['z'].to(torch.float32).to(device)
            batch_graph_size = n_view_graphs.batch_num_nodes()

            batch_graph_size_cumsum = torch.cumsum(torch.cat([torch.zeros(1), batch_graph_size]), dim=0).long()
            zdist = torch.cat([torch.pdist(z[batch_graph_size_cumsum[i]: batch_graph_size_cumsum[i + 1]])
                               for i in range(n_view_graphs.batch_size)])

            batch_zdist_size = (batch_graph_size * (batch_graph_size - 1) / 2).long().to(device)

            student_t_loss, z_hat = net.loss_fn_kl_t_pdist(n_view_graphs, edge_only_graph, zdist, batch_zdist_size)

            y = n_view_graphs.ndata['y']
            nmi, sc = visualization_metric.get_nmi_sc(z_hat.cpu().numpy(), y.tolist())

            scs.append(sc)
            nmis.append(nmi)
            losses.append(student_t_loss)

    scs_t = torch.mean(torch.Tensor(scs))
    nmis_t = torch.mean(torch.Tensor(nmis))
    losses_t = torch.mean(torch.Tensor(losses))
    net.train()
    return nmis_t, scs_t, losses_t


def load_d_data_names(datasets):
    root = '/mnt/data01/public/aad_data/gene_filtered'
    train_names = []
    for dataset in datasets:
        file_names = os.listdir(root + '/' + dataset)
        file_names = ['.'.join(name.split('.')[:-1]) for name in file_names]
        train_names += file_names
    return train_names

def main():
    device = 'cuda:0'

    # train_names = ['arrhythmia', 'wine', 'lympho', 'glass', 'vertebral',
    #                                    'wbc', 'ecoli', 'ionosphere', 'breastw', 'pima',
    #                                    'vowels',        'letter']
    # test_names = ['cardio',  'seismic', 'musk', 'speech', 'abalone']

    datasets = [
        # 'mnist', 'fmnist', 'cifar10'
        'PBMC68K', 'Campbell', 'Mouse_retina',
        'Baron Human'
        # 'uci'
    ]
    train_names = load_d_data_names(datasets)
    shuffle = torch.randperm(len(train_names))
    # train_names = [train_names[i] for i in shuffle[:5000]]
    # test_names = load_d_data_names(['cifar10'])
    # shuffle = torch.randperm(len(test_names))
    # test_names = [test_names[i] for i in shuffle[:500]]

    # train_names = ['mnist_group2', 'mnist_group1', 'fmnist_group2', 'fmnist_group1', 'cifar10_group1'] + [f'mnist_comb{i}' for i in range(252)]
    # # train_names = [f'mnist_comb{i}' for i in range(128, 252)]
    # test_names = ['cifar10_group2']

    # z_cali = 'none'

    # precomputed_pe_path = 'get_gt'
    precomputed_pe_path = '/mnt/data01/public/aad_data/pe'


    get_ds_fn = partial(LargeCompleteMVGraphDatasets.DatasetGraphDataset,
                        exp_name='gene',
                        is4gt=True,
                        cdist_path='/mnt/data01/public/aad_data/gene_filtered',
                        visual_path='/mnt/data01/public/aad_data/bo/gene_filtered',
                        visual_path_umap='/mnt/data01/public/aad_data/bo/umap/gene_filtered',
                        normalize_z=True, z_cali_method='none',
                                                 # z_anchor=dummy_train_ds.z[0],
                                                 # precomputed_pe_path=precomputed_pe_path,
                 z_mu=None, z_std=None, flip_sign_method='pos')

    train_ds = get_ds_fn(data_names=train_names, precomputed_pe_path=precomputed_pe_path)
    # test_ds = get_ds_fn(data_names=test_names, precomputed_pe_path=precomputed_pe_path)

    get_loader_for_eval = partial(dgl.dataloading.GraphDataLoader, batch_size=1,
                         shuffle=True,
                         num_workers=0,
                         # collate_fn=lambda x: list(zip(*x))
                         )

    train_loader4val = get_loader_for_eval(train_ds)
    # test_loader4val = get_loader_for_eval(test_ds)

    # print(len(train_loader))

    train_nmi, train_sc = train_utils.get_true_metric(train_loader4val)
    clip_ground_truth = train_nmi, train_sc
    torch.save(clip_ground_truth, '/mnt/data01/public/aad_data/gene_tsne_w_umap_ground_truth.tar')
    # print(f'train_nmi: {train_nmi}, train_sc: {train_sc}')
    #
    # test_nmi, test_sc = train_utils.get_true_metric(test_loader4val)
    # print(f'test_nmi: {test_nmi}, test_sc: {test_sc}')

    # a = iter(train_ds)
    # a_ = next(a)
    # raise NotImplementedError


if __name__ == '__main__':
    main()