import numpy as np
import torch
import torch.nn.functional as F
import os
from sklearn import metrics
from sklearn.utils import shuffle
from sklearn.metrics.cluster import normalized_mutual_info_score, homogeneity_score
from scgm_g.scgm_resnet import resnet50
from utils.utils import get_training_dataloader_breeds, get_validation_dataloader_breeds, get_test_dataloader_breeds, adjust_learning_rate_cos
from time import time
from eval.eval_performance import classify, mean_confidence_interval
from vis import vis_tsne_multiclass_means_new
from sinkhornknopp import optimize_l_sk
# import resource

os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'

# rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
# resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))


if __name__ == '__main__':

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    lr = 0.03
    kd_t = 4.0
    epochs = 200
    k = 100
    n_class = 17
    batchsz = 256
    batchsz_eval = 256
    alpha = 0.5
    lmd = 25.0
    tau = 0.1
    hiddim = 128
    beta1 = 1.0
    beta2 = 1.0
    beta3 = 1.0
    num_workers = 32
    num_cycles = 10
    norm_type = 'logit'
    with_mlp = False
    set_cuda = True
    ds_name = 'living17'
    info_dir = '/nfs/data/usr/jni/datasets/imagenet_ilsvrc/ILSVRC/BREEDS/'
    data_dir = '/nfs/data/usr/jni/datasets/imagenet_ilsvrc/ILSVRC/Data/CLS-LOC/'

    breeds_training_loader = get_training_dataloader_breeds(
        ds_name=ds_name,
        info_dir=info_dir,
        data_dir=data_dir,
        batch_size=batchsz,
        num_workers=num_workers,
        shuffle=True,
        twocrops=False)

    breeds_validation_loader = get_validation_dataloader_breeds(
        ds_name=ds_name,
        info_dir=info_dir,
        data_dir=data_dir,
        batch_size=batchsz,
        num_workers=num_workers,
        shuffle=True)

    classifier = 'LR'
    n_test_runs = 100
    n_ways = 68
    n_shots = 1
    n_queries = 15
    n_aug_support_samples = 5
    norm = True

    breeds_test_loader = get_test_dataloader_breeds(
        ds_name=ds_name,
        info_dir=info_dir,
        data_dir=data_dir,
        n_test_runs=n_test_runs,
        n_ways=n_ways,
        n_shots=n_shots,
        n_queries=n_queries,
        n_aug_support_samples=n_aug_support_samples,
        batch_size=1,
        num_workers=0)

    # model
    # ---
    net = resnet50(num_classes=n_class, num_subclasses=k, kd_t=kd_t, hiddim=hiddim, with_mlp=with_mlp)

    if set_cuda is True:
        net.to(device)
        net = torch.nn.DataParallel(net)
        # net = torch.nn.DataParallel(net, device_ids=[0, 1, 2, 3])

    opt = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)
    iter_per_epoch_tr = len(breeds_training_loader)
    iter_per_epoch_va = len(breeds_validation_loader)

    # training
    # ---
    n_tr = len(breeds_training_loader.dataset)
    n_va = len(breeds_validation_loader.dataset)
    # opt_times = ((np.linspace(0, 1, nopts) ** 2)[::-1] * epochs).tolist()
    # opt_times[0] = opt_times[0] + 1
    # print('opt_times:', opt_times)

    ls_tr_all = []
    ls1_tr_all = []
    ls2_tr_all = []
    ls3_tr_all = []

    total_time = 0

    for epoch in range(1, (epochs + 1)):
        t0 = time()

        adjust_learning_rate_cos(opt, lr, (epoch - 1), epochs, num_cycles)

        print('epoch={:d}'.format(epoch),
              'learning rate={:.3f}'.format(opt.param_groups[0]['lr']))

        # training e step
        # ---
        net.train()

        if epoch % 5 == 1:
            # if epoch >= opt_times[-1]:

            # _ = opt_times.pop()
            prob_tr = []
            batch_idx = []

            with torch.no_grad():
                for (images, _, labels_coarse, selected) in breeds_training_loader:

                    selected = selected.detach().cpu().numpy()
                    labels_coarse = labels_coarse.detach().cpu().numpy()
                    labels_coarse = labels_coarse.astype(np.int64)

                    batch_y = np.zeros([len(labels_coarse), n_class])
                    batch_y[np.arange(len(labels_coarse)), labels_coarse] = 1
                    batch_y = torch.tensor(batch_y, dtype=torch.float32)  # (n, num_class)

                    if set_cuda:
                        images = images.to(device)
                        batch_y = batch_y.to(device)

                    outputs = net(images)
                    outputs = net.module.embed(outputs)
                    batch_prob_y_x, batch_prob_y_z, batch_prob_z_x = net.module.forward_to_prob(outputs, batch_y, tau)
                    prob_tr.append(batch_prob_y_x.detach().detach().cpu().numpy())

                    batch_idx.append(selected)

                prob_tr = np.concatenate(prob_tr, axis=0)  # (n, k)
                batch_idx = np.concatenate(batch_idx, axis=0)

                # run sinkhorn-knopp
                # ---
                q, argmax_q = optimize_l_sk(prob_tr, lmd)
                q_new = np.zeros((n_tr, k))  # (n, k)
                q_new[batch_idx, argmax_q] = 1

        # training m step
        # ---
        ls_tr = 0
        ls1_tr = 0
        ls2_tr = 0
        ls3_tr = 0
        ls_div1_tr = 0
        ls_div2_tr = 0
        ls_div3_tr = 0
        cnt = 0
        correct = 0
        x_tr_embed = []
        y_tr_embed = []
        y_tr_embed_coarse = []
        y_pred_tr_embed = []

        for (images, labels, labels_coarse, selected) in breeds_training_loader:

            selected = selected.detach().cpu().numpy()
            labels = labels.detach().cpu().numpy()
            labels_coarse = labels_coarse.detach().cpu().numpy()
            labels = labels.astype(np.int64)
            labels_coarse = labels_coarse.astype(np.int64)

            batch_y = np.zeros([len(labels_coarse), n_class])
            batch_y[np.arange(len(labels_coarse)), labels_coarse] = 1
            batch_y = torch.tensor(batch_y, dtype=torch.float32)  # (n, num_class)

            batch_q = q_new[selected, :]
            batch_q = torch.tensor(batch_q, dtype=torch.float32)  # (n, k)

            if set_cuda:
                images = images.to(device)
                batch_y = batch_y.to(device)
                batch_q = batch_q.to(device)

            outputs = net(images)
            outputs = net.module.embed(outputs)
            ls, ls1, ls2, ls3, ls_div1, ls_div2, ls_div3 = net.module.loss(outputs, batch_q, batch_y, tau, alpha, logit_t1=None, logit_t2=None, logit_t3=None, beta1=beta1, beta2=beta2, beta3=beta3, ang_norm=False, norm_type=norm_type)

            opt.zero_grad()
            ls.backward()
            opt.step()

            ls_tr += ls.data
            ls1_tr += ls1.data
            ls2_tr += ls2.data
            ls3_tr += ls3.data
            ls_div1_tr += ls_div1
            ls_div2_tr += ls_div2
            ls_div3_tr += ls_div3

            prob_y_x, prob_z_x, prob_y_z = net.module.pred(outputs, tau)

            outputs = outputs.detach().cpu().numpy()
            prob_y_x = prob_y_x.detach().cpu().numpy()
            prob_z_x = prob_z_x.detach().cpu().numpy()

            correct += (prob_y_x.argmax(1) == labels_coarse).sum()
            cnt += len(labels_coarse)
            # print('number={:d}'.format(cnt))

            x_tr_embed.append(outputs)
            y_tr_embed.append(labels)
            y_tr_embed_coarse.append(labels_coarse)
            y_pred_tr_embed.append(prob_z_x)

        acc_tr = correct / cnt
        ls_tr = ls_tr.cpu().numpy() / iter_per_epoch_tr
        ls1_tr = ls1_tr.cpu().numpy() / iter_per_epoch_tr
        ls2_tr = ls2_tr.cpu().numpy() / iter_per_epoch_tr
        ls3_tr = ls3_tr.cpu().numpy() / iter_per_epoch_tr
        ls_div1_tr = ls_div1_tr / iter_per_epoch_tr
        ls_div2_tr = ls_div2_tr / iter_per_epoch_tr
        ls_div3_tr = ls_div3_tr / iter_per_epoch_tr
        x_tr_embed = np.concatenate(x_tr_embed, axis=0)
        y_tr_embed = np.concatenate(y_tr_embed, axis=0)
        y_tr_embed_coarse = np.concatenate(y_tr_embed_coarse, axis=0)
        y_pred_tr_embed = np.concatenate(y_pred_tr_embed, axis=0)  # (n, k)

        ls_tr_all.append(ls_tr)
        ls1_tr_all.append(ls1_tr)
        ls2_tr_all.append(ls2_tr)
        ls3_tr_all.append(ls3_tr)

        nmi_score_tr = normalized_mutual_info_score(y_tr_embed, y_pred_tr_embed.argmax(1), average_method='arithmetic')
        acc_score_tr = homogeneity_score(y_tr_embed, y_pred_tr_embed.argmax(1))

        epoch_time = time() - t0
        total_time += epoch_time

        # validation
        # ---
        net.eval()

        cnt = 0
        correct = 0
        x_va_embed = []
        y_va_embed = []
        y_va_embed_coarse = []
        y_pred_va_embed = []

        with torch.no_grad():
            for (images, labels, labels_coarse, selected) in breeds_validation_loader:

                labels = labels.detach().cpu().numpy()
                labels_coarse = labels_coarse.detach().cpu().numpy()
                labels = labels.astype(np.int64)
                labels_coarse = labels_coarse.astype(np.int64)

                if set_cuda:
                    images = images.to(device)

                outputs = net(images)
                outputs = net.module.embed(outputs)
                prob_y_x, prob_z_x, prob_y_z = net.module.pred(outputs, tau)

                outputs = outputs.detach().cpu().numpy()
                prob_y_x = prob_y_x.detach().cpu().numpy()
                prob_z_x = prob_z_x.detach().cpu().numpy()

                correct += (prob_y_x.argmax(1) == labels_coarse).sum()
                cnt += len(labels_coarse)
                # print('number={:d}'.format(cnt))

                x_va_embed.append(outputs)
                y_va_embed.append(labels)
                y_va_embed_coarse.append(labels_coarse)
                y_pred_va_embed.append(prob_z_x)

            acc_va = correct / cnt
            x_va_embed = np.concatenate(x_va_embed, axis=0)
            y_va_embed = np.concatenate(y_va_embed, axis=0)
            y_va_embed_coarse = np.concatenate(y_va_embed_coarse, axis=0)
            y_pred_va_embed = np.concatenate(y_pred_va_embed, axis=0)  # (n, k)

            nmi_score_va = normalized_mutual_info_score(y_va_embed, y_pred_va_embed.argmax(1), average_method='arithmetic')
            acc_score_va = homogeneity_score(y_va_embed, y_pred_va_embed.argmax(1))

        print('training: epoch={:d}'.format(epoch),
              'loss={:.5f}'.format(ls_tr),
              'loss1={:.5f}'.format(ls1_tr),
              'loss2={:.5f}'.format(ls2_tr),
              'loss3={:.5f}'.format(ls3_tr),
              'loss_div1={:.5f}'.format(ls_div1_tr),
              'loss_div2={:.5f}'.format(ls_div2_tr),
              'loss_div3={:.5f}'.format(ls_div3_tr),
              'acc={:.5f}'.format(acc_tr),
              'purity={:.5f}'.format(acc_score_tr),
              'nmi={:.5f}'.format(nmi_score_tr),
              '| validation: acc={:.5f}'.format(acc_va),
              'purity={:.5f}'.format(acc_score_va),
              'nmi={:.5f}'.format(nmi_score_va),
              'time={:.5f}'.format(time() - t0))

        # evaluation
        # ---
        if epoch % 50 == 0:

            with torch.no_grad():
                acc_te = []

                for (run_idx, batch_data) in enumerate(breeds_test_loader):

                    support_xs, support_ys, query_xs, query_ys = batch_data
                    support_xs = support_xs[0]
                    support_ys = support_ys[0]
                    query_xs = query_xs[0]
                    query_ys = query_ys[0]

                    # load support set embeddings
                    # ---
                    support_feats = []

                    if len(support_ys) > batchsz_eval:
                        loop_range = range(0, (len(support_ys) - batchsz_eval), batchsz_eval)
                    else:
                        loop_range = [0]

                    for i in loop_range:
                        if (len(support_ys) - i) < 2 * batchsz_eval:
                            batchsz_iter = len(support_ys) - i
                        else:
                            batchsz_iter = batchsz_eval

                        batch_support_xs = support_xs[i:(i + batchsz_iter)]
                        if set_cuda is True:
                            batch_support_xs = batch_support_xs.to(device)

                        batch_support_xs = net(batch_support_xs)
                        batch_support_xs = net.module.embed(batch_support_xs)
                        if norm is True:
                            batch_support_xs = F.normalize(batch_support_xs, p=2, dim=1)

                        support_feats.append(batch_support_xs.detach().cpu().numpy())

                    support_feats = np.concatenate(support_feats, axis=0)

                    # load query set embeddings
                    # ---
                    query_feats = []

                    if len(query_ys) > batchsz_eval:
                        loop_range = range(0, (len(query_ys) - batchsz_eval), batchsz_eval)
                    else:
                        loop_range = [0]

                    for i in loop_range:
                        if (len(query_ys) - i) < 2 * batchsz_eval:
                            batchsz_iter = len(query_ys) - i
                        else:
                            batchsz_iter = batchsz_eval

                        batch_query_xs = query_xs[i:(i + batchsz_iter)]
                        if set_cuda is True:
                            batch_query_xs = batch_query_xs.to(device)

                        batch_query_xs = net(batch_query_xs)
                        batch_query_xs = net.module.embed(batch_query_xs)
                        if norm is True:
                            batch_query_xs = F.normalize(batch_query_xs, p=2, dim=1)

                        query_feats.append(batch_query_xs.detach().cpu().numpy())

                    query_feats = np.concatenate(query_feats, axis=0)

                    # classification
                    # ---
                    clf = classify(classifier, support_feats, support_ys)
                    support_preds = clf.predict(support_feats)
                    query_preds = clf.predict(query_feats)

                    # evaluation
                    # ---
                    # acc_te_s = metrics.accuracy_score(support_ys, support_preds)
                    acc_te_q = metrics.accuracy_score(query_ys, query_preds)
                    acc_te.append(acc_te_q)

                    del clf

                acc_te_mn, acc_te_std = mean_confidence_interval(acc_te)
                print('accuracy={:.5f}'.format(acc_te_mn * 100),
                      'std={:.5f}'.format(acc_te_std * 100))

    print('total training time={:.5f}'.format(total_time))

    # save model
    # ---
    model_path = 'pretrain_model/scgm_g_' + ds_name + '.pth'
    torch.save(net.module.state_dict(), model_path)

    # vis training embedding
    mu_z_tr = net.module.mu_z.data.detach().cpu().numpy()
    mu_y_tr = net.module.mu_y.data.detach().cpu().numpy()
    mu_z_tr = mu_z_tr / ((mu_z_tr ** 2).sum(1) ** 0.5).reshape(-1, 1)
    mu_y_tr = mu_y_tr / ((mu_y_tr ** 2).sum(1) ** 0.5).reshape(-1, 1)

    x_embed_vis, y_embed_vis = shuffle(x_tr_embed, y_tr_embed_coarse)
    x_embed_vis = x_embed_vis[:2000, :]
    y_embed_vis = y_embed_vis[:2000]
    x_embed_vis = x_embed_vis / ((x_embed_vis ** 2).sum(1) ** 0.5).reshape(-1, 1)

    destpath = 'fig/tsne_scgm_g_' + ds_name + '_tr.png'
    vis_tsne_multiclass_means_new(x_embed_vis, y_embed_vis, mu_z_tr, mu_y_tr, destpath, y_pred=None, destpath_correct=None)

    # vis validation embedding
    # ---
    x_embed_vis, y_embed_vis = shuffle(x_va_embed, y_va_embed_coarse)
    x_embed_vis = x_embed_vis[:2000, :]
    y_embed_vis = y_embed_vis[:2000]
    x_embed_vis = x_embed_vis / ((x_embed_vis ** 2).sum(1) ** 0.5).reshape(-1, 1)

    destpath = 'fig/tsne_scgm_g_' + ds_name + '_va.png'
    vis_tsne_multiclass_means_new(x_embed_vis, y_embed_vis, mu_z_tr, mu_y_tr, destpath, y_pred=None, destpath_correct=None)
