from functools import partial

import torch
import os

# import read_data
from .Datasets import L2OGraphDatasets
from .networks import L2OGWDMultiGINComplete
from .useful_utils import visualization_metric

from .useful_utils import train_utils

def evaluate(net, d_loader, optimizer):
    print(f'net.type = {type(net)}')
    net.eval()
    with torch.no_grad():
        scs = []
        nmis = []
        losses = []
        for xs, n_view_graphs, z, zdist, y in d_loader:
            optimizer.zero_grad()
            # inp = inp.to(torch.float32).to(device)
            # z = torch.from_numpy(np.concatenate(z)).to(torch.float32).to(device)

            z_hat = net(xs, n_view_graphs)
            # y_hat = y_hat.view(-1)
            zdist = zdist[0].to(torch.float32).to(net.device)
            student_t_loss = net.loss_fn_kl_t(xs, n_view_graphs, zdist)

            nmi, sc = visualization_metric.get_nmi_sc(z_hat.cpu().numpy(), y[0].tolist())

            scs.append(sc)
            nmis.append(nmi)
            losses.append(student_t_loss)

    scs_t = torch.mean(torch.Tensor(scs))
    nmis_t = torch.mean(torch.Tensor(nmis))
    losses_t = torch.mean(torch.Tensor(losses))
    net.train()
    return nmis_t, scs_t, losses_t

def fetch_data(dataloader):
    """
    create an infinite generator for the dataloader
    :param dataloader:
    :return:
    """
    while True:
        for graph_pes, graph_views, selected_emb, zdist, y in dataloader:
            yield graph_pes, graph_views

def main():
    device = 'cuda:0'

    # train_names = ['arrhythmia', 'wine', 'lympho', 'glass', 'vertebral',
    #                                    'wbc', 'ecoli', 'ionosphere', 'breastw', 'pima',
    #                                    'vowels',        'letter']
    # test_names = ['cardio',  'seismic', 'musk', 'speech', 'abalone']

    # dqiao comment out
    # train_names = ['mnist_group2', 'mnist_group1', 'fmnist_group2', 'fmnist_group1', 'cifar10_group1'] + [f'mnist_comb{i}' for i in range(252)]
    # # train_names = [f'mnist_comb{i}' for i in range(128, 252)]
    # test_names = ['cifar10_group2']

    train_names = ['cifar10_1class_comb0_seed0']
    test_names = ['cifar10_1class_comb1_seed0']

    z_cali = 'none'

    # precomputed_pe_path = None
    # # precomputed_pe_path = '../prepare_data/pe/pe_for_gat'
    #
    # train_ds = LargeCompleteMVGraphDatasets.DatasetGraphDataset(data_names=train_names, cdist_path='../prepare_data/clip/features',
    #                     visual_path='../prepare_data/bo/res-2',  normalize_z=True, z_cali_method=z_cali,
    #                                              # z_anchor=dummy_train_ds.z[0],
    #                                              precomputed_pe_path=precomputed_pe_path,
    #              z_mu=None, z_std=None)
    #
    #
    #
    # test_ds = LargeCompleteMVGraphDatasets.DatasetGraphDataset(data_names=test_names, cdist_path='../prepare_data/clip/features',
    #                                              visual_path='../prepare_data/bo/res-2',  normalize_z=True,
    #                                             z_cali_method=z_cali,
    #                                             # z_anchor=dummy_train_ds.z[0],
    #                                             precomputed_pe_path=precomputed_pe_path,
    #              z_mu=train_ds.z_mu, z_std=train_ds.z_std)

    precomputed_pe_path = 'pe_data/pe_for_gat'
    train_ds = L2OGraphDatasets.DatasetGraphDataset(data_names=train_names,
                                                                cdist_path='features',
                                                                visual_path='../prepare_data/bo/res-2',
                                                                normalize_z=True, z_cali_method=z_cali,
                                                                # z_anchor=dummy_train_ds.z[0],
                                                                precomputed_pe_path=precomputed_pe_path,
                                                                z_mu=None, z_std=None)

    test_ds = L2OGraphDatasets.DatasetGraphDataset(data_names=test_names,
                                                               cdist_path='features',
                                                               visual_path='../prepare_data/bo/res-2', normalize_z=True,
                                                               z_cali_method=z_cali,
                                                               # z_anchor=dummy_train_ds.z[0],
                                                               precomputed_pe_path=precomputed_pe_path,
                                                               z_mu=train_ds.z_mu, z_std=train_ds.z_std)



    get_loader = partial(torch.utils.data.DataLoader, batch_size=1,
                                               shuffle=False,
                                               num_workers=8,
                         collate_fn=lambda x: list(zip(*x))
                                               )

    train_loader = get_loader(train_ds)
    test_loader = get_loader(test_ds)

    train_iloader = fetch_data(train_loader)
    test_iloader = fetch_data(test_loader)

    # dqiao comment out this block because it is not used
    # train_nmi, train_sc = train_utils.get_true_metric(train_loader)
    # print(f'train_nmi: {train_nmi}, train_sc: {train_sc}')

    # test_nmi, test_sc = train_utils.get_true_metric(test_loader)
    # print(f'test_nmi: {test_nmi}, test_sc: {test_sc}')
    # # a = iter(train_ds)
    # # a_ = next(a)

    input_dim = train_ds[0][0][0].shape[1]
    print(input_dim)

    net = L2OGWDMultiGINComplete.L2OGWDMultiGINComplete(input_dim=input_dim, gnn_hidden=128, gnn_out=128, out_dim=2,
                                   n_graph_view=5, n_transformer=8, device=device)
    # f_name = './res/GNN-gat-run4-8000epoch-kl2_64dim.tar'
    # net, save = torch.load(f_name)
    net = net.to(device)

    # print(net)

    n_epoch = 500

    # optimizer = torch.optim.Adam(params=net.parameters(), lr=1e-3, amsgrad=True)
    optimizer = torch.optim.AdamW(params=net.parameters(), lr=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epoch)
    loss_fn = torch.nn.MSELoss()

    # infinite data loader
    

    records = {}
    # for epoch in range(n_epoch):
    #     running_loss = 0
    #     net.train()
        # for xs, n_view_graphs, z, zdist, _ in train_loader:
    t = 0
    running_loss = 0
    net.train()
    while t < n_epoch:
        optimizer.zero_grad()
        # inp = inp.to(torch.float32).to(device)
        # z = torch.from_numpy(np.concatenate(z)).to(torch.float32).to(device)

        # zdist = zdist[0].to(torch.float32).to(device)  # only works for bs 1 so far # dqiao comment out
        # print(xs)
        # xs = (torch.ones_like(xs[0]), )
        # xs = xs[0].to(torch.float32).to(device)  # only works for bs 1 so far
        # n_view_graphs = n_view_graphs[0].to(device)  # only works for bs 1 so far
        # loss = net.loss_fn_kl_t(xs, n_view_graphs, zdist) # dqiao comment out
        x1, graph1 = next(train_iloader)
        x2, graph2 = next(train_iloader)
        # sigma = [0.1, 0.5, 1, 2, 5]
        # x1 = ([x_pe_0, x_pe_1, x_pe_2, x_pe_3, x_pe_5],)
        # graph1 = ([graph_0, graph_1, graph_2, graph_3, graph_4],)
        loss = net.loss_gwd(x1, graph1, x2, graph2)

        # print(loss)
        # loss = net.loss_fn_kl2(xs, n_view_graphs, zdist)

        # z_hat = net(xs, n_view_graphs)
        # loss = loss_fn(z_hat, z)
        # loss = net.loss_fn_calib2(inp, y)
        loss.backward()
        optimizer.step()

        # del zdist
        # del xs
        # del n_view_graphs
        # zdist = zdist.to('cpu')  # only works for bs 1 so far
        # xs = xs.to('cpu')  # only works for bs 1 so far
        # n_view_graphs = n_view_graphs.to('cpu')  # only works for bs 1 so far
        # torch.cuda.empty_cache()
        running_loss += loss.item()  # .item() is important

        scheduler.step()
        # print('+++++++++++++++++++++++++++')
        print(f'epoch: {t}, [loss_ins loss_avg] = [{loss.item()} {running_loss / (t + 1)}]')
        t += 1

        # if (t + 1) % 10 == 0:
        #     train_nmi, train_sc, train_loss = evaluate(net, train_loader, optimizer)
        #     test_nmi, test_sc, test_loss = evaluate(net, test_loader, optimizer)
        #     # test_mse, test_mae, test_error = evaluate(test_loader)

        #     records['train_nmi'] = records.get('train_nmi', []) + [train_nmi]
        #     records['train_sc'] = records.get('train_sc', []) + [train_sc]
        #     records['train_loss'] = records.get('train_loss', []) + [train_loss]
        #     # records['val_mse'] = records.get('val_mse', []) + [val_mse.item()]
        #     # records['val_mae'] = records.get('val_mae', []) + [val_mae.item()]
        #     # records['val_error'] = records.get('val_error', []) + [val_error.item()]
        #     records['test_nmi'] = records.get('test_nmi', []) + [test_nmi]
        #     records['test_sc'] = records.get('test_sc', []) + [test_sc]
        #     records['test_loss'] = records.get('test_loss', []) + [test_loss]
        #     records['running_loss'] = records.get('running_loss', []) + [(running_loss / len(train_loader))]

        #     print([(k, v[-1]) for k, v in records.items()])
        # if (t + 1) % 100 == 0:
        #     torch.save((net, records), f'res/GNN-regin-8gt-sigma_mv-run4-{n_epoch}epoch-kl_t_64dim_lap_no_norm2_complete_G-252ds-epoch-{t}.tar')

if __name__ == '__main__':
    main()