import sys
sys.path.append('/mnt/data01/****/****')

from functools import partial
import os
import time
import multiprocessing


import torch
import dgl

# import read_data
from gnn.Datasets import LargeCompleteMVGraphDatasets_gene as LargeCompleteMVGraphDatasets
from gnn.networks import AutoVisualMultiGINComplete
from gnn.useful_utils import visualization_metric

from gnn.useful_utils import train_utils
from gnn.useful_utils import graph_sampler

def evaluate(d_loader, net,):
    # print(f'net.type = {type(net)}')
    net.eval()
    device = net.device
    with torch.no_grad():
        scs = []
        nmis = []
        losses = []
        for _, n_view_graphs in d_loader:
            edge_only_graph = dgl.graph(n_view_graphs.edges(), num_nodes=n_view_graphs.num_nodes()).to(device)
            z = n_view_graphs.ndata['z'].to(torch.float32).to(device)
            batch_graph_size = n_view_graphs.batch_num_nodes()

            batch_graph_size_cumsum = torch.cumsum(torch.cat([torch.zeros(1), batch_graph_size]), dim=0).long()
            zdist = torch.cat([torch.pdist(z[batch_graph_size_cumsum[i]: batch_graph_size_cumsum[i + 1]])
                               for i in range(n_view_graphs.batch_size)])

            batch_zdist_size = (batch_graph_size * (batch_graph_size - 1) / 2).long().to(device)

            student_t_loss, z_hat = net.loss_fn_kl_t_pdist(n_view_graphs, edge_only_graph, zdist, batch_zdist_size)

            y = n_view_graphs.ndata['y']
            nmi, sc = visualization_metric.get_nmi_sc(z_hat.cpu().numpy(), y.tolist())

            scs.append(sc)
            nmis.append(nmi)
            losses.append(student_t_loss)

    scs_t = torch.mean(torch.Tensor(scs))
    nmis_t = torch.mean(torch.Tensor(nmis))
    losses_t = torch.mean(torch.Tensor(losses))
    net.train()
    return nmis_t, scs_t, losses_t


def load_d_data_names(datasets):
    root = '/mnt/data01/public/aad_data/gene_filtered'
    train_names = []
    for dataset in datasets:
        file_names = os.listdir(root + '/' + dataset)
        file_names = ['.'.join(name.split('.')[:-1]) for name in file_names]
        train_names += file_names
    return train_names

def main():
    multiprocessing.set_start_method('spawn')
    device = 'cuda:0'

    # train_names = ['arrhythmia', 'wine', 'lympho', 'glass', 'vertebral',
    #                                    'wbc', 'ecoli', 'ionosphere', 'breastw', 'pima',
    #                                    'vowels',        'letter']
    # test_names = ['cardio',  'seismic', 'musk', 'speech', 'abalone']

    datasets = [
        # 'mnist',
        # 'fmnist'
        'Campbell',  'PBMC68K', 'Baron Human'

    ]

    is_resume = False
    if is_resume:
        ckp_name = f'/mnt/data01/****/****/gnn/res/GNN-gene-regin2-8gt-sigma_mv-run4-300epoch-kl_t_64dim_svdu_complete_G-5000ds-epoch-144.tar'
        net, save = torch.load(ckp_name)
        train_names = save['train_names']
        test_names = save['test_names']
        start_epoch = 145
    else:
        train_names = load_d_data_names(datasets)

        uci_gt_dict = torch.load('/mnt/data01/public/aad_data/gene_filtered/gene_tsne_nmi_ground_truth_test.tar', weights_only=False)
        train_names = train_utils.filter_low_nmi(train_names, uci_gt_dict, th=0.1)
        print(f'train len: {len(train_names)}')
        # shuffle = torch.randperm(len(train_names))
        # train_names = [train_names[i] for i in shuffle[:5000]]
        test_names = load_d_data_names(['Mouse_retina'])
        test_names = train_utils.filter_low_nmi(test_names, uci_gt_dict, th=0.1)
        print(f'test len: {len(test_names)}')
        # shuffle = torch.randperm(len(test_names))
        # test_names = [test_names[i] for i in shuffle[:500]]
        start_epoch = 0
        net = AutoVisualMultiGINComplete.AutoVisualNet(input_dim=64, gnn_hidden=128, gnn_out=128, out_dim=2,
                                                       n_graph_view=5, n_transformer=8, device=device)


    train_nmi, train_sc = train_utils.get_ground_truth_from_precomp(train_names, uci_gt_dict)
    print(f'train_nmi: {train_nmi}, train_sc: {train_sc}')

    test_nmi, test_sc = train_utils.get_ground_truth_from_precomp(test_names, uci_gt_dict)
    print(f'test_nmi: {test_nmi}, test_sc: {test_sc}')

    # raise NotImplementedError
    # precomputed_pe_path = None
    precomputed_pe_path = '/mnt/data01/public/aad_data/pe/gene_filtered'

    get_ds_fn = partial(LargeCompleteMVGraphDatasets.DatasetGraphDataset,
                        cdist_path='/mnt/data01/public/aad_data/gene_filtered',
                        visual_path='/mnt/data01/public/aad_data/bo/gene_filtered',
                        normalize_z=True, z_cali_method='none',
                                                 # z_anchor=dummy_train_ds.z[0],
                                                 # precomputed_pe_path=precomputed_pe_path,
                 z_mu=None, z_std=None, flip_sign_method='pos')

    train_ds = get_ds_fn(data_names=train_names, precomputed_pe_path=precomputed_pe_path)
    test_ds = get_ds_fn(data_names=test_names, precomputed_pe_path=precomputed_pe_path)

    # raise NotImplementedError

    # precomputed_pe_path = '../prepare_data/pe/pe_for_gat'
    # sampler =
    get_loader_for_eval = partial(dgl.dataloading.GraphDataLoader, batch_size=1,
                         shuffle=True,
                         num_workers=0,
                         # collate_fn=lambda x: list(zip(*x))
                         )


    # get_loader_for_train = partial(dgl.dataloading.GraphDataLoader,
    #                                batch_size=16,
    #                               shuffle=True,
    #                               num_workers=10,
    #
    #                               # collate_fn=lambda x: list(zip(*x))
    #                               )
    train_sampler = torch.utils.data.RandomSampler(train_ds)
    train_loader = dgl.dataloading.GraphDataLoader(train_ds,
                                                   batch_sampler=graph_sampler.GraphBatchSampler(train_sampler,
                                                                                                 batch_size=25000, #34000,
                                                                                                 block_list=train_ds.ds_block_sizes),
                                                   num_workers=10,)
    # test_loader = get_loader_for_train(test_ds, batch_sampler=graph_sampler.GraphBatchSampler(torch.utils.data.RandomSampler(test_ds)))

    train_loader4val = get_loader_for_eval(train_ds)
    test_loader4val = get_loader_for_eval(test_ds)




    # print(len(train_loader))



    # a = iter(train_ds)
    # a_ = next(a)
    # raise NotImplementedError

    input_dim = train_ds[0][1].ndata['pe0'].shape[1]
    print(input_dim)
    # raise NotImplementedError


    # f_name = './res/GNN-gat-run4-8000epoch-kl2_64dim.tar'
    # net, save = torch.load(f_name)
    net = net.to(device)

    # train_nmi, train_sc, train_loss = evaluate(train_loader4val, net)
    # test_nmi, test_sc, test_loss = evaluate(test_loader4val, net)
    # print(test_nmi, test_sc, '123456====')

    print(net)

    n_epoch = 300

    # optimizer = torch.optim.Adam(params=net.parameters(), lr=1e-3, amsgrad=True)
    lr = 1e-4
    optimizer = torch.optim.AdamW(params=net.parameters(), lr=lr, weight_decay=0)
    optimizer.param_groups[0]['initial_lr'] = lr
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epoch,
                                                           last_epoch=start_epoch, )
    loss_fn = torch.nn.MSELoss()
    # scheduler.step()
    print(scheduler.get_lr())
    records = {}
    for epoch in range(start_epoch, n_epoch):
        running_loss = 0
        net.train()
        batch_num_ = 0
        wait_start = time.perf_counter()
        for _, n_view_graphs in train_loader:
            batch_num_ += 1
            wait_end = time.perf_counter() - wait_start
            # print(f'receive data for {wait_end} s')
            optimizer.zero_grad()
            # inp = inp.to(torch.float32).to(device)
            # z = torch.from_numpy(np.concatenate(z)).to(torch.float32).to(device)
            # n_view_graphs = n_view_graphs.to(device)
            edge_only_graph = dgl.graph(n_view_graphs.edges(), num_nodes=n_view_graphs.num_nodes()).to(device)
            z = n_view_graphs.ndata['z'].to(torch.float32).to(device)
            batch_graph_size = n_view_graphs.batch_num_nodes()
            if torch.sum(batch_graph_size) > 40000:
                print(epoch, 'skipped')
                continue
            # print(f'epoch: {epoch}', len(batch_graph_size), torch.sum(batch_graph_size))

            batch_graph_size_cumsum = torch.cumsum(torch.cat([torch.zeros(1), batch_graph_size]), dim=0).long()
            zdist = torch.cat( [torch.pdist(z[batch_graph_size_cumsum[i]: batch_graph_size_cumsum[i+1]])
                                for i in range(n_view_graphs.batch_size) ])

            batch_zdist_size = (batch_graph_size*(batch_graph_size-1)/2).long().to(device)

            # print(xs)
            # xs = (torch.ones_like(xs[0]), )
            # xs = xs[0].to(torch.float32).to(device)  # only works for bs 1 so far
            # n_view_graphs = n_view_graphs[0].to(device)  # only works for bs 1 so far
            try:
                forward_start = time.perf_counter()
                loss, _ = net.loss_fn_kl_t_pdist(n_view_graphs, edge_only_graph, zdist, batch_zdist_size)

                # print(loss)ec
                # loss = net.loss_fn_kl2(xs, n_view_graphs, zdist)

                # z_hat = net(xs, n_view_graphs)
                # loss = loss_fn(z_hat, z)
                # loss = net.loss_fn_calib2(inp, y)
                loss.backward()
                optimizer.step()
                forward_end = time.perf_counter() - forward_start
                # print(f'forwad time: {forward_end} s')
            except torch.cuda.OutOfMemoryError:
                print('e', len(batch_graph_size), torch.sum(batch_graph_size**2))

            # del zdist
            # del xs
            # del n_view_graphs
            # zdist = zdist.to('cpu')  # only works for bs 1 so far
            # xs = xs.to('cpu')  # only works for bs 1 so far
            # n_view_graphs = n_view_graphs.to('cpu')  # only works for bs 1 so far
            # torch.cuda.empty_cache()
            running_loss += loss.item()  # .item() is important
            # print('wait for data')
            wait_start = time.perf_counter()

        scheduler.step()
        # print('+++++++++++++++++++++++++++')

        if (epoch + 1) % 5 == 0:
            train_nmi, train_sc, train_loss = evaluate(train_loader4val, net)
            test_nmi, test_sc, test_loss = evaluate(test_loader4val, net)
            # test_mse, test_mae, test_error = evaluate(test_loader)

            records['train_nmi'] = records.get('train_nmi', []) + [train_nmi]
            records['train_sc'] = records.get('train_sc', []) + [train_sc]
            records['train_loss'] = records.get('train_loss', []) + [train_loss]
            # records['val_mse'] = records.get('val_mse', []) + [val_mse.item()]
            # records['val_mae'] = records.get('val_mae', []) + [val_mae.item()]
            # records['val_error'] = records.get('val_error', []) + [val_error.item()]
            records['test_nmi'] = records.get('test_nmi', []) + [test_nmi]
            records['test_sc'] = records.get('test_sc', []) + [test_sc]
            records['test_loss'] = records.get('test_loss', []) + [test_loss]
            records['running_loss'] = records.get('running_loss', []) + [(running_loss / batch_num_)]

            print([(k, v[-1]) for k, v in records.items()])
        if (epoch + 1) % 5 == 0:
            records['train_names'] = train_names
            records['test_names'] = test_names
            torch.save((net, records), f'./res/GNN-gene_filtered-regin2-8gt-sigma_mv-run4-{n_epoch}epoch-kl_t_64dim_svdu_complete_G-filteredds-epoch-{epoch}.tar')

if __name__ == '__main__':
    main()