import time
import torch
import os
import math
import numpy as np
import torch.nn.functional as F
from torch import nn
from torch.autograd import Variable
# from utils.kernel_kmeans import KernelKMeans
import gc
# import ipdb

def train(train_loader_source, train_loader_source_batch, train_loader_target, train_loader_target_batch, model, learn_cen, learn_cen_2, criterion_cons, optimizer, itern, epoch, new_epoch_flag, src_cs, args, assigned_labels, TarLoss, **kwargs):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    top1_source = AverageMeter()
    losses = AverageMeter()
    # switch to train mode
    model.train()
#     setting lambda based on the epoch number
    if args.lambda_method == 'inv_dao':
        lam = 2 / (1 + math.exp(-1 * 10 * epoch / args.max)) - 1 
    elif args.lambda_method == 'inv_cosine':
        raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
    elif args.lambda_method == 'inv_exp':
        raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
    #lam = 1.0
    if args.src_cls:
        weight = lam
    else:
        raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
        
    if args.cons_loss_weight_one:
        raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
    else:
        cons_loss_weight = weight
#     setting lambda 3 for the self-supervised loss
    cons_loss_weight = cons_loss_weight * args.a3
    
#     adjust learning rate based on the epoch number
#     , apply different learning rates to different layers as discussed in the paper
    adjust_learning_rate(optimizer, epoch, args) 

    end = time.time()
    # prepare target data (The base augmented and the second augmented)
    try:
        if args.aug_tar_agree and (not args.gray_tar_agree):
            raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
        elif args.gray_tar_agree and (not args.aug_tar_agree):
            (input_target, input_target_gray, _target_target, indices) = train_loader_target_batch.__next__()[1]
        elif args.aug_tar_agree and args.gray_tar_agree:
            raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
        else:
            raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
    except StopIteration:
        train_loader_target_batch = enumerate(train_loader_target)
        if args.aug_tar_agree and (not args.gray_tar_agree):
            raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
        elif args.gray_tar_agree and (not args.aug_tar_agree):
            (input_target, input_target_gray, _target_target, indices) = train_loader_target_batch.__next__()[1]
        elif args.aug_tar_agree and args.gray_tar_agree:
            raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
        else:
            raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
    # print('indices:', indices)
    if args.tar_loss_idx == 0 or args.tar_loss_idx == 2: # all the situations that pseudo-labels are needed
        raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
    
        # print('target_target:', target_target.shape, target_target.min(), target_target.max())
        # time.sleep(3)

        target_target = target_target.cuda(non_blocking=True)
    else:
        target_target = None
    # print('(label, idx):', [(tar_label.item(), idx.item()) for tar_label, idx in zip(target_target, _)])
    input_target_var = Variable(input_target)
    target_target_var = Variable(target_target)
     
        
    if args.aug_tar_agree:
        raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
    if args.gray_tar_agree:
        input_target_gray_var = Variable(input_target_gray)
        
    if args.init_cen_on == 's':
        learn_cen_for_s_losses = learn_cen
        learn_cen_2_for_s_losses = learn_cen_2
        learn_cen_for_t_losses = learn_cen.detach().clone()
        learn_cen_2_for_t_losses = learn_cen_2.detach().clone()
    elif args.init_cen_on == 'st':
        raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
    elif args.init_cen_on == 't':
        raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
    
    # model forward on target (The base augmented and the second augmented)
    f_t, f_t_2, ca_t = model(input_target_var)
    if args.aug_tar_agree:
        raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
    if args.gray_tar_agree:
        f_t_gray, f_t_2_gray, ca_t_gray = model(input_target_gray_var)
 
    if args.ssbtconst:
        raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
    else:
        ca_t_ss = ca_t
        
    loss = 0

#         calculate the self-supervised loss

    if args.aug_tar_agree and (not args.gray_tar_agree):
        raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
    elif args.gray_tar_agree and (not args.aug_tar_agree):
        loss += cons_loss_weight * criterion_cons(ca_t_ss, ca_t_gray) / args.aug_denom
        # print('in 2')
    elif args.aug_tar_agree and args.gray_tar_agree:
        raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
    
    # calculate the adaptation loss and the regularization loss         
    loss += weight * TarLoss(args, epoch, ca_t, target_target, em=(args.cluster_method == 'em'))
    
    ca_t_const = ca_t.detach().clone()
    
    if args.learn_embed and (not args.learn_emb_ofsee or new_epoch_flag): # if args.learn_emb_ofsee be true then it see new_epoch_flag
        raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
  
    if args.learn_embed:
        raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
  
                
    if args.src_cls:
        # prepare source data
        try:
            (input_source, target_source, index) = train_loader_source_batch.__next__()[1]
        except StopIteration:
            train_loader_source_batch = enumerate(train_loader_source)
            (input_source, target_source, index) = train_loader_source_batch.__next__()[1]
        target_source = target_source.cuda(non_blocking=True)
        input_source_var = Variable(input_source)
        target_source_var = Variable(target_source)
        
        # model forward on source
        f_s, f_s_2, ca_s = model(input_source_var)
        prec1_s = accuracy(ca_s.data, target_source, topk=(1,))[0]
        top1_source.update(prec1_s.item(), input_source.size(0))

        # calculate the fidelity loss
        loss += SrcClassifyLoss(args, ca_s, target_source, index, src_cs, lam, fit=args.src_fit)
        
        if args.learn_embed and (not args.learn_emb_ofsee or new_epoch_flag): # if args.learn_emb_ofsee be true then it see new_epoch_flag
            raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
   

    losses.update(loss.data.item(), input_target.size(0))
    if args.test_it:
        raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
     
    # loss backward and network update
    model.zero_grad()
    loss.backward()
    if args.test_it:
        # learn_cen_grads = learn_cen_for_s_losses.grad.clone()
        # learn_cen_2_grads = learn_cen_2_for_s_losses.grad.clone()
        pass
    optimizer.step()
    
    if args.test_it and new_epoch_flag:
        raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
    
        
    batch_time.update(time.time() - end)
    if itern % args.print_freq == 0:
        print('Train - epoch [{0}/{1}]\t'
              'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
              'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
              'S@1 {s_top1.val:.3f} ({s_top1.avg:.3f})\t'
              'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(
               epoch, args.max, batch_time=batch_time,
               data_time=data_time, s_top1=top1_source, loss=losses))
        log = open(os.path.join(args.log, 'log.txt'), 'a')
        log.write("\nTrain - epoch: %d, top1_s acc: %3f, loss: %4f" % (epoch, top1_source.avg, losses.avg))
        log.close()
    if new_epoch_flag:
        print('The penalty weight is %3f' % weight)
    
    return train_loader_source_batch, train_loader_target_batch

def get_tar_loss(args):

    def TarDisClusterLoss(args, epoch, output, target, softmax=True, em=False, **kwargs):
       
        return loss

    def TarCombinedLoss1(args, epoch, output, target, softmax=True, em=None, **kwargs):
        if softmax:
            prob_p = F.softmax(output, dim=1)
        else:
            prob_p = output / output.sum(1, keepdim=True)
    #     increase confidence of target assigning to classes & make sure majority of targets are not assigned to a single class or only a few classes
        return EntropyLoss(prob_p) - EntropyLoss(prob_p.mean(0))

    def EntropyLoss(x):
        b = x * torch.log(x)
        b = b.sum(-1).mean()
        return -1 * b

    def TarCombinedLoss2(args, epoch, output, target, softmax=True, em=None, **kwargs):
        if softmax:
            prob_p = F.softmax(output, dim=1)
        else:
            prob_p = output / output.sum(1, keepdim=True)

        prob_q1 = Variable(torch.cuda.FloatTensor(prob_p.size()).fill_(0))
        prob_q1.scatter_(1, target.unsqueeze(1), torch.ones(prob_p.size(0), 1).cuda()) # assigned pseudo labels

        if (epoch == 0) or args.ao:
            prob_q = prob_q1
        else:
            prob_q2 = prob_p / prob_p.sum(0, keepdim=True).pow(0.5)
            prob_q2 /= prob_q2.sum(1, keepdim=True)
            prob_q = (1 - args.beta) * prob_q1 + args.beta * prob_q2

        if softmax:
            loss = - (prob_q * F.log_softmax(output, dim=1)).sum(1).mean()
        else:
            loss = - (prob_q * prob_p.log()).sum(1).mean()
    #     increase confidence of target assigning to classes & make sure majority of targets are not assigned to a single class or only a few classes
        return EntropyLoss(prob_p) + loss

    def TarEntropyLoss(args, epoch, output, target, softmax=True, em=None, **kwargs):
        if softmax:
            prob_p = F.softmax(output, dim=1)
        else:
            prob_p = output / output.sum(1, keepdim=True)
    #     increase confidence of target assigning to classes
        return EntropyLoss(prob_p)

    def TarCombinedLoss3(args, epoch, output, target, softmax=True, em=None, ca_t=None, **kwargs):
        if softmax:
            prob_p = F.softmax(output, dim=1)
        else:
            prob_p = output / output.sum(1, keepdim=True)
    #     increase confidence of target assigning to classes & make sure majority of targets are not assigned to a single class or only a few classes
        if ca_t == None:
            return EntropyLoss(prob_p) - EntropyLoss(prob_p.mean(0))
        else:
            return EntropyLoss(prob_p) - EntropyLoss(prob_p.mean(0)) + kl_loss(ca_t, output)

    def kl_loss(x, y):
        x = F.softmax(x, dim=1)
        y = F.log_softmax(y, dim=1)
        kl_div = F.kl_div(y, x, reduction='batchmean') #x 

        return kl_div


    def TarCombinedLoss4(args, epoch, output, target, softmax=True, em=None, ca_t=None, **kwargs):
        if softmax:
            prob_p = F.softmax(output, dim=1)
        else:
            prob_p = output / output.sum(1, keepdim=True)
    #     increase confidence of target assigning to classes & make sure majority of targets are not assigned to a single class or only a few classes
        if ca_t == None:
            return EntropyLoss(prob_p) - EntropyLoss(prob_p.mean(0))
        else:
            return 0.5 * EntropyLoss(prob_p) - EntropyLoss(prob_p.mean(0)) + 0.5 * kl_loss(ca_t, output)

    def TarCombinedLoss5(args, epoch, output, target, softmax=True, em=None, ca_t=None, **kwargs):
        if softmax:
            prob_p = F.softmax(output, dim=1)
        else:
            prob_p = output / output.sum(1, keepdim=True)
    #     increase confidence of target assigning to classes & make sure majority of targets are not assigned to a single class or only a few classes
        if ca_t == None:
            # add offset value to make it non negative (it has not any effect on training but only to make logged losses to non negetive values)
            return args.a1*EntropyLoss(prob_p) + args.a2*(-EntropyLoss(prob_p.mean(0)) + EntropyLoss(torch.ones(1,output.size(-1))/output.size(-1)))
        else:
            sample_entropy_p = EntropyLossNoReduction(prob_p)
            sample_entropy_p_detached = sample_entropy_p.detach()
            sample_entropy_ca_t = EntropyLossNoReduction(F.softmax(ca_t, dim=-1))
            sample_entropy_ca_t_detached = sample_entropy_ca_t.detach()
            sum_sample_entropy_detached = sample_entropy_p_detached + sample_entropy_ca_t_detached
            loss = (sample_entropy_ca_t_detached / sum_sample_entropy_detached) * sample_entropy_p + (sample_entropy_p_detached / sum_sample_entropy_detached) * kl_loss_no_reduction(ca_t, output)
            loss = loss.mean()
            loss +=  - EntropyLoss(prob_p.mean(0))
            return loss

    def EntropyLossNoReduction(x):
        b = x * torch.log(x)
        b = b.sum(-1)
        return -1 * b

    def kl_loss_no_reduction(x, y):
        x = F.softmax(x, dim=1)
        y = F.log_softmax(y, dim=1)
        kl_div = F.kl_div(y, x, reduction='none').sum(-1)  # No reduction
        return kl_div

    def TarClassMeanEntropyLoss(args, epoch, output, target, softmax=True, em=None, **kwargs):
        if softmax:
            prob_p = F.softmax(output, dim=1)
        else:
            prob_p = output / output.sum(1, keepdim=True)
    #     increase confidence of target assigning to classes
        return - EntropyLoss(prob_p.mean(0))

    def TarNothingLoss(args, epoch, output, target, softmax=True, em=None, **kwargs):
        return 0.0

    loss_list = [TarDisClusterLoss, TarCombinedLoss1, TarCombinedLoss2, TarEntropyLoss, TarCombinedLoss3, TarCombinedLoss4, TarCombinedLoss5, TarClassMeanEntropyLoss, TarNothingLoss]
    
    return loss_list[args.tar_loss_idx]
# define the fidelity loss (Methodology section of the paper)
def SrcClassifyLoss(args, output, target, index, src_cs, lam, softmax=True, fit=False):
    # Use F.cross_entropy with label smoothing if args.label_smoothing > 0
    if args.label_smoothing > 0:
        # Compute per-sample loss with no reduction
        loss_per_sample = F.cross_entropy(output, target, label_smoothing=args.label_smoothing, reduction='none')
        # Apply src_weights
        if args.src_mix_weight:
            raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
        else:
            src_weights = src_cs[index]
        # Weighted mean
        loss = (src_weights * loss_per_sample).mean()
        
    else:
        if softmax:
            prob_p = F.softmax(output, dim=1)
        else:
            prob_p = output / output.sum(1, keepdim=True)
        prob_q = Variable(torch.cuda.FloatTensor(prob_p.size()).fill_(0))
        prob_q.scatter_(1, target.unsqueeze(1), torch.ones(prob_p.size(0), 1).cuda())
        if fit:
            prob_q = (1 - prob_p) * prob_q + prob_p * prob_p    
        if args.src_mix_weight:
            raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
        else:
            src_weights = src_cs[index]

        if softmax:
            loss = - (src_weights * (prob_q * F.log_softmax(output, dim=1)).sum(1)).mean()
        else:
            loss = - (src_weights * (prob_q * prob_p.log()).sum(1)).mean()
    
    return loss

def aug_nearing_loss(args, first_emb, second_emb):
    raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper

def validate(val_loader, model, criterion, epoch, args):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    
    # switch to evaluate mode
    model.eval()
    
    total_vector = torch.FloatTensor(args.num_classes).fill_(0)
    correct_vector = torch.FloatTensor(args.num_classes).fill_(0)
    
    end = time.time()
    for i, (input, target, _) in enumerate(val_loader):
        target = target.cuda(non_blocking=True)
        input_var = Variable(input)
        target_var = Variable(target)

        # forward
        with torch.no_grad():
            _, _, output = model(input_var)
            loss = criterion(output, target_var)

        # compute and record loss and accuracy
        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
        total_vector, correct_vector = accuracy_for_each_class(output.data, target, total_vector, correct_vector) # compute class-wise accuracy
        losses.update(loss.data.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))
        top5.update(prec5.item(), input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Test on T test set - [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                   epoch, i, len(val_loader), batch_time=batch_time, 
                   loss=losses, top1=top1, top5=top5))

    acc_for_each_class = 100.0 * correct_vector / total_vector
    print(' * Test on T test set - Prec@1 {top1.avg:.3f}, Prec@5 {top5.avg:.3f}'.format(top1=top1, top5=top5))
    log = open(os.path.join(args.log, 'log.txt'), 'a')
    log.write("\n             Test on T test set - epoch: %d, loss: %4f, Top1 acc: %3f, Top5 acc: %3f" % (epoch, losses.avg, top1.avg, top5.avg))

    log.close()
    return top1.avg

    
def validate_compute_cen(val_loader_target, val_loader_source, model, criterion, epoch, args, compute_cen=True):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    
    # switch to evaluate mode
    model.eval()

    # compute source class centroids
    # source_features = torch.cuda.FloatTensor(len(val_loader_source.dataset.imgs), 2048).fill_(0)
    # source_features_2 = torch.cuda.FloatTensor(len(val_loader_source.dataset.imgs), args.num_neurons*4).fill_(0)
    # source_targets = torch.cuda.LongTensor(len(val_loader_source.dataset.imgs)).fill_(0)
    # c_src = torch.cuda.FloatTensor(args.num_classes, 2048).fill_(0)
    # c_src_2 = torch.cuda.FloatTensor(args.num_classes, args.num_neurons*4).fill_(0)
    # count_s = torch.cuda.FloatTensor(args.num_classes, 1).fill_(0)
    source_features = torch.zeros(len(val_loader_source.dataset.imgs), 2048, dtype=torch.float32, device='cuda')
    source_features_2 = torch.zeros(len(val_loader_source.dataset.imgs), args.num_neurons * 4, dtype=torch.float32, device='cuda')
    source_targets = torch.zeros(len(val_loader_source.dataset.imgs), dtype=torch.long, device='cuda')
    c_src = torch.zeros(args.num_classes, 2048, dtype=torch.float32, device='cuda')
    c_src_2 = torch.zeros(args.num_classes, args.num_neurons * 4, dtype=torch.float32, device='cuda')
    count_s = torch.zeros(args.num_classes, 1, dtype=torch.float32, device='cuda')

    if compute_cen:
        raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper

    target_features = torch.zeros(len(val_loader_target.dataset.imgs), 2048, dtype=torch.float32, device='cuda')
    target_features_2 = torch.zeros(len(val_loader_target.dataset.imgs), args.num_neurons * 4, dtype=torch.float32, device='cuda')
    target_targets = torch.zeros(len(val_loader_target.dataset.imgs), dtype=torch.long, device='cuda')
    pseudo_labels = torch.zeros(len(val_loader_target.dataset.imgs), args.num_classes, dtype=torch.float32, device='cuda')
    c_tar = torch.zeros(args.num_classes, 2048, dtype=torch.float32, device='cuda')
    c_tar_2 = torch.zeros(args.num_classes, args.num_neurons * 4, dtype=torch.float32, device='cuda')
    count_t = torch.zeros(args.num_classes, 1, dtype=torch.float32, device='cuda')

    total_vector = torch.zeros(args.num_classes, dtype=torch.float32)
    correct_vector = torch.zeros(args.num_classes, dtype=torch.float32)

    
    end = time.time()
    for i, (input, target, index) in enumerate(val_loader_target): # the iterarion in the target dataset
        data_time.update(time.time() - end)
        target = target.cuda(non_blocking=True)
        input_var = Variable(input)
        target_var = Variable(target)
        
        with torch.no_grad():
            feature, feature_2, output = model(input_var)
        
        # target_features[index.cuda(), :] = feature.data.clone() # index:a tensor 
        # if args.nclayers == 2:
        #     target_features_2[index.cuda(), :] = feature_2.data.clone()
        target_targets[index.cuda()] = target.clone()
        pseudo_labels[index.cuda(), :] = output.data.clone()
            
        if compute_cen: # compute target class centroids
            raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper

        # compute and record loss and accuracy
        loss = criterion(output, target_var)
        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
        total_vector, correct_vector = accuracy_for_each_class(output.data, target, total_vector, correct_vector) # compute class-wise accuracy
        losses.update(loss.data.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))
        top5.update(prec5.item(), input.size(0))
        
        batch_time.update(time.time() - end)
        end = time.time()
        if i % args.print_freq == 0:
            print('Test on T training set - [{0}][{1}/{2}]\t'
                  'T {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'D {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'T@1 {tc_top1.val:.3f} ({tc_top1.avg:.3f})\t'
                  'T@5 {tc_top5.val:.3f} ({tc_top5.avg:.3f})\t'
                  'L {tc_loss.val:.4f} ({tc_loss.avg:.4f})'.format(
                   epoch, i, len(val_loader_target), batch_time=batch_time,
                   data_time=data_time, tc_top1=top1, tc_top5=top5, tc_loss=losses))

    # compute global class centroids
    # c_srctar = torch.cuda.FloatTensor(args.num_classes, 2048).fill_(0)
    # c_srctar_2 = torch.cuda.FloatTensor(args.num_classes, args.num_neurons*4).fill_(0)
    c_srctar = torch.zeros(args.num_classes, 2048, dtype=torch.float32, device='cuda')
    c_srctar_2 = torch.zeros(args.num_classes, args.num_neurons * 4, dtype=torch.float32, device='cuda')
    if (args.cluster_method == 'spherical_kmeans'):
        raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper

    else:
        c_srctar = (c_src + c_tar) / (count_s + count_t)
        c_srctar_2 = (c_src_2 + c_tar_2) / (count_s + count_t)
        c_src /= count_s
        c_src_2 /= count_s
        c_tar /= (count_t + args.eps)
        if args.nclayers == 2:
            c_tar_2 /= (count_t + args.eps)
        
    acc_for_each_class = 100.0 * correct_vector / total_vector
    
    print(' * Test on T training set - Prec@1 {tc_top1.avg:.3f}, Prec@5 {tc_top5.avg:.3f}'.format(tc_top1=top1, tc_top5=top5))

    log = open(os.path.join(args.log, 'log.txt'), 'a')
    log.write("\nTest on T training set - epoch: %d, tc_loss: %4f, tc_Top1 acc: %3f, tc_Top5 acc: %3f" % (epoch, losses.avg, top1.avg, top5.avg))
    

    log.close()
    return top1.avg, c_src, c_src_2, c_tar, c_tar_2, c_srctar, c_srctar_2, source_features, source_features_2, source_targets, target_features, target_features_2, target_targets, pseudo_labels


def source_select(source_features, source_targets, target_features, pseudo_labels, train_loader_source, epoch, cen, args):
    raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper    

def kernel_k_means(target_features, target_targets, pseudo_labels, train_loader_target, epoch, model, args, best_prec, change_target=True):
    raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper

def k_means(target_features, target_targets, train_loader_target, epoch, model, c, args, best_prec, change_target=True):
    raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
    
def spherical_k_means(target_features, target_targets, train_loader_target, epoch, model, c, args, best_prec, change_target=True):
    raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def adjust_learning_rate(optimizer, epoch, args):
    """adjust learning rate based on the epoch number, apply different learning rates to different layers as discussed in the paper"""
    if args.lr_plan == 'cosine':  
        raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
    elif args.lr_plan == 'exp':
        raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
    elif args.lr_plan == 'linear':  
        raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
    elif args.lr_plan == 'step':
        raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
    elif args.lr_plan == 'dao':
        lr = args.lr / math.pow((1 + 10 * epoch / args.max), 0.75)
    elif args.lr_plan == 'nothing':
        raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
    else:
        raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
        raise ValueError(f'{args.lr_plan} is not a valid lr_plan.')
        
    for param_group in optimizer.param_groups:
        if param_group['name'] == 'conv':
            param_group['lr'] = lr * 0.1
        elif param_group['name'] == 'ca_cl':
            param_group['lr'] = lr
        else:
            raise ValueError('The required parameter group does not exist.')


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
        
    return res


def accuracy_for_each_class(output, target, total_vector, correct_vector):
    """Computes the precision for each class"""
    batch_size = target.size(0)
    _, pred = output.topk(1, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1)).float().cpu().squeeze()
    for i in range(batch_size):
        total_vector[target[i]] += 1
        correct_vector[torch.LongTensor([target[i]])] += correct[i]
    
    return total_vector, correct_vector

