import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torch.optim.lr_scheduler
import random
import numpy as np
import math
from utils.utils import *
from utils import get_data
from utils import convmix
from core import pretrain
from scipy.spatial.distance import cdist
from torch.optim.lr_scheduler import _LRScheduler

def train(args, dg_featurizer, dg_classifier, src_data_loader_proto, src_data_loader_train, src_data_loader_valid, ul1_dataset_train, ul2_dataset_train, ul1_data_loader_valid, ul2_data_loader_valid, test_data_loader, transform_train):
    
    dg_featurizer.to(args.device)
    dg_classifier.to(args.device)

    dg_featurizer.train()
    dg_classifier.train()

    if args.optimizer == 'SGD':
        optimizer = optim.SGD([
                {'params': dg_featurizer.parameters()},
                {'params': dg_classifier.parameters()}
            ], lr=args.lr, weight_decay=args.wd, momentum=args.momentum, nesterov=args.nesterov)

    if args.schedule == 'poly':
        optimizer = op_copy(optimizer)

    criterion = nn.CrossEntropyLoss().to(args.device)

    src_acc = []
    ul1_acc = []
    ul2_acc = []
    test_acc = []
    lam1_list = []
    lam2_list = []
    ul1_pl_acc = []
    ul2_pl_acc = []

    if args.augmentation == 'convmix':
        aug = convmix.ConvMix(args).to(args.device)

    iter_num = 0

    for epoch in range (args.epoch):
        
        src_iter = []
        for i in range(args.num_classes):
            src_iter.append(iter(src_data_loader_train[i]))
        
        dg_featurizer.eval()
        dg_classifier.eval()

        images1, labels1, domain_labels1 = ul1_dataset_train.get_raw_data()
        images2, labels2, domain_labels2 = ul2_dataset_train.get_raw_data()
        images = images1 + images2
        domain_labels = np.concatenate((domain_labels1, domain_labels2))
        ul1_data_loader_unshuffle = data.DataLoader(ul1_dataset_train, batch_size=args.batch_size, num_workers=0, shuffle=False, drop_last=False)
        ul2_data_loader_unshuffle = data.DataLoader(ul2_dataset_train, batch_size=args.batch_size, num_workers=0, shuffle=False, drop_last=False)

        #generate pseudo labels every epoch
        if args.pl == 'oracle':
            ul1_pseudo_labels = labels1
            ul2_pseudo_labels = labels2
            _, lam1, _, ul1_protos = shot_label(args, ul1_data_loader_unshuffle, dg_featurizer, dg_classifier)
            _, lam2, _, ul2_protos = shot_label(args, ul2_data_loader_unshuffle, dg_featurizer, dg_classifier)
        elif args.pl == 'shot':
            if args.trans_ensemble:
                ul1_pseudo_labels, lam1, ul1_pl_shot, ul1_protos = shot_label_trans_ens(args, ul1_data_loader_unshuffle, dg_featurizer, dg_classifier, epoch+1, 'ul1')
                ul2_pseudo_labels, lam2, ul2_pl_shot, ul2_protos = shot_label_trans_ens(args, ul2_data_loader_unshuffle, dg_featurizer, dg_classifier, epoch+1, 'ul2')
                src_protos = src_proto_trans_ens(args, src_data_loader_proto, dg_featurizer, epoch+1)
            else:
                ul1_pseudo_labels, lam1, ul1_pl_shot, ul1_protos = shot_label(args, ul1_data_loader_unshuffle, dg_featurizer, dg_classifier)
                ul2_pseudo_labels, lam2, ul2_pl_shot, ul2_protos = shot_label(args, ul2_data_loader_unshuffle, dg_featurizer, dg_classifier)
                src_protos = src_proto(args, src_data_loader_proto, dg_featurizer)
            
            ul1_pl_acc.append(ul1_pl_shot)
            ul2_pl_acc.append(ul2_pl_shot)
            protos = np.stack((src_protos, ul1_protos, ul2_protos), axis=0)
            protos_mean = np.mean(protos, axis=0)

        lam1_list.append(np.mean(lam1))
        lam2_list.append(np.mean(lam2))
        pseudo_labels = np.concatenate((np.array(ul1_pseudo_labels), np.array(ul2_pseudo_labels)))  
        lam = np.concatenate((lam1, lam2))
        ul_pseudo_dataset = get_data.utilDataset(images, pseudo_labels, domain_labels, lam, transform=transform_train)
        ul_pseudo_dataloader = data.DataLoader(ul_pseudo_dataset, batch_size=args.batch_size, num_workers=0, shuffle=True, drop_last=True)
        num_steps = len(ul_pseudo_dataloader)

        dg_featurizer.train()
        dg_classifier.train()

        max_iter = args.epoch * len(ul_pseudo_dataloader)

        for step, (ul_images, ul_labels, ul_domains, lam) in enumerate(ul_pseudo_dataloader):
            
            iter_num += 1
            if args.schedule == 'poly':
                lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter, power=args.power)

            dg_featurizer.eval()
            dg_classifier.eval()

            with torch.no_grad():
                ul_images = ul_images.to(args.device)
                ul_labels = ul_labels.to(args.device)
                ul_domains = ul_domains.to(args.device)
                lam = lam.to(args.device)

                ul_features_org = dg_featurizer(ul_images)

            dg_featurizer.train()
            dg_classifier.train()

            src_images_list = []
            src_labels_list = []
            src_domains_list = []

            for label in ul_labels:
                src_image, src_label, src_domain = next(src_iter[label])
                src_images_list.append(src_image)
                src_labels_list.append(src_label)
                src_domains_list.append(src_domain)
                
            src_images = torch.cat(src_images_list, dim=0)
            src_labels = torch.cat(src_labels_list, dim=0)
            src_domains = torch.cat(src_domains_list, dim=0)

            src_images = src_images.to(args.device)
            src_labels = src_labels.to(args.device)
            src_domains = src_domains.to(args.device)

            #augmentation
            org_src = src_images
            org_ul = ul_images
           
            if not args.augmentation == 'none':
                src_images = aug(src_images)
                ul_images = aug(ul_images)
            if args.aug_ogcat:
                src_images = torch.cat([org_src, src_images])
                ul_images = torch.cat([org_ul, ul_images])
                src_labels = torch.cat([src_labels, src_labels])

            src_feature = dg_featurizer(src_images)
            src_preds = dg_classifier(src_feature)
            ul_feature = dg_featurizer(ul_images)
            
            if args.aug_ogcat:
                indices1 = torch.randint(0, 2, (args.batch_size,))
                indices2 = 1-indices1
                ar = torch.arange(0,args.batch_size)
                indices1 = indices1*args.batch_size + ar
                indices2 = indices2*args.batch_size + ar
                indices = torch.cat([indices1, indices2])
                ul_feature = ul_feature[indices, :]       
            
            ul_preds = dg_classifier(ul_feature)

            #calculate entropy and lambda
            protos = [torch.tensor(ul1_protos) if ul_domains[i] == 1 else torch.tensor(ul2_protos) for i in range(len(ul_domains))]

            lam = lam.view(args.batch_size, 1)
            
            if args.aug_ogcat:
                lam = torch.cat([lam, lam]).to(args.device)
            
            #Domain Mix
            mixed_feature = (1-lam)*src_feature + lam*ul_feature
            #Mixup
            lam_mixup = np.random.beta(args.mixup_alpha, args.mixup_alpha)
            index = torch.randperm(mixed_feature.size(0)).to(args.device)
            mixup_feature = (lam_mixup * mixed_feature + (1 - lam_mixup) * mixed_feature[index, :])
            y_a, y_b = src_labels, src_labels[index]

            # zero gradients for optimizer
            optimizer.zero_grad()

            # compute loss for critic
            
            preds = dg_classifier(mixed_feature)
            preds_mixup = dg_classifier(mixup_feature)

            loss = lam_mixup * criterion(preds_mixup, y_a) + (1 - lam_mixup) * criterion(preds_mixup, y_b)
            labeled_CE = criterion(src_preds, y_a)
            unlabeled_CE = criterion(ul_preds, y_a)

            loss += args.delta * (ContrastiveLoss(args, src_feature, src_labels, protos_mean) + ContrastiveLoss(args, ul_feature, src_labels, protos_mean))/2
            loss += args.beta * EMloss(args, preds)
            loss += args.gamma * LDloss(args, preds)        

            loss.backward()
            optimizer.step()

            """print("Epoch [{}/{}] Step [{}/{}]: loss={}"
                      .format(epoch + 1,
                              args.epoch,
                              step + 1,
                              len(ul_pseudo_dataloader),
                              loss.mean().data.item()))"""

            log_message = "Epoch [{}/{}] Step [{}/{}]: loss={}".format(epoch + 1, args.epoch, step + 1, len(ul_pseudo_dataloader), loss.mean().data.item())
            
            write_log(args, log_message)

        if not args.schedule == 'none':
            scheduler.step()

        evaluate(args, dg_featurizer, dg_classifier, src_data_loader_valid, src_acc, 'Source')
        evaluate(args, dg_featurizer, dg_classifier, ul1_data_loader_valid, ul1_acc, 'UL1')
        evaluate(args, dg_featurizer, dg_classifier, ul2_data_loader_valid, ul2_acc, 'UL2')
        evaluate(args, dg_featurizer, dg_classifier, test_data_loader, test_acc, 'Test')

        if epoch % 200 == 0 and args.dataset == 'Digits':
            make_plot(args, src_acc, ul1_acc, ul2_acc, test_acc, lam1_list, lam2_list, ul1_pl_acc, ul2_pl_acc, mid=True)
            title_featurizer = 'featurizer_' + args.title + '.pt'
            title_classifier = 'classifier_' + args.title + '.pt'
            save_model(args, dg_featurizer, title_featurizer)
            save_model(args, dg_classifier, title_classifier)

        if epoch % 20 == 0 and not args.dataset == 'Digits':
            make_plot(args, src_acc, ul1_acc, ul2_acc, test_acc, lam1_list, lam2_list, ul1_pl_acc, ul2_pl_acc, mid=True)
            title_featurizer = 'featurizer_' + args.title + '.pt'
            title_classifier = 'classifier_' + args.title + '.pt'
            save_model(args, dg_featurizer, title_featurizer)
            save_model(args, dg_classifier, title_classifier)

    # save final model
    title_featurizer = 'featurizer_' + args.title + '.pt'
    title_classifier = 'classifier_' + args.title + '.pt'
    save_model(args, dg_featurizer, title_featurizer)
    save_model(args, dg_classifier, title_classifier)

    return src_acc, ul1_acc, ul2_acc, test_acc, lam1_list, lam2_list, ul1_pl_acc, ul2_pl_acc

def evaluate(args, featurizer, classifier, data_loader, acc_list, domain_name):
    """Evaluate classifier for source domain."""
    # set eval state for Dropout and BN layers
    featurizer.eval()
    classifier.eval()

    # init loss and accuracy
    loss = 0
    acc = 0.0

    # set loss function
    criterion = nn.CrossEntropyLoss()

    # evaluate network
    with torch.no_grad():
        for (images, labels, domains) in data_loader:

            preds = classifier(featurizer(images.to(args.device)))
            loss += criterion(preds, labels.to(args.device)).data.item()

            pred_cls = preds.data.max(1)[1]
            acc += pred_cls.eq(labels.to(args.device).data).cpu().sum()

    loss /= len(data_loader)
    acc /= len(data_loader.dataset)

    acc_list.append(acc)

    featurizer.train()
    classifier.train()

    #print("{} Loss = {}, {} Accuracy = {:2%}".format(domain_name, loss, domain_name, acc))
    log_message = "{} Loss = {}, {} Accuracy = {:2%}".format(domain_name, loss, domain_name, acc)
    write_log(args, log_message)

def shot_label(args, loader, featurizer, classifier):
    start_test = True
    all_fea = []
    all_output = []
    all_label = []
    with torch.no_grad():
        for i, data in enumerate(loader):
            inputs = data[0].to(args.device)
            labels = data[1]

            feas = featurizer(inputs)
            outputs = classifier(feas)
       
            all_fea.append(feas.float().cpu())
            all_output.append(outputs.float().cpu())
            all_label.append(labels.float())

    all_fea = torch.cat(all_fea)
    all_output = torch.cat(all_output)
    all_label = torch.cat(all_label)

    all_output = nn.Softmax(dim=1)(all_output)
    epsilon = 1e-5
    ent = torch.sum(-all_output * torch.log(all_output + epsilon), dim=1)
    unknown_weight = 1 - ent / np.log(args.num_classes)
    _, predict = torch.max(all_output, 1)

    accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
    
    all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1)   # (N, dim+1)
    if args.distance == 'cosine':    
        all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t()       # (N, dim+1) / norm(N, dim+1)

    all_fea = all_fea.float().cpu().numpy()   # (N, dim+1)
    K = all_output.size(1)
    aff = all_output.float().cpu().numpy()    # (N, C)

    for _ in range(2):
        initc = aff.transpose().dot(all_fea)  # (C, dim+1) molecule of equation (4)
        initc = initc / (1e-8 + aff.sum(axis=0)[:,None])  # (C, dim+1) / (C, 1)   cluster center of each class
        cls_count = np.eye(K)[predict].sum(axis=0)   # (1, C) each element representation the number of prediction of that class
        labelset = np.where(cls_count>0)
        labelset = labelset[0]    # 1D index of which class has been assign pseudo label.

        if args.distance == 'cosine':
            dd = 1 - cdist(all_fea, initc[labelset], 'cosine')   # (N,C) distance of each features and cluster center
        elif args.distance == 'euclidean':
            dd = -cdist(all_fea, initc[labelset], 'euclidean')
            if args.fea_dim:
                dd = dd / math.sqrt(all_fea.shape[1]-1)
        elif args.distance == 'sqeuclidean':
            dd = -cdist(all_fea, initc[labelset], 'sqeuclidean')
            if args.fea_dim:
                dd = dd / (all_fea.shape[1]-1)
        dd = dd / args.stemp

        pred_softmax = np.exp(dd) / np.sum(np.exp(dd), axis=1, keepdims=True)   # (N)
        pred_label = dd.argmax(axis=1)   # (N)
        predict = labelset[pred_label]   # (N)

        aff = np.eye(K)[predict]         # one-hot (N, C)

    entropy = Entropy_(pred_softmax)
    lam = np.exp(-entropy/args.ltemp) / (1 + np.exp(-entropy/args.ltemp))
    lam = np.where(lam < args.lam_thres, lam, above_thres(args, lam))
    lam = lam.astype(np.float32)

    acc = np.sum(predict == all_label.float().numpy()) / len(all_fea)

    return predict.astype('int'), lam, acc, initc

def shot_label_trans_ens(args, loader, featurizer, classifier, epoch, dname):
    start_test = True
    all_fea = []
    all_output = []
    all_label = []

    with torch.no_grad():
        for j in range(args.num_ens):
            for i, data in enumerate(loader):
                inputs = data[0].to(args.device)
                labels = data[1]

                feas = featurizer(inputs)
                outputs = classifier(feas)
        
                all_fea.append(feas.float().cpu())
                all_output.append(outputs.float().cpu())
                all_label.append(labels.float())

    all_fea = torch.cat(all_fea)
    all_output = torch.cat(all_output)
    all_label = torch.cat(all_label)

    all_output = nn.Softmax(dim=1)(all_output)
    epsilon = 1e-5
    ent = torch.sum(-all_output * torch.log(all_output + epsilon), dim=1)
    unknown_weight = 1 - ent / np.log(args.num_classes)
    _, predict = torch.max(all_output, 1)

    all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1)   # (N, dim+1)

    #with open(f"./tsne/samples/{args.order_name}/epoch{epoch}_{dname}.pickle", "wb") as fw:
    #    pickle.dump(all_fea, fw)

    #with open(f"./tsne/labels/{args.order_name}/{dname}.pickle", "wb") as fw:
    #    pickle.dump(all_label, fw)

    if args.distance == 'cosine':    
        all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t()       # (N, dim+1) / norm(N, dim+1)

    all_fea = all_fea.float().cpu().numpy()   # (N, dim+1)
    K = all_output.size(1)
    aff = all_output.float().cpu().numpy()    # (N, C)

    initc = aff.transpose().dot(all_fea)  # (C, dim+1) molecule of equation (4)
    initc = initc / (1e-8 + aff.sum(axis=0)[:,None])  # (C, dim+1) / (C, 1)   cluster center of each class
    cls_count = np.eye(K)[predict].sum(axis=0)   # (1, C) each element representation the number of prediction of that class
    labelset = np.where(cls_count>0)
    labelset = labelset[0]    # 1D index of which class has been assign pseudo label.

    if args.distance == 'cosine':
        dd = 1 - cdist(all_fea, initc[labelset], 'cosine')   # (N,C) distance of each features and cluster center
    elif args.distance == 'euclidean':
        dd = -cdist(all_fea, initc[labelset], 'euclidean')
        if args.fea_dim:
            dd = dd / math.sqrt(all_fea.shape[1]-1)
    elif args.distance == 'sqeuclidean':
        dd = -cdist(all_fea, initc[labelset], 'sqeuclidean')
        if args.fea_dim:
            dd = dd / (all_fea.shape[1]-1)
    dd = dd / args.stemp
    
    pred_softmax = np.exp(dd) / np.sum(np.exp(dd), axis=1, keepdims=True)   # (N)
    data_num = pred_softmax.shape[0]//args.num_ens
    pred_softmax_ens = pred_softmax[:data_num]
    for i in range(1, args.num_ens):
        pred_softmax_ens += pred_softmax[i*data_num:(i+1)*data_num]
    pred_softmax_ens /= args.num_ens
    pred_label = pred_softmax_ens.argmax(axis=1)   # (N)
    pred_label_c = pred_label
    for i in range(1, args.num_ens):
        pred_label_c = np.concatenate((pred_label_c, pred_label))
    predict = labelset[pred_label_c]   # (N)
    aff = np.eye(K)[predict]         # one-hot (N, C)

    initc = aff.transpose().dot(all_fea)  # (C, dim+1) molecule of equation (4)
    initc = initc / (1e-8 + aff.sum(axis=0)[:,None])  # (C, dim+1) / (C, 1)   cluster center of each class
    cls_count = np.eye(K)[predict].sum(axis=0)   # (1, C) each element representation the number of prediction of that class
    labelset = np.where(cls_count>0)
    labelset = labelset[0]    # 1D index of which class has been assign pseudo label.

    if args.distance == 'cosine':
        dd = 1 - cdist(all_fea, initc[labelset], 'cosine')   # (N,C) distance of each features and cluster center
    elif args.distance == 'euclidean':
        dd = -cdist(all_fea, initc[labelset], 'euclidean')
        if args.fea_dim:
            dd = dd / math.sqrt(all_fea.shape[1]-1)
    elif args.distance == 'sqeuclidean':
        dd = -cdist(all_fea, initc[labelset], 'sqeuclidean')
        if args.fea_dim:
            dd = dd / (all_fea.shape[1]-1)
    dd = dd / args.stemp

    pred_softmax = np.exp(dd) / np.sum(np.exp(dd), axis=1, keepdims=True)   # (N)
    pred_softmax_ens = pred_softmax[:data_num]
    for i in range(1, args.num_ens):
        pred_softmax_ens += pred_softmax[i*data_num:(i+1)*data_num]
    pred_softmax_ens /= args.num_ens
    pred_label = pred_softmax_ens.argmax(axis=1)   # (N)

    predict = labelset[pred_label]   # (N)

    entropy = Entropy_(pred_softmax_ens)
    lam = np.exp(-entropy / args.ltemp) / (1 + np.exp(-entropy / args.ltemp))
    lam = np.where(lam < args.lam_thres, lam, above_thres(args, lam))
    lam = lam.astype(np.float32)

    acc = np.sum(predict == all_label[:data_num].float().numpy()) / data_num

    #with open(f"./tsne/protos/{args.order_name}/epoch{epoch}_{dname}.pickle", "wb") as fw:
    #    pickle.dump(initc, fw)

    return predict.astype('int'), lam, acc, initc

def src_proto(args, loader, featurizer):
    start_test = True
    all_fea = []
    all_label = []
    with torch.no_grad():
        for i, data in enumerate(loader):
            inputs = data[0].to(args.device)
            labels = data[1]

            feas = featurizer(inputs)
       
            all_fea.append(feas.float().cpu())
            all_label.append(labels)

    all_fea = torch.cat(all_fea)
    all_label = torch.cat(all_label)

    all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1)   # (N, dim+1)
    if args.distance == 'cosine':    
        all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t()

    all_fea = all_fea.float().cpu().numpy()   # (N, dim+1)
    K = args.num_classes
    aff = np.eye(K)[all_label] 
    protos = aff.transpose().dot(all_fea)  # (C, dim+1) molecule of equation (4)
    protos = protos / (1e-8 + aff.sum(axis=0)[:,None])  # (C, dim+1) / (C, 1)   cluster center of each class

    return protos

def src_proto_trans_ens(args, loader, featurizer, epoch):
    start_test = True
    all_fea = []
    all_label = []
    with torch.no_grad():
        for j in range(args.num_ens):
            for i, data in enumerate(loader):
                inputs = data[0].to(args.device)
                labels = data[1]

                feas = featurizer(inputs)
        
                all_fea.append(feas.float().cpu())
                all_label.append(labels)

    all_fea = torch.cat(all_fea)
    all_label = torch.cat(all_label)

    all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1)   # (N, dim+1)

    #with open(f"./tsne/samples/{args.order_name}/epoch{epoch}_src.pickle", "wb") as fw:
    #    pickle.dump(all_fea, fw)

    with open(f"./tsne/labels/{args.order_name}/src.pickle", "wb") as fw:
        pickle.dump(all_label, fw)

    if args.distance == 'cosine':    
        all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t()

    all_fea = all_fea.float().cpu().numpy()   # (N, dim+1)
    K = args.num_classes
    aff = np.eye(K)[all_label] 
    protos = aff.transpose().dot(all_fea)  # (C, dim+1) molecule of equation (4)
    protos = protos / (1e-8 + aff.sum(axis=0)[:,None])  # (C, dim+1) / (C, 1)   cluster center of each class

    #with open(f"./tsne/protos/{args.order_name}/epoch{epoch}_src.pickle", "wb") as fw:
    #    pickle.dump(protos, fw)
    return protos

def Entropy_(input_):
    epsilon = 1e-5
    entropy = -input_ * np.log(input_ + epsilon)
    entropy = np.sum(entropy, axis=1)
    return entropy

def Entropy_torch(input_):
    epsilon = 1e-5
    entropy = -input_ * torch.log(input_ + epsilon)
    entropy = torch.sum(entropy, dim=1)
    return entropy    

def above_thres(args, lam):
    chosen = [np.random.uniform(0,1) for l in lam]
    return chosen

def ContrastiveLoss(args, mixed_feature, src_labels, protos_mean):
    ones = torch.ones(mixed_feature.size(0), 1).to(args.device)
    mixed_feature = torch.cat((mixed_feature, ones), 1)
    mixed_feature1 = mixed_feature.unsqueeze(1)
    protos_mean = torch.tensor(protos_mean).to(args.device)
    protos_mean1 = protos_mean.unsqueeze(0)
    if args.distance == 'cosine':    
        sim_matrix = nn.CosineSimilarity(dim=2)(mixed_feature1, protos_mean1)
    elif args.distance == 'euclidean':
        mixed_feature = mixed_feature.float()
        protos_mean = protos_mean.float()
        sim_matrix = -torch.cdist(mixed_feature, protos_mean, p=2)
        if args.fea_dim:
            sim_matrix = sim_matrix / math.sqrt(mixed_feature.shape[1]-1)
    elif args.distance == 'sqeuclidean':
        mixed_feature = mixed_feature.float()
        protos_mean = protos_mean.float()
        sim_matrix = torch.cdist(mixed_feature, protos_mean, p=2)
        sim_matrix = -(sim_matrix ** 2)
        if args.fea_dim:
            sim_matrix = sim_matrix / (mixed_feature.shape[1]-1)
    sim_matrix = sim_matrix / args.ctemp

    criterion = nn.CrossEntropyLoss()
    loss = criterion(sim_matrix, src_labels)

    return loss

def EMloss(args, logits):
    softmax_out = nn.Softmax(dim=1)(logits)
    entropy_loss = torch.mean(Entropy_torch(softmax_out))
    return entropy_loss

def LDloss(args, logits):
    epsilon = 1e-5
    softmax_out = nn.Softmax(dim=1)(logits)
    msoftmax = softmax_out.mean(dim=0)
    gentropy_loss = torch.sum(msoftmax * torch.log(msoftmax + epsilon)) + np.log(args.num_classes)
    return gentropy_loss

def lr_scheduler(optimizer, iter_num, max_iter, power=0.75, gamma=10):
    decay = (1 + gamma * iter_num / max_iter) ** (-power)
    for param_group in optimizer.param_groups:
        param_group['lr'] = param_group['lr0'] * decay
    return optimizer

def op_copy(optimizer):
    for param_group in optimizer.param_groups:
        param_group['lr0'] = param_group['lr']
    return optimizer