import numpy as np
import torch
import logging
import losses
import json
from tqdm import tqdm
import torch.nn.functional as F
import math
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
import torch.utils.data as data
from sklearn.cluster import KMeans
from sklearn.utils.linear_assignment_ import linear_assignment
from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score
from sklearn.metrics import adjusted_rand_score as ari_score
import time
from functools import wraps
from torch.utils.data import  Dataset
from torchvision import datasets, transforms

import matplotlib as mpl
import matplotlib.pyplot as plt
import os
from sklearn.manifold import TSNE
from PIL import ImageFilter
import random



def plot_TSNE(model, criterion, gallery_loader, args, epoch, mAP):
    """
    seed = 1
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    """

    model = model.cuda()
    criterion = criterion.cuda()
    model.eval()
    criterion.eval()

    # train_X, _, train_T = predict_batchwise(model, train_loader)
    gallery_X, gallery_T = predict_batchwise(model, gallery_loader)

    # sampled_train_X = train_X[::100]
    # sampled_train_T = train_T[::100]
    # db_size = sampled_train_X.size(0)

    sampled_gallery_X = gallery_X[::3]
    sampled_gallery_T = gallery_T[::3]

    # sampled_train_X = sampled_train_X.view(-1, args.num_partitionings, args.sz_embedding)
    sampled_gallery_X = sampled_gallery_X.view(-1, args.num_partitionings, args.sz_embedding)

    # quantized_sampled_train_X = criterion.SoftAssignment(criterion.proxies, sampled_train_X, args.num_partitionings, zeta=20)
    quantized_sampled_gallery_X = criterion.SoftAssignment(criterion.proxies, sampled_gallery_X, args.num_partitionings,
                                                           zeta=20)

    meta_proxies = criterion.proxies.data
    global_proxies = criterion.normsoftmax.proxies.data

    feats = torch.cat([quantized_sampled_gallery_X, ]).detach().cpu().numpy()
    reduced_feats = TSNE(learning_rate=300, perplexity=26, metric='cosine').fit_transform(feats)

    spl_feats = torch.cat([sampled_gallery_X, ]).cpu().numpy()
    spl_feats = spl_feats.reshape((-1, args.num_partitionings, args.sz_embedding))

    N = 10
    nb_classes = int(N * args.seen_rate)
    # define the colormap
    cmap = plt.cm.jet
    # extract all colors from the .jet map
    cmaplist = [cmap(i) for i in range(cmap.N)]
    # create the new map
    cmap = cmap.from_list('Custom cmap', cmaplist, cmap.N)

    # define the bins and normalize
    bounds = np.linspace(0, N, N + 1)
    norm = mpl.colors.BoundaryNorm(bounds, cmap.N)

    seen_idx = (sampled_gallery_T < nb_classes).nonzero().squeeze()
    unseen_idx = (sampled_gallery_T >= nb_classes).nonzero().squeeze()
    seen_gallery_T = sampled_gallery_T[seen_idx]
    unseen_gallery_T = sampled_gallery_T[unseen_idx]

    partitionings_candidate = torch.load(args.partitionings_path)
    partitionings = partitionings_candidate[0 * args.num_partitionings:(0 + 1) * args.num_partitionings].t()
    # train_meta_labels = LabelToMetaLabel(partitionings, sampled_train_T)
    gallery_meta_labels = LabelToMetaLabel(partitionings, seen_gallery_T)

    save_path = args.LOG_DIR + '/{}/{}'.format(args.project_name, args.name)
    title = '{} epoch:{} mAP:{:.4}'.format(args.name, epoch + 1, mAP)

    (fig, subplots) = plt.subplots(1, 1, figsize=(28, 28), constrained_layout=True)

    ax = subplots
    ax.set_title(title, fontsize=50)
    ax.axis('off')
    term = int(sampled_gallery_T.shape[0] / N)
    color_lst = ['tab:blue', '#89ABE3FF', 'tab:cyan', 'tab:olive', '#435E55FF', '#2BAE66FF', '#A2A2A1FF', '#F95700FF',
                 'tab:purple', 'tab:red']
    plots = {}
    for i in range(nb_classes):
        p = ax.scatter(reduced_feats[(i * term): ((i + 1) * term), 0], reduced_feats[(i * term): ((i + 1) * term), 1],
                       c=color_lst[i], cmap=cmap, alpha=.6, norm=norm, s=300)
        plots[i] = p
    for i in range(nb_classes, 10):
        p = ax.scatter(reduced_feats[(i * term):  ((i + 1) * term), 0], reduced_feats[(i * term):((i + 1) * term), 1],
                       c=color_lst[i], cmap=cmap, alpha=.5, norm=norm, marker='x', s=600, linewidth=4)
        plots[i] = p
    if not os.path.isdir(save_path):
        os.makedirs(save_path)
    plt.savefig(os.path.join(save_path, 'embeddings_%d_epochs.png' % (epoch + 1)), bbox_inches="tight", pad_inches=0.0)
    return plt


class GaussianBlur(object):
    """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""

    def __init__(self, sigma=[.1, 2.]):
        self.sigma = sigma

    def __call__(self, x):
        sigma = random.uniform(self.sigma[0], self.sigma[1])
        x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
        return x



def sigmoid_rampup(current, rampup_length):
    """Exponential rampup from https://arxiv.org/abs/1610.02242"""
    if rampup_length == 0:
        return 1.0
    else:
        current = np.clip(current, 0.0, rampup_length)
        phase = 1.0 - current / rampup_length
        return float(np.exp(-5.0 * phase * phase))


def linear_rampup(current, rampup_length):
    """Linear rampup"""
    assert current >= 0 and rampup_length >= 0
    if current >= rampup_length:
        return 1.0
    else:
        return current / rampup_length


def cosine_rampdown(current, rampdown_length):
    """Cosine rampdown from https://arxiv.org/abs/1608.03983"""
    assert 0 <= current <= rampdown_length
    return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1))



class DatasetSplitMultiView(Dataset):
    """An abstract Dataset class wrapped around Pytorch Dataset class.
    """

    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = [int(i) for i in idxs]

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        (view1, view2), label = self.dataset[self.idxs[item]]
        return torch.tensor(view1), torch.tensor(view2), torch.tensor(label)


class MultiViewDataInjector(object):
    def __init__(self, *args):
        self.transforms = args[0]
        self.random_flip = transforms.RandomHorizontalFlip()

    def __call__(self, sample, *with_consistent_flipping):
        if with_consistent_flipping:
            sample = self.random_flip(sample)
        output = [transform(sample) for transform in self.transforms]
        return output

def test(model, criterion, test_loader, args):
    model.eval()
    criterion.eval()
    targets=np.array([])
    feats = []
    for batch_idx, (x, label) in enumerate(tqdm(test_loader)):
        x, label = x.cuda(), label.cuda()
        with torch.no_grad():
            feat = model(x)
        targets=np.append(targets, label.cpu().numpy())
        feats.append(feat)
    feats = torch.cat(feats, dim=0)
    feats = F.normalize(feats.view(-1, args.num_partitionings, args.sz_embedding), dim=2)
    feats = feats.view(-1, args.num_partitionings * args.sz_embedding)

    # Evaluate by metric learning eval metric
    kmeans = KMeans(n_clusters=len(np.unique(targets.flatten())))
    kmeans.fit(feats.cpu().detach().numpy())
    clustered_preds = np.array(kmeans.labels_)
    acc, nmi, ari = cluster_acc(targets.astype(int), clustered_preds.astype(int)), nmi_score(targets, clustered_preds), ari_score(targets,
                                                                                                              clustered_preds)
    print('Not Quantized Kmeans Test acc {:.4f}, nmi {:.4f}, ari {:.4f}'.format(acc, nmi, ari))


    descriptor = criterion.SoftAssignment(criterion.proxies, feats, args.num_partitionings, zeta=20)
    # Evaluate by metric learning eval metric
    kmeans = KMeans(n_clusters=len(np.unique(targets.flatten())))
    kmeans.fit(descriptor.cpu().detach().numpy())
    clustered_preds = np.array(kmeans.labels_)
    Q_acc, Q_nmi, Q_ari = cluster_acc(targets.astype(int), clustered_preds.astype(int)), nmi_score(targets,
                                                                                             clustered_preds), ari_score(
        targets,
        clustered_preds)
    print('Quantized Kmeans Test acc {:.4f}, nmi {:.4f}, ari {:.4f}'.format(Q_acc, Q_nmi, Q_ari))

    return Q_acc, Q_nmi, Q_ari


def test2(model, criterion, test_loader, args):
    model.eval()
    targets=np.array([])
    feats = []
    for batch_idx, (x, label) in enumerate(tqdm(test_loader)):
        x, label = x.cuda(), label.cuda()
        feat = model(x)
        targets=np.append(targets, label.cpu().numpy())
        feats.append(feat)
    feats = torch.cat(feats, dim=0)
    feats = F.normalize(feats.view(-1, args.num_partitionings, args.sz_embedding), dim=2)
    feats = feats.view(-1, args.num_partitionings * args.sz_embedding)

    descriptor = criterion.SoftAssignment(criterion.proxies, feats, args.num_partitionings, zeta=20)
    # Evaluate by metric learning eval metric
    kmeans = KMeans(n_clusters=len(np.unique(targets.flatten())))
    kmeans.fit(descriptor.cpu().detach().numpy())
    clustered_preds = np.array(kmeans.labels_)

    #NMI_split(targets, clustered_preds)

    seen_idxs = targets < args.nb_classes
    seen_targets = targets[seen_idxs]
    seen_clustered_preds = clustered_preds[seen_idxs]
    unseen_idxs = targets >= args.nb_classes
    unseen_targets = targets[unseen_idxs]
    unseen_clustered_preds = clustered_preds[unseen_idxs]

    seen_Q_nmi, seen_Q_ari = nmi_score(seen_targets, seen_clustered_preds), ari_score(seen_targets,
                                                                                        seen_clustered_preds)
    unseen_Q_nmi, unseen_Q_ari = nmi_score(unseen_targets, unseen_clustered_preds), ari_score(unseen_targets,
                                                                                        unseen_clustered_preds)

    seen_acc, unseen_acc = cluster_acc(seen_targets.astype(int), seen_clustered_preds.astype(int)), cluster_acc(
        unseen_targets.astype(int), unseen_clustered_preds.astype(int))

    print('Kmeans Test seen acc {:.4f} nmi {:.4f}, ari {:.4f}'.format(seen_acc, seen_Q_nmi, seen_Q_ari))

    print('Kmeans Test unseen acc {:.4f} nmi {:.4f}, ari {:.4f}'.format(unseen_acc, unseen_Q_nmi, unseen_Q_ari))

    return (seen_acc, unseen_acc), (seen_Q_nmi, unseen_Q_nmi), (seen_Q_ari, unseen_Q_ari)


def test3(model, criterion, test_loader, args):
    model.eval()
    targets=np.array([])
    feats = []
    for batch_idx, (x, label) in enumerate(tqdm(test_loader)):
        x, label = x.cuda(), label.cuda()
        feat = model(x)
        targets=np.append(targets, label.cpu().numpy())
        feats.append(feat)
    feats = torch.cat(feats, dim=0)
    feats = F.normalize(feats.view(-1, args.num_partitionings, args.sz_embedding), dim=2)
    feats = feats.view(-1, args.num_partitionings * args.sz_embedding)

    descriptor = criterion.SoftAssignment(criterion.proxies, feats, args.num_partitionings, zeta=20)
    global_proxies = []
    for i in criterion.partitionings:
        global_proxies.append(criterion.proxies.data[torch.tensor(range(24)), i, :].view(-1))

    global_proxies = torch.stack(global_proxies)


    print('되라 제발')
    idxs = targets == 9
    unseen_descriptors = descriptor[idxs]
    # Evaluate by metric learning eval metric
    unseen_descriptors = F.normalize(unseen_descriptors, p=2, dim=1)
    global_proxies = F.normalize(global_proxies, p=2, dim=1)
    result = torch.matmul(unseen_descriptors, global_proxies.t())

    argmaxes = torch.argmax(result, dim=1)




    kmeans = KMeans(n_clusters=len(np.unique(targets.flatten())))
    kmeans.fit(descriptor.cpu().detach().numpy())
    clustered_preds = np.array(kmeans.labels_)

    #NMI_split(targets, clustered_preds)

    seen_idxs = targets < 7
    seen_targets = targets[seen_idxs]
    seen_clustered_preds = clustered_preds[seen_idxs]
    unseen_idxs = targets > 7
    unseen_targets = targets[unseen_idxs]
    unseen_clustered_preds = clustered_preds[unseen_idxs]

    seen_Q_nmi, seen_Q_ari = nmi_score(seen_targets, seen_clustered_preds), ari_score(seen_targets,
                                                                                        seen_clustered_preds)
    unseen_Q_nmi, unseen_Q_ari = nmi_score(unseen_targets, unseen_clustered_preds), ari_score(unseen_targets,
                                                                                        unseen_clustered_preds)


    print('Quantized Kmeans Test seen nmi {:.4f}, ari {:.4f}'.format(seen_Q_nmi, seen_Q_ari))

    print('Quantized Kmeans Test unseen nmi {:.4f}, ari {:.4f}'.format(unseen_Q_nmi, unseen_Q_ari))


from sklearn.metrics.cluster._supervised import check_clusterings, contingency_matrix, mutual_info_score, entropy, _generalized_average
from sklearn.utils.fixes import _astype_copy_false

def NMI_split(labels_true, labels_pred, average_method='arithmetic'):
    labels_true, labels_pred = check_clusterings(labels_true, labels_pred)
    classes = np.unique(labels_true)
    clusters = np.unique(labels_pred)
    # Special limit cases: no clustering since the data is not split.
    # This is a perfect match hence return 1.0.
    if (classes.shape[0] == clusters.shape[0] == 1 or
            classes.shape[0] == clusters.shape[0] == 0):
        return 1.0
    contingency = contingency_matrix(labels_true, labels_pred, sparse=False)
    contingency = contingency.astype(np.float64,
                                     **_astype_copy_false(contingency))

    seen_idxs = labels_true < 75
    seen_targets = labels_true[seen_idxs]
    seen_clustered_preds = labels_pred[seen_idxs]
    unseen_idxs = labels_true > 75
    unseen_targets = labels_true[unseen_idxs]
    unseen_clustered_preds = labels_pred[unseen_idxs]


    # Calculate the MI for the two clusterings
    mi = mutual_info_score(seen_targets, seen_clustered_preds,
                           contingency=contingency[:75])
    # Calculate the expected value for the mutual information
    # Calculate entropy for each labeling
    h_true, h_pred = entropy2(seen_targets, labels_true.size), entropy2(seen_clustered_preds, labels_true.size)
    normalizer = _generalized_average(h_true, h_pred, average_method)
    # Avoid 0.0 / 0.0 when either entropy is zero.
    normalizer = max(normalizer, np.finfo('float64').eps)
    seen_nmi = mi / normalizer

    print('Seen NMI', seen_nmi)

    ## UnSeen

    # Calculate the MI for the two clusterings
    mi = mutual_info_score(unseen_targets, unseen_clustered_preds,
                           contingency=contingency[75:])
    # Calculate the expected value for the mutual information
    # Calculate entropy for each labeling
    h_true, h_pred = entropy2(seen_targets, labels_true.size), entropy2(seen_clustered_preds, labels_true.size)
    normalizer = _generalized_average(h_true, h_pred, average_method)
    # Avoid 0.0 / 0.0 when either entropy is zero.
    normalizer = max(normalizer, np.finfo('float64').eps)
    unseen_nmi = mi / normalizer

    print('UnSeen NMI', unseen_nmi)

from math import log

def entropy2(labels, total):
    """Calculates the entropy for a labeling.

    Parameters
    ----------
    labels : int array, shape = [n_samples]
        The labels

    Notes
    -----
    The logarithm used is the natural logarithm (base-e).
    """
    if len(labels) == 0:
        return 1.0
    label_idx = np.unique(labels, return_inverse=True)[1]
    pi = np.bincount(label_idx).astype(np.float64)
    pi = pi[pi > 0]
    pi_sum = np.sum(pi)
    # log(a / b) should be calculated as log(a) - log(b) for
    # possible loss of precision
    return -np.sum((pi / total) * (np.log(pi) - log(total)))



def test_split(model, criterion, test_loader, args):
    model.eval()
    targets=np.array([])
    feats = []
    for batch_idx, (x, label) in enumerate(tqdm(test_loader)):
        x, label = x.cuda(), label.cuda()
        feat = model(x)
        targets=np.append(targets, label.cpu().numpy())
        feats.append(feat)
    feats = torch.cat(feats, dim=0)
    feats = F.normalize(feats.view(-1, args.num_partitionings, args.sz_embedding), dim=2)
    feats = feats.view(-1, args.num_partitionings * args.sz_embedding)

    # Evaluate by metric learning eval metric
    kmeans = KMeans(n_clusters=len(np.unique(targets.flatten())))
    kmeans.fit(feats.cpu().detach().numpy())
    clustered_preds = np.array(kmeans.labels_)
    acc, nmi, ari = cluster_acc(targets.astype(int), clustered_preds.astype(int)), nmi_score(targets, clustered_preds), ari_score(targets,
                                                                                                              clustered_preds)
    print('Not Quantized Kmeans Test acc {:.4f}, nmi {:.4f}, ari {:.4f}'.format(acc, nmi, ari))


    descriptor = criterion.SoftAssignment(criterion.proxies, feats, args.num_partitionings, zeta=20)
    # Evaluate by metric learning eval metric
    kmeans = KMeans(n_clusters=len(np.unique(targets.flatten())))
    kmeans.fit(descriptor.cpu().detach().numpy())
    clustered_preds = np.array(kmeans.labels_)
    Q_acc, Q_seen_acc, Q_unseen_acc = cluster_acc_split(targets.astype(int), clustered_preds.astype(int), seen=args.nb_classes)


    return Q_acc, Q_seen_acc, Q_unseen_acc

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def test_acc(model, criterion, classifier, test_loader, args):
    model.eval()
    criterion.eval()
    classifier.eval()
    acc_record = AverageMeter()
    targets=np.array([])
    feats = []
    for batch_idx, (x, label) in enumerate(tqdm(test_loader)):
        x, label = x.cuda(), label.cuda()
        feat = model(x)
        feat = F.normalize(feat.view(-1, args.num_partitionings, args.sz_embedding), dim=2)
        feat = feat.view(-1, args.num_partitionings * args.sz_embedding)
        descriptor = criterion.SoftAssignment(criterion.proxies, feat, args.num_partitionings, zeta=20)
        output = classifier(descriptor)

        acc = accuracy(output, label)
        acc_record.update(acc[0].item(), x.size(0))

    print('Test: Avg Acc: {:.4f}'.format(acc_record.avg))


class TransformTwice:
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, inp):
        out1 = self.transform(inp)
        out2 = self.transform(inp)
        return out1, out2

def l2_norm(input):
    input_size = input.size()
    buffer = torch.pow(input, 2)
    normp = torch.sum(buffer, 1).add_(1e-12)
    norm = torch.sqrt(normp)
    _output = torch.div(input, norm.view(-1, 1).expand_as(input))
    output = _output.view(input_size)

    return output

def calc_recall_at_k(T, Y, k):
    """
    T : [nb_samples] (target labels)
    Y : [nb_samples x k] (k predicted labels/neighbours)
    """

    s = 0
    for t,y in zip(T,Y):
        if t in torch.Tensor(y).long()[:k]:
            s += 1
    return s / (1. * len(T))


def predict_batchwise(model, dataloader):
    device = "cuda"
    model_is_training = model.training
    model.eval()
    
    ds = dataloader.dataset
    A = [[] for i in range(len(ds[0]))]
    with torch.no_grad():
        # extract batches (A becomes list of samples)
        for batch in tqdm(dataloader):
            for i, J in enumerate(batch):
                # i = 0: sz_batch * images
                # i = 1: sz_batch * labels
                # i = 2: sz_batch * indices
                if i == 0:
                    # move images to device of model (approximate device)
                    with torch.no_grad():
                        J = model(J.cuda())

                for j in J:
                    A[i].append(j)
    #model.train()
    #model.train(model_is_training) # revert to previous training state
    
    return [torch.stack(A[i]) for i in range(len(A))]


def predict_batchwise2(model, dataloader):
    device = "cuda"
    model_is_training = model.training
    model.eval()

    ds = dataloader.dataset
    A = [[] for i in range(len(ds[0]))]
    with torch.no_grad():
        # extract batches (A becomes list of samples)
        for batch in tqdm(dataloader):

            for i, J in enumerate(batch):
                # i = 0: sz_batch * images
                # i = 1: sz_batch * labels
                # i = 2: sz_batch * indices
                if i == 0:
                    # move images to device of model (approximate device)
                    J = model(J.cuda())

                for j in J:
                    A[i].append(j)
    # model.train()
    # model.train(model_is_training) # revert to previous training state

    return [torch.stack(A[i]) for i in range(len(A))]


def predict_batchwise_nus3(model, dataloader):
    device = "cuda"
    model_is_training = model.training
    model.eval()

    ds = dataloader.dataset
    A = [[] for i in range(len(ds[0]))]
    with torch.no_grad():
        # extract batches (A becomes list of samples)
        for batch in tqdm(dataloader):
            for i, J in enumerate(batch):
                # i = 0: sz_batch * images
                # i = 1: sz_batch * labels
                # i = 2: sz_batch * indices
                if i == 0:
                    # move images to device of model (approximate device)
                    with  torch.no_grad():
                        J = model(J.cuda())

                for j in J:
                    A[i].append(j)
    model.train()
    model.train(model_is_training)  # revert to previous training state

    return [torch.stack(A[i]) for i in range(len(A))]

def proxy_init_calc(model, dataloader):
    nb_classes = dataloader.dataset.nb_classes()
    X, T, *_ = predict_batchwise(model, dataloader)

    proxy_mean = torch.stack([X[T==class_idx].mean(0) for class_idx in range(nb_classes)])

    return proxy_mean

def evaluate_cos(model, dataloader):
    #nb_classes = dataloader.dataset.nb_classes()

    # calculate embeddings with model and get targets
    X, T = predict_batchwise(model, dataloader)
    X = l2_norm(X)

    # get predictions by assigning nearest 8 neighbors with cosine
    K = 32
    Y = []
    xs = []
    
    cos_sim = F.linear(X, X)
    Y = T[cos_sim.topk(1 + K)[1][:,1:]]
    Y = Y.float().cpu()

    calculator = AccuracyCalculator(include=(), exclude=(['AMI', 'NMI']), avg_of_avgs=False, k=None)
    q = X.cpu().numpy()
    label = T.cpu().numpy()
    result = calculator.get_accuracy(q, q, label, label, embeddings_come_from_same_source=False)

    r_precision = result['r_precision']
    map_at_r = result['mean_average_precision_at_r']
    print("R_Precision : {:.4f}, MAP@R : {:.4f}".format(r_precision, map_at_r))

    recall = []
    for k in [1, 2, 4, 8, 16, 32]:
        r_at_k = calc_recall_at_k(T, Y, k)
        recall.append(r_at_k)
        print("R@{} : {:.3f}".format(k, 100 * r_at_k))

    return recall, r_precision, map_at_r



def evaluate_cos_Inshop(model, query_dataloader, gallery_dataloader):
    nb_classes = query_dataloader.dataset.nb_classes()
    
    # calculate embeddings with model and get targets
    query_X, query_T = predict_batchwise(model, query_dataloader)
    gallery_X, gallery_T = predict_batchwise(model, gallery_dataloader)
    
    query_X = l2_norm(query_X)
    gallery_X = l2_norm(gallery_X)
    
    # get predictions by assigning nearest 8 neighbors with cosine
    K = 50
    Y = []
    xs = []
    
    cos_sim = F.linear(query_X, gallery_X)

    def recall_k(cos_sim, query_T, gallery_T, k):
        m = len(cos_sim)
        match_counter = 0

        for i in range(m):
            pos_sim = cos_sim[i][gallery_T == query_T[i]]
            neg_sim = cos_sim[i][gallery_T != query_T[i]]

            thresh = torch.max(pos_sim).item()

            if torch.sum(neg_sim > thresh) < k:
                match_counter += 1
            
        return match_counter / m
    
    # calculate recall @ 1, 2, 4, 8
    recall = []
    for k in [1, 10, 20, 30, 40, 50]:
        r_at_k = recall_k(cos_sim, query_T, gallery_T, k)
        recall.append(r_at_k)
        print("R@{} : {:.3f}".format(k, 100 * r_at_k))
                
    return recall

def evaluate_cos_SOP(model, dataloader):
    nb_classes = dataloader.dataset.nb_classes()
    
    # calculate embeddings with model and get targets
    X, T = predict_batchwise(model, dataloader)
    X = l2_norm(X)
    
    # get predictions by assigning nearest 8 neighbors with cosine
    K = 1000
    Y = []
    xs = []
    for x in X:
        if len(xs)<10000:
            xs.append(x)
        else:
            xs.append(x)            
            xs = torch.stack(xs,dim=0)
            cos_sim = F.linear(xs,X)
            y = T[cos_sim.topk(1 + K)[1][:,1:]]
            Y.append(y.float().cpu())
            xs = []
            
    # Last Loop
    xs = torch.stack(xs,dim=0)
    cos_sim = F.linear(xs,X)
    y = T[cos_sim.topk(1 + K)[1][:,1:]]
    Y.append(y.float().cpu())
    Y = torch.cat(Y, dim=0)

    calculator = AccuracyCalculator(include=(), exclude=(['AMI', 'NMI']), avg_of_avgs=False, k=None)
    q = X.cpu().numpy()
    label = T.cpu().numpy()
    result = calculator.get_accuracy(q, q, label, label, embeddings_come_from_same_source=False)

    r_precision = result['r_precision']
    map_at_r = result['mean_average_precision_at_r']
    print("R_Precision : {:.4f}, MAP@R : {:.4f}".format(r_precision, map_at_r))

    # calculate recall @ 1, 2, 4, 8
    recall = []
    for k in [1, 10, 100, 1000]:
        r_at_k = calc_recall_at_k(T, Y, k)
        recall.append(r_at_k)
        print("R@{} : {:.3f}".format(k, 100 * r_at_k))
    return recall, r_precision, map_at_r


def evaluate(args, model, z, query_dataloader, gallery_dataloader, label_similarity):
    nb_classes = query_dataloader.dataset.nb_classes()

    # calculate embeddings with model and get targets
    query_X, query_T = predict_batchwise(model, query_dataloader)
    gallery_X, gallery_T = predict_batchwise(model, gallery_dataloader)

    query_X = F.normalize(query_X, dim=1)
    gallery_X = F.normalize(gallery_X, dim=1)

    # get predictions by assigning nearest 8 neighbors with cosine
    K = 50
    Y = []
    xs = []

    cos_sim = F.linear(query_X, gallery_X)

    def recall_k(cos_sim, query_T, gallery_T, k):
        m = len(cos_sim)
        match_counter = 0

        for i in range(m):
            pos_sim = cos_sim[i][gallery_T == query_T[i]]
            neg_sim = cos_sim[i][gallery_T != query_T[i]]

            thresh = torch.max(pos_sim).item()

            if torch.sum(neg_sim > thresh) < k:
                match_counter += 1

        return match_counter / m

    # calculate recall @ 1, 2, 4, 8
    recall = []
    for k in [1, 10, 20, 30, 40, 50]:
        r_at_k = recall_k(cos_sim, query_T, gallery_T, k)
        recall.append(r_at_k)
        print("R@{} : {:.3f}".format(k, 100 * r_at_k))

    # Evaluate by metric learning eval metric
    calculator = AccuracyCalculator(include=(), exclude=(['AMI']), avg_of_avgs=False, k=None)
    result = calculator.get_accuracy(query_X.cpu().detach().numpy(), gallery_X.cpu().detach().numpy(),
                                     query_T.cpu().detach().numpy(), gallery_T.cpu().detach().numpy(),
                                     embeddings_come_from_same_source=False)

    r_precision = result['r_precision']
    map_at_r = result['mean_average_precision_at_r']
    NMI = result['NMI']
    print("\nR_Precision : {:.4f}, MAP@R : {:.4f}, NMI : {:.4f}".format(r_precision, map_at_r, NMI))

    # Indexing
    idxed_descriptor = IndexingNoSplit(z, gallery_X, numSeg=args.num_partitionings)
    quantizedDist = pqDistNoSplit(z.cpu().detach().numpy(), args.num_partitionings, idxed_descriptor.cpu().detach().numpy(),
                           query_X.cpu().detach().numpy()).T
    Rank = np.argsort(quantizedDist, axis=0)
    mAP = cat_apcal(label_similarity, Rank, top_N=label_similarity.shape[1])

    print("\n mAP : {:.4f}".format(mAP))

    return recall, r_precision, map_at_r, mAP

def evaluate_split(args, model, z, query_dataloader, gallery_dataloader, label_similarity):
    #nb_classes = query_dataloader.dataset.nb_classes()

    # calculate embeddings with model and get targets
    query_X, query_T = predict_batchwise(model, query_dataloader)
    gallery_X, gallery_T = predict_batchwise(model, gallery_dataloader)

    query_X = F.normalize(query_X.view(-1, args.num_partitionings, args.sz_embedding), dim=2)
    query_X = query_X.view(-1, args.num_partitionings * args.sz_embedding)
    gallery_X = F.normalize(gallery_X.view(-1, args.num_partitionings, args.sz_embedding), dim=2)
    gallery_X = gallery_X.view(-1, args.num_partitionings * args.sz_embedding)
  #  query_X = F.normalize(query_X, dim=1)
  #  gallery_X = F.normalize(gallery_X, dim=1)


    # get predictions by assigning nearest 8 neighbors with cosine
    K = 50
    Y = []
    xs = []

    cos_sim = F.linear(query_X, gallery_X)

    def recall_k(cos_sim, query_T, gallery_T, k):
        m = len(cos_sim)
        match_counter = 0

        for i in range(m):
            pos_sim = cos_sim[i][gallery_T == query_T[i]]
            neg_sim = cos_sim[i][gallery_T != query_T[i]]

            thresh = torch.max(pos_sim).item()

            if torch.sum(neg_sim > thresh) < k:
                match_counter += 1

        return match_counter / m

    # calculate recall @ 1, 2, 4, 8
    recall = []
    for k in [1, 10, 20, 30, 40, 50]:
        r_at_k = recall_k(cos_sim, query_T, gallery_T, k)
        recall.append(r_at_k)
        print("R@{} : {:.3f}".format(k, 100 * r_at_k))

    # Evaluate by metric learning eval metric
    calculator = AccuracyCalculator(include=(), exclude=(['AMI']), avg_of_avgs=False, k=None)
    result = calculator.get_accuracy(query_X.cpu().detach().numpy(), gallery_X.cpu().detach().numpy(),
                                     query_T.cpu().detach().numpy(), gallery_T.cpu().detach().numpy(),
                                     embeddings_come_from_same_source=False)

    r_precision = result['r_precision']
    map_at_r = result['mean_average_precision_at_r']
    NMI = result['NMI']
    print("\nR_Precision : {:.4f}, MAP@R : {:.4f}, NMI : {:.4f}".format(r_precision, map_at_r, NMI))

    # Indexing
    idxed_descriptor = Indexing(z, gallery_X, numSeg=args.num_partitionings)
    z.data = F.normalize(z.data, p=2, dim=2)
    quantizedDist = pqDist(z.cpu().detach().numpy(), args.num_partitionings, idxed_descriptor.cpu().detach().numpy(),
                           query_X.cpu().detach().numpy()).T
    Rank = np.argsort(quantizedDist, axis=0)
    mAP = cat_apcal(label_similarity, Rank, top_N=label_similarity.shape[1])

    print("\n mAP : {:.4f}".format(mAP))

    return recall, r_precision, map_at_r, mAP



def evaluate_split_2(args, model, z, query_dataloader, gallery_dataloader, label_similarity):
    #nb_classes = query_dataloader.dataset.nb_classes()

    # calculate embeddings with model and get targets
    query_X, query_T = predict_batchwise(model, query_dataloader)
    gallery_X, gallery_T = predict_batchwise(model, gallery_dataloader)

    query_X = F.normalize(query_X.view(-1, args.num_partitionings, args.sz_embedding), dim=2)
    query_X = query_X.view(-1, args.num_partitionings * args.sz_embedding)
    gallery_X = F.normalize(gallery_X.view(-1, args.num_partitionings, args.sz_embedding), dim=2)
    gallery_X = gallery_X.view(-1, args.num_partitionings * args.sz_embedding)
  #  query_X = F.normalize(query_X, dim=1)
  #  gallery_X = F.normalize(gallery_X, dim=1)



    # Indexing
    idxed_descriptor = Indexing(z, gallery_X, numSeg=args.num_partitionings)
    z.data = F.normalize(z.data, p=2, dim=2)
    quantizedDist = pqDist(z.cpu().detach().numpy(), args.num_partitionings, idxed_descriptor.cpu().detach().numpy(),
                           query_X.cpu().detach().numpy()).T
    Rank = np.argsort(quantizedDist, axis=0)
    mAP = cat_apcal(label_similarity, Rank, top_N=label_similarity.shape[1])

    print("\n mAP : {:.4f}".format(mAP))

    return recall, r_precision, map_at_r, mAP

def evaluate_split_NUS(args, model, z, query_dataloader, gallery_dataloader, label_similarity, topk=5000):
    nb_classes = query_dataloader.dataset.nb_classes

    # calculate embeddings with model and get targets
    query_X, query_T = predict_batchwise(model, query_dataloader)
    gallery_X, gallery_T = predict_batchwise(model, gallery_dataloader)

    query_X = F.normalize(query_X.view(-1, args.num_partitionings, args.sz_embedding), dim=2)
    query_X = query_X.view(-1, args.num_partitionings * args.sz_embedding)
    gallery_X = F.normalize(gallery_X.view(-1, args.num_partitionings, args.sz_embedding), dim=2)
    gallery_X = gallery_X.view(-1, args.num_partitionings * args.sz_embedding)
  #  query_X = F.normalize(query_X, dim=1)
  #  gallery_X = F.normalize(gallery_X, dim=1)


    # get predictions by assigning nearest 8 neighbors with cosine
    K = 50
    Y = []
    xs = []

    # Indexing
    idxed_descriptor = Indexing(z, gallery_X, numSeg=args.num_partitionings)
    z.data = F.normalize(z.data, p=2, dim=2)
    quantizedDist = pqDist(z.cpu().detach().numpy(), args.num_partitionings, idxed_descriptor.cpu().detach().numpy(),
                           query_X.cpu().detach().numpy()).T
    Rank = np.argsort(quantizedDist, axis=0)
    mAP = cat_apcal(label_similarity, Rank, top_N=topk)

    print("\n mAP : {:.4f}".format(mAP))

    # num_query = query_X.shape[0]
    # map = 0
    #
    # for iter in range(num_query):
    #     gnd = (np.dot(query_T[iter, :], gallery_T.T) > 0).astype(np.float32)
    #     tsum = np.sum(gnd)
    #     if tsum == 0:
    #         continue
    #     cos_dis = 1 - np.dot(query_X[iter, :].cpu(), gallery_X.T.cpu())
    #     ind = np.argsort(cos_dis)
    #     gnd = gnd[ind]
    #     count = np.linspace(1, tsum, int(tsum))
    #     tindex = np.asarray(np.where(gnd == 1)) + 1.0
    #     map_ = np.mean(count / (tindex))
    #     # print(map_)
    #     map = map + map_
    # mAP = map / num_query
    #
    # print("\n feature mAP : {:.4f}".format(mAP))

    return mAP

def evaluate_split_NUS3(args, model, z, query_dataloader, gallery_dataloader, label_similarity, topk=500):
    nb_classes = 21

    # calculate embeddings with model and get targets
    query_X, query_T = predict_batchwise(model, query_dataloader)
    gallery_X, gallery_T = predict_batchwise(model, gallery_dataloader)

    query_X = F.normalize(query_X.view(-1, args.num_partitionings, args.sz_embedding), dim=2)
    query_X = query_X.view(-1, args.num_partitionings * args.sz_embedding)
    gallery_X = F.normalize(gallery_X.view(-1, args.num_partitionings, args.sz_embedding), dim=2)
    gallery_X = gallery_X.view(-1, args.num_partitionings * args.sz_embedding)
  #  query_X = F.normalize(query_X, dim=1)
  #  gallery_X = F.normalize(gallery_X, dim=1)


    # get predictions by assigning nearest 8 neighbors with cosine
    K = 50
    Y = []
    xs = []

    # Indexing
    idxed_descriptor = Indexing(z, gallery_X, numSeg=args.num_partitionings)
    z.data = F.normalize(z.data, p=2, dim=2)
    quantizedDist = pqDist(z.cpu().detach().numpy(), args.num_partitionings, idxed_descriptor.cpu().detach().numpy(),
                           query_X.cpu().detach().numpy()).T
    Rank = np.argsort(quantizedDist, axis=0)
    mAP = cat_apcal(label_similarity, Rank, top_N=topk)

    print("\n mAP : {:.4f}".format(mAP))

    return mAP



def evaluate_NUS(args, model, query_dataloader, gallery_dataloader):

    # calculate embeddings with model and get targets
    query_X, query_T = predict_batchwise(model, query_dataloader)
    gallery_X, gallery_T = predict_batchwise(model, gallery_dataloader)

    #query_X = F.normalize(query_X.view(-1, args.num_partitionings, args.sz_embedding), dim=2)
    #query_X = query_X.view(-1, args.num_partitionings * args.sz_embedding)
    #gallery_X = F.normalize(gallery_X.view(-1, args.num_partitionings, args.sz_embedding), dim=2)
    #gallery_X = gallery_X.view(-1, args.num_partitionings * args.sz_embedding)
    query_X = F.normalize(query_X, dim=1).cpu().numpy()
    gallery_X = F.normalize(gallery_X, dim=1).cpu().numpy()


    num_query = query_X.shape[0]
    map = 0

    for iter in range(num_query):
        gnd = (np.dot(query_T[iter, :], gallery_T.T) > 0).astype(np.float32)
        tsum = np.sum(gnd)
        if tsum == 0:
            continue
        cos_dis = 1 - np.dot(query_X[iter, :], gallery_X.T)
        ind = np.argsort(cos_dis)
        gnd = gnd[ind]
        count = np.linspace(1, tsum, int(tsum))
        tindex = np.asarray(np.where(gnd==1)) + 1.0
        map_ = np.mean(count / (tindex))
        # print(map_)
        map = map + map_
    mAP = map / num_query

    print("\n mAP : {:.4f}".format(mAP))

    return mAP

def encode_onehot(labels, num_classes=196):
    """
    one-hot labels
    Args:
        labels (numpy.ndarray): labels.
        num_classes (int): Number of classes.
    Returns:
        onehot_labels (numpy.ndarray): one-hot labels.
    """
    onehot_labels = np.zeros((len(labels), num_classes))

    for i in range(len(labels)):
        onehot_labels[i, labels[i]] = 1

    return onehot_labels

def evaluate_fine(args, model, query_dataloader, gallery_dataloader):

    # calculate embeddings with model and get targets
    query_X, query_T = predict_batchwise(model, query_dataloader)
    gallery_X, gallery_T = predict_batchwise(model, gallery_dataloader)

    calculator = AccuracyCalculator(include=(), exclude=(['AMI']), avg_of_avgs=False, k=None)
    result = calculator.get_accuracy(query_X.cpu().detach().numpy(), gallery_X.cpu().detach().numpy(),
                                     query_T.cpu().detach().numpy(), gallery_T.cpu().detach().numpy(),
                                     embeddings_come_from_same_source=False)

    r_precision = result['r_precision']
    map_at_r = result['mean_average_precision_at_r']
    NMI = result['NMI']
    print("\nR_Precision : {:.4f}, MAP@R : {:.4f}, NMI : {:.4f}".format(r_precision, map_at_r, NMI))

    #query_X = F.normalize(query_X.view(-1, args.num_partitionings, args.sz_embedding), dim=2)
    #query_X = query_X.view(-1, args.num_partitionings * args.sz_embedding)
    #gallery_X = F.normalize(gallery_X.view(-1, args.num_partitionings, args.sz_embedding), dim=2)
    #gallery_X = gallery_X.view(-1, args.num_partitionings * args.sz_embedding)
    query_X = F.normalize(query_X, dim=1).cpu().numpy()
    gallery_X = F.normalize(gallery_X, dim=1).cpu().numpy()
    if args.dataset == 'cub200':
        query_T = encode_onehot(query_T, num_classes=200)
        gallery_T = encode_onehot(gallery_T, num_classes=200)
    elif args.dataset == 'car196':
        query_T = encode_onehot(query_T, num_classes=196)
        gallery_T = encode_onehot(gallery_T, num_classes=196)


    num_query = query_X.shape[0]
    map = 0

    for iter in range(num_query):
        gnd = (np.dot(query_T[iter, :], gallery_T.T) > 0).astype(np.float32)
        tsum = np.sum(gnd)
        if tsum == 0:
            continue
        cos_dis = 1 - np.dot(query_X[iter, :], gallery_X.T)
        ind = np.argsort(cos_dis)
        gnd = gnd[ind]
        count = np.linspace(1, tsum, int(tsum))
        tindex = np.asarray(np.where(gnd==1)) + 1.0
        map_ = np.mean(count / (tindex))
        # print(map_)
        map = map + map_
    mAP = map / num_query

    print("\n mAP : {:.4f}".format(mAP))



    return mAP



## Need to Check how this operate
def IndexingNoSplit(z, descriptor, numSeg):
    z = F.normalize(z, dim=2)
    x = descriptor
    #y = torch.split(z, int(z.size(0) / numSeg), dim=0)
    y = z

    for i in range(numSeg):
        size_x = x.size(0)
        size_y = y[i, :, :].size(0)
        xx = x.unsqueeze(-1)
        xx = xx.repeat(1, 1, size_y)

        yy = y[i, :, :].unsqueeze(-1)
        yy = yy.repeat(1, 1, size_x)
        #yy = np.transpose(yy, (2, 1, 0))
        yy = yy.permute(2, 1, 0)
        cos_sim = (xx * yy).sum(dim=1)

        arg = torch.argmax(cos_sim, dim=1)
        max_idx = arg.view(-1, 1)

        if i == 0:
            quant_idx = max_idx
        else:
            quant_idx = torch.cat([quant_idx, max_idx], dim=1)

    return quant_idx

## Need to Check how this operate
## z = num_partitionings x num_partitions X sz_embeddings
## descriptor = B X num_partitionings X sz_embeddings
def Indexing(z, descriptor, numSeg):
    z = F.normalize(z, dim=2)
    x = torch.split(descriptor, int(descriptor.size(1) / numSeg), dim=1)
    y = torch.split(z, int(z.size(0) / numSeg), dim=0)

    for i in range(numSeg):
        size_x = x[i].size(0)
        size_y = y[i].squeeze().size(0)
        xx = x[i].unsqueeze(-1)
        xx = xx.repeat(1, 1, size_y)

        yy = y[i].squeeze().unsqueeze(-1)
        yy = yy.repeat(1, 1, size_x)
        #yy = np.transpose(yy, (2, 1, 0))
        yy = yy.permute(2, 1, 0)
        cos_sim = (xx * yy).sum(dim=1)

        arg = torch.argmax(cos_sim, dim=1)
        max_idx = arg.view(-1, 1)

        if i == 0:
            quant_idx = max_idx
        else:
            quant_idx = torch.cat([quant_idx, max_idx], dim=1)

    return quant_idx

# Compute distances and build look-up-table
def pqDistNoSplit(Z, numSeg, g_x, q_x):
    n1 = q_x.shape[0]
    n2 = g_x.shape[0]
    Z_split = Z
    l1, l2 = Z_split[0, :, :].shape

    D_Z = np.zeros((l1, numSeg), dtype=np.float32)

    q_x_split = q_x
    g_x_split = np.split(g_x, numSeg, 1)
    D_Z_split = np.split(D_Z, numSeg, 1)

    Dpq = np.zeros((n1, n2), dtype=np.float32)

    for i in range(n1):
        for j in range(numSeg):
            for k in range(l1):
                D_Z_split[j][k] =1-np.dot(q_x_split[i], Z_split[j][k])
            if j == 0:
                y = D_Z_split[j][g_x_split[j]]
            else:
                y = np.add(y, D_Z_split[j][g_x_split[j]])
        Dpq[i, :] = np.squeeze(y)
    return Dpq


# Compute distances and build look-up-table
## z = num_partitionings x num_partitions X sz_embeddings
def pqDist(Z, numSeg, g_x, q_x):
    n1 = q_x.shape[0]
    n2 = g_x.shape[0]
    l1 = Z.shape[1]
    l2 = Z.shape[0]* Z.shape[2]
    #l1, l2 = Z.shape

    D_Z = np.zeros((l1, numSeg), dtype=np.float32)

    q_x_split = np.split(q_x, numSeg, 1)
    g_x_split = np.split(g_x, numSeg, 1)
    Z_split = np.split(Z, numSeg, 0)
    D_Z_split = np.split(D_Z, numSeg, 1)

    Dpq = np.zeros((n1, n2), dtype=np.float32)

    for i in range(n1):
        for j in range(numSeg):
            for k in range(l1):
                D_Z_split[j][k] =1-np.dot(q_x_split[j][i], Z_split[j].squeeze()[k])
            if j == 0:
                y = D_Z_split[j][g_x_split[j]]
            else:
                y = np.add(y, D_Z_split[j][g_x_split[j]])
        Dpq[i, :] = np.squeeze(y)
    return Dpq


# Average Precision (AP) Calculation
def cat_apcal(label_Similarity, IX, top_N):

    [_, numtest] = IX.shape

    apall = np.zeros(numtest)

    for i in range(numtest):
        y = IX[:, i]
        x = 0
        p = 0

        for j in range(top_N):
            if label_Similarity[i, y[j]] == 1:
                x = x + 1
                p = p + float(x) / (j + 1)
        if p == 0:
            apall[i] = 0
        else:
            apall[i] = p / x

    mAP = np.mean(apall)

    return mAP


def check_unique_code_assign(partitionings, num_classes):

    u_part = set()

    for p in partitionings:
        # if p.sum() > 5:
        u_part.add(str(p))

    if len(u_part) != num_classes:
        return False
    else:
        return True


def LabelToMetaLabel(partitionings, instance_labels):
    #partitionings = num_classes X num_partitionings
    meta_labels = partitionings[instance_labels, :]
    return meta_labels

def adaptation_factor(p, gamma=10):
    p = max(min(p, 1.0), 0.0)
    den = 1.0 + math.exp(-gamma * p)
    lamb = 2.0 / den - 1.0
    return min(lamb, 1.0)




class UnlabeledDataManager(object):
    def __init__(self, args, unlabeled_trainset):

        self.unlabeled_trainset = unlabeled_trainset
        self.unlabeled_train_loader = data.DataLoader(self.unlabeled_trainset, batch_size=args.sz_batch, shuffle=True,
                                                 num_workers=args.nb_workers, pin_memory=True, drop_last=True)
        self.unlabeled_iter = enumerate(self.unlabeled_train_loader)


    def next_unlabeled_train(self):
        try:
            _, (x, _) = next(self.unlabeled_iter)
        except StopIteration:
            self.unlabeled_iter = enumerate(self.unlabeled_train_loader)
            _, (x, _) = next(self.unlabeled_iter)

        return x

    def next_unlabeled_train_pseudo(self):
        try:
            _, (x, y) = next(self.unlabeled_iter)
        except StopIteration:
            self.unlabeled_iter = enumerate(self.unlabeled_train_loader)
            _, (x, y) = next(self.unlabeled_iter)

        return x, y

    def next_unlabeled_train_oracle(self):
        try:
            _, (x, y) = next(self.unlabeled_iter)
        except StopIteration:
            self.unlabeled_iter = enumerate(self.unlabeled_train_loader)
            _, (x, y) = next(self.unlabeled_iter)

        return x, y

"""Changed"""
class UnlabeledDataManager2(object):
    def __init__(self, args, unlabeled_trainset):

        self.unlabeled_trainset = unlabeled_trainset
        self.unlabeled_train_loader = data.DataLoader(self.unlabeled_trainset, batch_size=args.sz_batch, shuffle=True,
                                                 num_workers=args.nb_workers, pin_memory=True, drop_last=True)
        self.unlabeled_iter = enumerate(self.unlabeled_train_loader)

    def next_unlabeled_train(self):
        try:
            """Changed"""
            _, (x1, x2, y) = next(self.unlabeled_iter)
        except StopIteration:
            self.unlabeled_iter = enumerate(self.unlabeled_train_loader)
            _, (x1, x2, y) = next(self.unlabeled_iter)

        return x1, x2, y

    def next_unlabeled_train_pseudo(self):
        try:
            _, (x, y) = next(self.unlabeled_iter)
        except StopIteration:
            self.unlabeled_iter = enumerate(self.unlabeled_train_loader)
            _, (x, y) = next(self.unlabeled_iter)

        return x, y

    def next_unlabeled_train_oracle(self):
        try:
            _, (x, y) = next(self.unlabeled_iter)
        except StopIteration:
            self.unlabeled_iter = enumerate(self.unlabeled_train_loader)
            _, (x, y) = next(self.unlabeled_iter)

        return x, y

class UnlabeledDataManagerNUS3(object):
    def __init__(self, args, unlabeled_trainset):

        self.unlabeled_trainset = unlabeled_trainset
        self.unlabeled_train_loader = data.DataLoader(self.unlabeled_trainset, batch_size=args.sz_batch, shuffle=True,
                                                 num_workers=args.nb_workers, pin_memory=True, drop_last=False)
        self.unlabeled_iter = enumerate(self.unlabeled_train_loader)


    def next_unlabeled_train(self):
        try:
            _, (x, _, _) = next(self.unlabeled_iter)
        except StopIteration:
            self.unlabeled_iter = enumerate(self.unlabeled_train_loader)
            _, (x, _, _) = next(self.unlabeled_iter)

        return x

    def next_unlabeled_train_oracle(self):
        try:
            _, (x, y) = next(self.unlabeled_iter)
        except StopIteration:
            self.unlabeled_iter = enumerate(self.unlabeled_train_loader)
            _, (x, y) = next(self.unlabeled_iter)

        return x, y



def PairEnum(x,mask=None):
    # Enumerate all pairs of feature in x
    assert x.ndimension() == 2, 'Input dimension must be 2'
    x1 = x.repeat(x.size(0),1)
    x2 = x.repeat(1,x.size(0)).view(-1,x.size(1))
    if mask is not None:
        xmask = mask.view(-1,1).repeat(1,x.size(1))
        #dim 0: #sample, dim 1:#feature
        x1 = x1[xmask].view(-1,x.size(1))
        x2 = x2[xmask].view(-1,x.size(1))
    return x1,x2


def cluster_acc(y_true, y_pred):
    """
    Calculate clustering accuracy. Require scikit-learn installed

    # Arguments
        y: true labels, numpy.array with shape `(n_samples,)`
        y_pred: predicted labels, numpy.array with shape `(n_samples,)`

    # Return
        accuracy, in [0,1]
    """
    y_true = y_true.astype(np.int64)
    assert y_pred.size == y_true.size
    D = max(y_pred.max(), y_true.max()) + 1
    w = np.zeros((D, D), dtype=np.int64)
    for i in range(y_pred.size):
        w[y_true[i], y_pred[i]] += 1
    ind = linear_assignment(w.max() - w)
    return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size



def cluster_acc_split(y_true, y_pred, seen=7):
    """
    Calculate clustering accuracy. Require scikit-learn installed

    # Arguments
        y: true labels, numpy.array with shape `(n_samples,)`
        y_pred: predicted labels, numpy.array with shape `(n_samples,)`

    # Return
        accuracy, in [0,1]
    """
    y_true = y_true.astype(np.int64)
    assert y_pred.size == y_true.size
    D = max(y_pred.max(), y_true.max()) + 1
    w = np.zeros((D, D), dtype=np.int64)
    for i in range(y_pred.size):
        w[y_true[i], y_pred[i]] += 1
    ind = linear_assignment(w.max() - w)

    seen_y_true = y_true[y_true < seen]
    unseen_y_true = y_true[y_true >= seen]
    seen_acc = sum([w[i, j] for i, j in ind[:seen]]) * 1.0 / seen_y_true.size
    unseen_acc = sum([w[i, j] for i, j in ind[seen:]]) * 1.0 / unseen_y_true.size
    total_acc = sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size

    return total_acc, seen_acc, unseen_acc


def calculate_threshold(descriptor_l, targets, global_proxies, num_classes=7):
    sim_lst = []
    for cls in range(num_classes):
        idxs = torch.nonzero(targets == cls).squeeze()
        descriptor = descriptor_l[idxs]
        proxy = global_proxies[cls].unsqueeze(0)
        descriptor = F.normalize(descriptor, p=2, dim=1)
        proxy = F.normalize(proxy, p=2, dim=1)
        #sim = torch.matmul(descriptor, proxy.t())
        sim = torch.matmul(descriptor, descriptor.t())
        mean_sim = sim.mean(1)
        mean_sim = mean_sim.min()
        sim_lst.append(mean_sim)
    avg_sim = sum(sim_lst) / num_classes
    return avg_sim






def timing(f):
    """print time used for function f"""

    @wraps(f)
    def wrapper(*args, **kwargs):
        time_start = time.time()
        ret = f(*args, **kwargs)
        print(f'total time = {time.time() - time_start:.4f}')
        return ret

    return wrapper





#######################################################################################################################
## Code for SSAH

def sim_degree(x_1, x_2):
    k = x_1.size(-1)
    return (torch.matmul(x_1, torch.transpose(x_2, -1, -2)) + k) / (2 * k)

def compute_mAP(trn_binary, tst_binary, trn_label, tst_label):
    """
    compute mAP by searching testset from trainset
    https://github.com/flyingpot/pytorch_deephash
    """
    for x in trn_binary, tst_binary, trn_label, tst_label: x.long()

    AP = []
    Ns = torch.arange(1, trn_binary.size(0) + 1)
    for i in range(tst_binary.size(0)):
        query_label, query_binary = tst_label[i], tst_binary[i]
        _, query_result = torch.sum((query_binary != trn_binary).long(), dim=1).sort()
        correct = (query_label == trn_label[query_result]).float()
        P = torch.cumsum(correct, dim=0) / Ns
        AP.append(torch.sum(P * correct) / torch.sum(correct))
    mAP = torch.mean(torch.Tensor(AP))
    return mAP


@timing
def evaluate_binary(args, model, query_dataloader, gallery_dataloader, label_similarity):
    #nb_classes = query_dataloader.dataset.nb_classes()

    # calculate embeddings with model and get targets
    query_X, query_T = predict_batchwise(model, query_dataloader)
    gallery_X, gallery_T = predict_batchwise(model, gallery_dataloader)

    query_X = (torch.sign(query_X - 0.5) + 1) * 0.5
    gallery_X = (torch.sign(gallery_X - 0.5) + 1) * 0.5

    mAP = compute_mAP(gallery_X, query_X, gallery_T, query_T)

    print("\n mAP : {:.4f}".format(mAP))

    return mAP




