import numpy as np
import torch
import torch.nn.functional as F
import os
from sklearn import metrics
from sklearn.utils import shuffle
from sklearn.metrics.cluster import normalized_mutual_info_score, homogeneity_score
from scgm_a.model_generator import MoEncoderGenerator
from utils.utils import get_training_dataloader_breeds, get_validation_dataloader_breeds, get_test_dataloader_breeds, adjust_learning_rate_cos, write_values
from time import time
from eval.eval_performance import classify, mean_confidence_interval
from vis import vis_tsne_multiclass_means_new
from sinkhornknopp import optimize_l_sk
# import resource

os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'

# rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
# resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))


if __name__ == '__main__':

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    arch = 'resnet50'
    head_type = 'seq_em'
    cst_dim = 128
    queue_k = 65536
    encoder_m = 0.999
    cls_t = 1.0
    cst_t = 0.2
    n_subclass = 100
    alpha = 0.5
    lmd = 25.0
    tau1 = 0.1
    tau2 = 1.0
    head_norm = False
    with_mlp = False
    queue_type = 'multi'
    metric_type = 'angular'
    calc_types = ['cls', 'cst_by_class', 'cst_by_subclass', 'cst_two_class']

    lr = 0.03
    epochs = 200
    n_class = 17
    batchsz = 256
    batchsz_eval = 128
    num_workers = 32
    num_cycles = 10
    set_cuda = True
    ds_name = 'living17'
    info_dir = '/nfs/data/usr/jni/datasets/imagenet_ilsvrc/ILSVRC/BREEDS/'
    data_dir = '/nfs/data/usr/jni/datasets/imagenet_ilsvrc/ILSVRC/Data/CLS-LOC/'

    breeds_training_loader = get_training_dataloader_breeds(
        ds_name=ds_name,
        info_dir=info_dir,
        data_dir=data_dir,
        batch_size=batchsz,
        num_workers=num_workers,
        shuffle=True,
        twocrops=True)

    breeds_validation_loader = get_validation_dataloader_breeds(
        ds_name=ds_name,
        info_dir=info_dir,
        data_dir=data_dir,
        batch_size=batchsz,
        num_workers=num_workers,
        shuffle=True)

    classifier = 'LR'
    n_test_runs = 100
    n_ways = 68
    n_shots = 1
    n_queries = 15
    n_aug_support_samples = 5
    norm = True

    breeds_test_loader = get_test_dataloader_breeds(
        ds_name=ds_name,
        info_dir=info_dir,
        data_dir=data_dir,
        n_test_runs=n_test_runs,
        n_ways=n_ways,
        n_shots=n_shots,
        n_queries=n_queries,
        n_aug_support_samples=n_aug_support_samples,
        batch_size=1,
        num_workers=0)

    n_tr = len(breeds_training_loader.dataset)
    n_va = len(breeds_validation_loader.dataset)
    iter_per_epoch_tr = len(breeds_training_loader)
    iter_per_epoch_va = len(breeds_validation_loader)

    print('dataset={:s}'.format(ds_name))
    print('training: size={:d},'.format(n_tr),
          'iter per epoch={:d} |'.format(iter_per_epoch_tr),
          'validation: size={:d},'.format(n_va),
          'iter per epoch={:d}'.format(iter_per_epoch_va))

    # model
    # ---
    model = MoEncoderGenerator().generate_ancor_model(arch=arch,
                                                      head_type=head_type,
                                                      dim=cst_dim,
                                                      K=queue_k,
                                                      m=encoder_m,
                                                      T=[cls_t, cst_t, tau1, tau2],
                                                      mlp=with_mlp,
                                                      num_classes=n_class,
                                                      num_subclasses=n_subclass,
                                                      norm=head_norm,
                                                      queue_type=queue_type,
                                                      metric=metric_type,
                                                      calc_types=calc_types)

    if set_cuda is True:
        model.to(device)
        model = torch.nn.DataParallel(model)
        # model = torch.nn.DataParallel(model, device_ids=[0, 1, 2, 3])

    opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)

    # training
    # ---
    ls_tr_all = []
    ls1_tr_all = []
    ls2_tr_all = []
    ls3_tr_all = []
    ls4_tr_all = []

    total_time = 0

    for epoch in range(1, (epochs + 1)):
        t0 = time()

        adjust_learning_rate_cos(opt, lr, (epoch - 1), epochs, num_cycles)

        print('epoch={:d}'.format(epoch),
              'learning rate={:.3f}'.format(opt.param_groups[0]['lr']))

        # training e step
        # ---
        model.train()

        if epoch % 5 == 1:
            prob_tr = []
            batch_idx = []

            with torch.no_grad():
                for (images, _, labels_coarse, selected) in breeds_training_loader:

                    images_q = images[0]
                    selected = selected.detach().cpu().numpy()

                    if set_cuda:
                        images_q = images_q.to(device)
                        labels_coarse = labels_coarse.to(device)

                    outputs, _ = model.module.encoder_q(images_q)
                    batch_prob_y_x, _, _ = model.module.forward_to_prob(outputs, labels_coarse, tau1)
                    prob_tr.append(batch_prob_y_x.detach().detach().cpu().numpy())

                    batch_idx.append(selected)

                prob_tr = np.concatenate(prob_tr, axis=0)  # (n, k)
                batch_idx = np.concatenate(batch_idx, axis=0)

                # run sinkhorn-knopp
                # ---
                _, argmax_q = optimize_l_sk(prob_tr, lmd)
                argmax_q_new = np.zeros(n_tr)  # (n, k)
                argmax_q_new[batch_idx] = argmax_q

        # training m step
        # ---
        ls_tr = 0
        ls1_tr = 0
        ls2_tr = 0
        ls3_tr = 0
        ls4_tr = 0
        cnt = 0
        correct_1 = 0
        correct_2 = 0
        x_tr_embed = []
        y_tr_embed_coarse = []
        y_tr_embed = []
        y_pred_tr_embed = []

        for (images, labels, labels_coarse, selected) in breeds_training_loader:

            images_q = images[0]
            images_k = images[1]

            selected = selected.detach().cpu().numpy()
            batch_argmax_q = argmax_q_new[selected]
            batch_argmax_q = torch.tensor(batch_argmax_q, dtype=torch.int64)

            if set_cuda:
                images_q = images_q.to(device)
                images_k = images_k.to(device)
                labels_coarse = labels_coarse.to(device)
                batch_argmax_q = batch_argmax_q.to(device)

            logits_and_labels = model(im_q=images_q, im_k=images_k, cls_labels=labels_coarse, subcls_labels=batch_argmax_q)

            ls1 = F.cross_entropy(logits_and_labels[0][0], logits_and_labels[0][1], weight=None)
            ls2 = F.cross_entropy(logits_and_labels[1][0], logits_and_labels[1][1], weight=None)
            ls3 = F.cross_entropy(logits_and_labels[2][0], logits_and_labels[2][1], weight=None)
            ls4 = F.cross_entropy(logits_and_labels[3][0], logits_and_labels[3][1], weight=None)
            ls = ls1 + ls2 + alpha * (ls3 + ls4)

            opt.zero_grad()
            ls.backward()
            opt.step()

            ls_tr += ls.data
            ls1_tr += ls1.data
            ls2_tr += ls2.data
            ls3_tr += ls3.data
            ls4_tr += ls4.data

            batch_embed, batch_pred_coarse_1 = model.module.encoder_q(images_q)
            batch_pred_coarse_1 = torch.softmax(batch_pred_coarse_1, dim=1)
            batch_pred_coarse_2, batch_pred, _ = model.module.pred(batch_embed, tau1)

            batch_embed = batch_embed.detach().cpu().numpy()
            batch_pred_coarse_1 = batch_pred_coarse_1.detach().cpu().numpy()
            batch_pred_coarse_2 = batch_pred_coarse_2.detach().cpu().numpy()
            batch_pred = batch_pred.detach().cpu().numpy()

            labels_coarse = labels_coarse.detach().cpu().numpy()
            labels_coarse = labels_coarse.astype(np.int64)
            labels = labels.detach().cpu().numpy()
            labels = labels.astype(np.int64)

            correct_1 += (batch_pred_coarse_1.argmax(1) == labels_coarse).sum()
            correct_2 += (batch_pred_coarse_2.argmax(1) == labels_coarse).sum()
            cnt += len(labels_coarse)

            x_tr_embed.append(batch_embed)
            y_tr_embed_coarse.append(labels_coarse)
            y_tr_embed.append(labels)
            y_pred_tr_embed.append(batch_pred)

        acc_tr_1 = correct_1 / cnt
        acc_tr_2 = correct_2 / cnt
        ls_tr = ls_tr.cpu().numpy() / iter_per_epoch_tr
        ls1_tr = ls1_tr.cpu().numpy() / iter_per_epoch_tr
        ls2_tr = ls2_tr.cpu().numpy() / iter_per_epoch_tr
        ls3_tr = ls3_tr.cpu().numpy() / iter_per_epoch_tr
        ls4_tr = ls4_tr.cpu().numpy() / iter_per_epoch_tr
        x_tr_embed = np.concatenate(x_tr_embed, axis=0)
        y_tr_embed_coarse = np.concatenate(y_tr_embed_coarse, axis=0)
        y_tr_embed = np.concatenate(y_tr_embed, axis=0)
        y_pred_tr_embed = np.concatenate(y_pred_tr_embed, axis=0)  # (n, k)

        ls_tr_all.append(ls_tr)
        ls1_tr_all.append(ls1_tr)
        ls2_tr_all.append(ls2_tr)
        ls3_tr_all.append(ls3_tr)
        ls4_tr_all.append(ls4_tr)

        nmi_score_tr = normalized_mutual_info_score(y_tr_embed, y_pred_tr_embed.argmax(1), average_method='arithmetic')
        acc_score_tr = homogeneity_score(y_tr_embed, y_pred_tr_embed.argmax(1))

        epoch_time = time() - t0
        total_time += epoch_time

        # validation
        # ---
        model.eval()
        cnt = 0
        correct_1 = 0
        correct_2 = 0
        x_va_embed = []
        y_va_embed_coarse = []
        y_va_embed = []
        y_pred_va_embed = []

        with torch.no_grad():
            for (images, labels, labels_coarse, selected) in breeds_validation_loader:

                images_q = images

                if set_cuda:
                    images_q = images_q.to(device)

                batch_embed, batch_pred_coarse_1 = model.module.encoder_q(images_q)
                batch_pred_coarse_1 = torch.softmax(batch_pred_coarse_1, dim=1)
                batch_pred_coarse_2, batch_pred, _ = model.module.pred(batch_embed, tau1)

                batch_embed = batch_embed.detach().cpu().numpy()
                batch_pred_coarse_1 = batch_pred_coarse_1.detach().cpu().numpy()
                batch_pred_coarse_2 = batch_pred_coarse_2.detach().cpu().numpy()
                batch_pred = batch_pred.detach().cpu().numpy()

                labels_coarse = labels_coarse.detach().cpu().numpy()
                labels_coarse = labels_coarse.astype(np.int64)
                labels = labels.detach().cpu().numpy()
                labels = labels.astype(np.int64)

                correct_1 += (batch_pred_coarse_1.argmax(1) == labels_coarse).sum()
                correct_2 += (batch_pred_coarse_2.argmax(1) == labels_coarse).sum()
                cnt += len(labels_coarse)

                x_va_embed.append(batch_embed)
                y_va_embed_coarse.append(labels_coarse)

                x_va_embed.append(batch_embed)
                y_va_embed_coarse.append(labels_coarse)
                y_va_embed.append(labels)
                y_pred_va_embed.append(batch_pred)

            acc_va_1 = correct_1 / cnt
            acc_va_2 = correct_2 / cnt
            x_va_embed = np.concatenate(x_va_embed, axis=0)
            y_va_embed_coarse = np.concatenate(y_va_embed_coarse, axis=0)
            y_va_embed = np.concatenate(y_va_embed, axis=0)
            y_pred_va_embed = np.concatenate(y_pred_va_embed, axis=0)  # (n, k)

            nmi_score_va = normalized_mutual_info_score(y_va_embed, y_pred_va_embed.argmax(1), average_method='arithmetic')
            acc_score_va = homogeneity_score(y_va_embed, y_pred_va_embed.argmax(1))

        print('training: epoch={:d}'.format(epoch),
              'loss={:.5f}'.format(ls_tr),
              'loss1={:.5f}'.format(ls1_tr),
              'loss2={:.5f}'.format(ls2_tr),
              'loss3={:.5f}'.format(ls3_tr),
              'loss4={:.5f}'.format(ls4_tr),
              'acc1={:.5f}'.format(acc_tr_1),
              'acc2={:.5f}'.format(acc_tr_2),
              'purity={:.5f}'.format(acc_score_tr),
              'nmi={:.5f}'.format(nmi_score_tr),
              '| validation: acc1={:.5f}'.format(acc_va_1),
              'acc2={:.5f}'.format(acc_va_2),
              'purity={:.5f}'.format(acc_score_va),
              'nmi={:.5f}'.format(nmi_score_va),
              'time={:.5f}'.format(time() - t0))

        # evaluation
        # ---
        if epoch % 50 == 0:

            with torch.no_grad():
                acc_te = []

                for (run_idx, batch_data) in enumerate(breeds_test_loader):

                    support_xs, support_ys, query_xs, query_ys = batch_data
                    support_xs = support_xs[0]
                    support_ys = support_ys[0]
                    query_xs = query_xs[0]
                    query_ys = query_ys[0]

                    # load support set embeddings
                    # ---
                    support_feats = []

                    if len(support_ys) > batchsz_eval:
                        loop_range = range(0, (len(support_ys) - batchsz_eval), batchsz_eval)
                    else:
                        loop_range = [0]

                    for i in loop_range:
                        if (len(support_ys) - i) < 2 * batchsz_eval:
                            batchsz_iter = len(support_ys) - i
                        else:
                            batchsz_iter = batchsz_eval

                        images_q = support_xs[i:(i + batchsz_iter)]
                        if set_cuda is True:
                            images_q = images_q.to(device)

                        outputs, _ = model.module.encoder_q(images_q)
                        if norm is True:
                            outputs = F.normalize(outputs, p=2, dim=1)

                        support_feats.append(outputs.detach().cpu().numpy())

                    support_feats = np.concatenate(support_feats, axis=0)

                    # load query set embeddings
                    # ---
                    query_feats = []

                    if len(query_ys) > batchsz_eval:
                        loop_range = range(0, (len(query_ys) - batchsz_eval), batchsz_eval)
                    else:
                        loop_range = [0]

                    for i in loop_range:
                        if (len(query_ys) - i) < 2 * batchsz_eval:
                            batchsz_iter = len(query_ys) - i
                        else:
                            batchsz_iter = batchsz_eval

                        images_q = query_xs[i:(i + batchsz_iter)]
                        if set_cuda is True:
                            images_q = images_q.to(device)

                        outputs, _ = model.module.encoder_q(images_q)
                        if norm is True:
                            outputs = F.normalize(outputs, p=2, dim=1)

                        query_feats.append(outputs.detach().cpu().numpy())

                    query_feats = np.concatenate(query_feats, axis=0)

                    # classification
                    # ---
                    clf = classify(classifier, support_feats, support_ys)
                    support_preds = clf.predict(support_feats)
                    query_preds = clf.predict(query_feats)

                    # evaluation
                    # ---
                    # acc_te_s = metrics.accuracy_score(support_ys, support_preds)
                    acc_te_q = metrics.accuracy_score(query_ys, query_preds)
                    acc_te.append(acc_te_q)

                    del clf

                acc_te_mn, acc_te_std = mean_confidence_interval(acc_te)
                print('accuracy={:.5f}'.format(acc_te_mn * 100),
                      'std={:.5f}'.format(acc_te_std * 100))

    print('total training time={:.5f}'.format(total_time))

    # save model
    # ---
    model_path = 'pretrain_model/scgm_a_' + ds_name + ' .pth'
    torch.save(model.module.state_dict(), model_path)

    # vis training embedding
    mu_z_tr = model.module.encoder_q.fc.mu_z.data.detach().cpu().numpy()
    mu_y_tr = model.module.encoder_q.fc.mu_y.data.detach().cpu().numpy()
    mu_z_tr = mu_z_tr / ((mu_z_tr ** 2).sum(1) ** 0.5).reshape(-1, 1)
    mu_y_tr = mu_y_tr / ((mu_y_tr ** 2).sum(1) ** 0.5).reshape(-1, 1)

    x_embed_vis, y_embed_vis = shuffle(x_tr_embed, y_tr_embed_coarse)
    x_embed_vis = x_embed_vis[:2000, :]
    y_embed_vis = y_embed_vis[:2000]
    x_embed_vis = x_embed_vis / ((x_embed_vis ** 2).sum(1) ** 0.5).reshape(-1, 1)

    destpath = '../fig/tsne_scgm_a_' + ds_name + '_tr.png'
    vis_tsne_multiclass_means_new(x_embed_vis, y_embed_vis, mu_z_tr, mu_y_tr, destpath, y_pred=None, destpath_correct=None)

    # vis validation embedding
    # ---
    x_embed_vis, y_embed_vis = shuffle(x_va_embed, y_va_embed_coarse)
    x_embed_vis = x_embed_vis[:2000, :]
    y_embed_vis = y_embed_vis[:2000]
    x_embed_vis = x_embed_vis / ((x_embed_vis ** 2).sum(1) ** 0.5).reshape(-1, 1)

    destpath = '../fig/tsne_scgm_a_' + ds_name + '_va.png'
    vis_tsne_multiclass_means_new(x_embed_vis, y_embed_vis, mu_z_tr, mu_y_tr, destpath, y_pred=None, destpath_correct=None)
