__author__ = 'Qi'
# Created by on 12/3/21.
import torch, copy, time
import numpy as np
from myutils import adjust_lr_lambda, AverageMeter, accuracy, save_checkpoint_epoch
import torch.nn as nn
import torch.nn.functional as F
from sklearn import  metrics
from torch.autograd import Variable
import wandb
from mydataset import get_imbalanced_dataset
from preprocess import get_transform_medium_scale_data
import math
import os
from myDataLoader import myDataLoader_imagenet, get_train_val_test_feature_loader, get_train_val_test_loader
# from main import validate
from mydataset import featLT

def FastDRO(args, model_new, results):
    wandb.init(config=args, project="SCCMA", entity="qiqi-helloworld")

    torch.manual_seed(args.random_seed)
    np.random.seed(args.random_seed)

    train_loader, val_loader, test_loader = get_train_val_test_loader(args, None)
    # if args.dataset == 'iNaturalist18' or args.dataset == 'imagenet-LT':
    #    train_loader, val_loader, test_loader = get_train_val_test_feature_loader(args)
    ivd_criterion = nn.CrossEntropyLoss(reduction='none')
    CE_criterion = nn.CrossEntropyLoss()
    optimizerW = torch.optim.SGD(model_new.parameters(), lr=args.lr, momentum=0.9,  weight_decay=args.weight_decay)
    # momentum=0.9,


    best_test_acc1 = 0
    start_time = time.time()
    for epoch in range(args.resumed_epoch, args.epochs):
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()
        model_new.train()
        adjust_curlr_beta(epoch, args, optimizerW)

        for batch_idx, (_, inputs, targets) in enumerate(train_loader):
            # print(inputs.size(), targets.size())
            inputs, targets= inputs.cuda(), targets.cuda()
            # print(inputs.size())
            outputs, _ = model_new(inputs)
            ivd_loss = ivd_criterion(outputs, targets)
            max_loss = torch.max(ivd_loss).item()
            lamda = binary_search_FastDRO_lambda(args, ivd_loss)
            # print(lamda)
            stb_loss = ivd_loss - max_loss
            exploss = torch.exp(stb_loss / (args.lamda0 + lamda))
            p = exploss/torch.sum(exploss)
            p.detach_()

            droloss=torch.sum(p*ivd_loss)

            model_new.zero_grad()
            droloss.backward()
            optimizerW.step()
            acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
            loss = CE_criterion(outputs, targets)
            # print("Loss Size: ", loss.size())
            losses.update(loss.item(), inputs.size(0))
            top1.update(acc1, inputs.size(0))
            top5.update(acc5, inputs.size(0))

            if batch_idx % args.print_freq == 0:
                if args.epochs <= 10:
                    train_loss, train_prec1, train_prc5, train_auc_score = validate(args, train_loader, model_new, CE_criterion, epoch)
                    val_loss, val_prec1, val_prec5, val_auc_score = validate_cifar_val_loader(args, val_loader, model_new, CE_criterion, epoch)
                    print('iter acc1', epoch*len(train_loader)+batch_idx, 4*len(train_loader), train_prec1, val_prec1)
                    wandb.log({'iter acc1': train_prec1, 'iter val acc1': val_prec1}, step=epoch*len(train_loader)+batch_idx)
                else:
                    print('Epoch: [{0}][{1}/{2}]\t'
                             'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                             'Train Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                             'Train Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                    epoch, batch_idx, len(train_loader),
                    loss=losses, top1=top1, top5=top5))


        if args.epochs > 10:
            train_loss, train_prec1, train_prec5, train_auc_score = validate(args, train_loader, model_new, CE_criterion, epoch)
            if 'cifar' in args.dataset:
                val_loss, val_prec1, val_prec5, val_auc_score = validate_cifar_val_loader(args, val_loader, model_new, CE_criterion, epoch)
            else:
                val_loss, val_prec1, val_prec5, val_auc_score = validate(args, val_loader, model_new, CE_criterion, epoch)

            if test_loader is not None:
                test_loss, test_prec1, test_prec5, test_auc_score = validate(args, test_loader, model_new, CE_criterion, epoch)


            overall_running_time = (time.time() - start_time) // 60
            best_test_acc1 = max(best_test_acc1, val_prec1) if test_loader is None else max(best_test_acc1, test_prec1)
            results.add(epoch=epoch, val_loss=val_loss,
                            train_prec1=train_prec1, val_prec1=val_prec1,
                            train_prec5=train_prec5, val_prec5=val_prec5,
                            test_prec1 = test_prec1 if test_loader is not None else val_prec1,
                            test_prec5 = test_prec5 if test_loader is not None else val_prec5,
                            overall_running_time=overall_running_time)
            results.save()

            ##### Print on the screen.
            if epoch % 2 == 0:
                output = ('Train: [{0}/{1}], lr: {lr:.5f}\t'
                              'Train Loss {train_loss:.4f} Val Loss {val_loss:.4f}\t'
                              'Train Prec@1 {train_prec1:.3f} Val Prec@1 {val_prec1:.3f} \t'
                              'Train Prec@5 {train_prec5:.3f} Val Prec@5 {val_prec5:.3f}'.format(
                        epoch, args.epochs, train_loss=train_loss, val_loss=val_loss,
                        train_prec1=train_prec1, val_prec1=val_prec1, train_prec5=train_prec5, val_prec5=val_prec5,
                        lr=args.curlr))
                print(output)
                print("Lambda Variable value: ", str(args.lamda))
                print('Total number of running time is {overall_running_time:.3f}'.format(
                        overall_running_time=overall_running_time))

            wandb.log({"lr": args.curlr, 'Optimized Lambda Variable': args.lamda}, step=epoch)
            wandb.log({"train loss": train_loss, 'train acc1': train_prec1, 'train acc5': train_prec5}, step=epoch)
            wandb.log({"test loss": val_loss, 'test acc1': val_prec1 if test_loader is None else test_prec1, 'test acc5': val_prec5 if test_loader is None else test_prec1}, step=epoch)
            wandb.log({"best test acc": best_test_acc1, 'beta': args.curbeta}, step=epoch)

def PDSGDDRO(args, model_new, results):
    '''
    rho0: constrained
    lrp: learning rates for updating lrp
    return: save the records of the blue.
    '''

    wandb.init(config=args, project="SCCMA", entity="qiqi-helloworld")

    torch.manual_seed(args.random_seed)
    np.random.seed(args.random_seed)
    train_loader, val_loader, test_loader = get_train_val_test_loader(args, None)
    # if args.dataset == 'iNaturalist18' or args.dataset == 'imagenet-LT':
    #    train_loader, val_loader, test_loader = get_train_val_test_feature_loader(args)
    numRounds = len(train_loader) * args.epochs
    data_len = len(train_loader.dataset)
    number_of_runs_per_epoch = len(train_loader)

    if 'cifar' in args.dataset:
        P_length = 50000
    else:
        P_length = len(train_loader.dataset)


    indices = None
    if 'cifar' in args.dataset:
        indices = train_loader.dataset.indices
        print('length  of training data:', len(indices))

    optimizerW = torch.optim.SGD(model_new.parameters(), lr=args.lr,  momentum=0.9, weight_decay=args.weight_decay)
    global DROP
    DROP = torch.tensor([1/ data_len] * P_length)
    uniP = torch.tensor([1/ data_len] * P_length)
    DROP = DROP.cuda()
    uniP = uniP.cuda()

    if args.gpus and len(args.gpus) >= 1:
        model_new = torch.nn.DataParallel(model_new)
    CE_criterion = nn.CrossEntropyLoss()
    criterion_DRO = nn.CrossEntropyLoss(reduction='none')
    CE_criterion.type(args.type)
    model_new.type(args.type)
    print("len train loader:", len(train_loader))
    best_prec1, lamda =  0, 0
    rounds, epoch = 0, 0
    while rounds <= numRounds:

        # adjust_lr_lambda(args, rounds //number_of_runs_per_epoch, optimizerW)
        print(rounds)
        if rounds % number_of_runs_per_epoch == 0:
            print('{0}/{1}, {2} epochs finished'.format(rounds, numRounds, rounds //number_of_runs_per_epoch))

        adjust_curlr_beta(epoch, args, optimizerW)
        model_new.train()
        for batch_idx, (minibatchIdx, inputs, targets) in enumerate(train_loader):

            rounds += 1
            otherbatchIdx = list(set(list(range(P_length))) - set(minibatchIdx.tolist()))
            inputs, targets = inputs.cuda(), targets.cuda()
            outputs, _ = model_new(inputs)
            loss_DRO = criterion_DRO(outputs, targets)
            # if epoch <= 90:
            #    loss = torch.mean(loss_DRO)
            # else:
                #unbiased_loss_DRO = loss_DRO * data_len / len(minibatchIdx)
                # tmp_p = DROP0[minibatchIdx]* data_len / len(minibatchIdx)

            unbiased_loss_DRO = loss_DRO * data_len / len(minibatchIdx)
            loss = torch.sum(DROP[minibatchIdx]* unbiased_loss_DRO)
            model_new.zero_grad()
            loss.backward()
            optimizerW.step()  # Updates W
            # if epoch > 90:
            DROP, lamda = PD_updatesP(args, DROP, uniP, minibatchIdx, otherbatchIdx,
                                      unbiased_loss_DRO, indices)  # UpdatesP, with close-form
            if args.epochs <= 10:
                if rounds % args.print_freq == 0:
                    train_loss, train_prec1, train_prc5, train_auc_score = validate(args, train_loader, model_new,
                                                                                CE_criterion, epoch)
                    val_loss, val_prec1, val_prec5, val_auc_score = validate_cifar_val_loader(args, val_loader, model_new, CE_criterion, epoch)
                    print('iter acc1', rounds, numRounds, train_prec1, val_prec1)
                    wandb.log({'iter acc1': train_prec1, 'iter val acc1': val_prec1}, step=rounds)


        epoch += 1
        if args.epochs > 10:
            print(">>>>>> train :", len(train_loader), '>>>>>>>> val :', len(val_loader), len(val_loader.dataset))
            train_loss, train_prec1, train_prec5, train_auc_score = validate(args,
            train_loader, model_new, CE_criterion, epoch)

            # print('Starting testing:')
            if 'cifar' in args.dataset:
                val_loss, val_prec1, val_prec5, val_auc_score = validate_cifar_val_loader(args,
                            val_loader, model_new, CE_criterion, epoch)
            else:
                val_loss, val_prec1, val_prec5, val_auc_score = validate(args, val_loader, model_new,
                                                                                     CE_criterion, epoch)

            if test_loader is not None:
                test_loss, test_prec1, test_prec5, test_auc_score = validate(args, test_loader, model_new, CE_criterion,
                                                                                     P_length // (args.batch_size  * rounds + 1))

            overall_training_time = time.time() - args.start_training_time
            results.add(epoch=(rounds * 128) // P_length + 1,
                                train_loss = train_loss, val_loss=val_loss,
                                train_prec1=train_prec1, val_prec1=val_prec1, test_prec1 = test_prec1 if test_loader is not None else val_prec1,
                                train_prec5 = train_prec5, val_prec5=val_prec5, test_prec5 = test_prec5 if test_loader is not None else val_prec5,
                                overall_training_time=overall_training_time, auc_score=train_auc_score)
            results.save()


            tmp_test_prec = val_prec1 if test_loader is None else test_prec1
            if best_prec1 < tmp_test_prec:
                best_prec1 = tmp_test_prec

            output = ('Train: [{0}/{1}]\t'
                                'Epoch: {2}\t'
                                'Train Loss {train_loss:.4f} Val Loss {val_loss:.4f}\t'
                                'Train Prec@1 {train_prec1:.3f} Val Prec@1 {val_prec1:.3f}  Test Prec@1 {test_prec1:.3f} \t'
                                'Train Prec@5 {train_prec5:.3f} Val Prec@5 {val_prec5:.3f} Test Prec@5 {test_prec5:.3f}\t'
                                'Best Eval Prec@1 {best_prec1:.3f}'.format(
                        rounds, numRounds, epoch, train_loss=train_loss, val_loss=val_loss, test_prec1 = val_prec1 if test_loader is None else test_prec1,
                        test_prec5 = val_prec5 if test_loader is None else test_prec5,
                        train_prec1=train_prec1, val_prec1=val_prec1, train_prec5=train_prec5, val_prec5=val_prec5,
                        best_prec1=best_prec1))
            wandb.log({"lr": args.curlr, 'Optimized Lambda Variable': args.lamda}, step=epoch)
            wandb.log({"train loss": train_loss, 'train acc1': train_prec1, 'train acc5': train_prec5}, step=epoch)
            wandb.log({"test loss": val_loss, 'test acc1': val_prec1 if test_loader is None else test_prec1, 'test acc5': val_prec5 if test_loader is None else test_prec5}, step=epoch)
            wandb.log({"best test acc": best_prec1, 'beta': args.curbeta}, step=epoch)
            wandb.log({'lambda constrained': lamda}, step=epoch)

            print(output)


def FastDROLambdamax(args, ivd_loss):
    lamda = 0
    while True:
        if FastDRO_KL(args, ivd_loss, lamda) < args.rho:
            return lamda
        else:
            lamda =  lamda * 2 + 1

def binary_search_FastDRO_lambda(args, ivd_loss):
    rlambda = FastDROLambdamax(args, ivd_loss)
    llambda = rlambda /2

    if FastDRO_KL(args, ivd_loss, rlambda) <  args.rho:
        return  rlambda

    while True:
        # FastDRO_KL(args, ivd_loss, mid)
        # print('r:', r)
        mid = (rlambda + llambda)/2
        # print(">>>>:", FastDRO_KL(args, ivd_loss, mid))
        if FastDRO_KL(args, ivd_loss, mid) > args.rho + 1e-3:
            llambda = mid
        elif FastDRO_KL(args, ivd_loss, mid) < args.rho - 1e-3:
            rlambda = mid
        else:
            return mid

    return rlambda

def FastDRO_KL(args, ivd_loss, lamda):
    max_loss = torch.max(ivd_loss).item()
    stb_loss = ivd_loss - max_loss
    exploss = torch.exp(stb_loss / (args.lamda0 + lamda))
    p = exploss / torch.sum(exploss)
    p.detach_()
    uniP = torch.ones_like(exploss) / args.batch_size
    kl_val = F.kl_div(p.log(), uniP, reduction='sum')
    return kl_val



def PD_updatesP(args, DROP, uniP, minibatchIdx, otherbatchIdx, loss, indices = None):

    max_loss = torch.max(loss).detach()
    stb_loss = loss - max_loss

    disparity_with_uniform = PD_KL(args, DROP, uniP, minibatchIdx, otherbatchIdx, 0, loss, indices)
    if math.isinf(disparity_with_uniform):
        # we delay the current P to the next iteration
        # print("Inf scenario: lambda value {} and KL Divergence is {:.3f}".format(0, F.kl_div(DROP.log(), uniP, reduction='sum')))
        return copy.deepcopy(DROP), 0


    if PD_KL(args, DROP, uniP, minibatchIdx, otherbatchIdx, 0, loss, indices) <= args.rho:
        lamda = 0
    else:
        lamda = binary_search_PD_lambda(args, DROP, uniP, minibatchIdx, otherbatchIdx, loss, indices)

    # ttt =1/((lamda+args.lamda0) * args.plr + 1)
    # print( "lambda:", lamda,  ((DROP[otherbatchIdx]*torch.exp(0-args.plr*max_loss)).pow(ttt))[0], DROP[otherbatchIdx][0]*(torch.exp(0-args.plr*max_loss)).pow(ttt))

    newP = torch.ones_like(DROP)
    # p_i = p_t^i*\exp(lr*\hat{loss})^{1/((\lambda+\lambda0) +1)}
    newP[minibatchIdx] = (DROP[minibatchIdx] * torch.exp(args.plr*stb_loss)).pow(
        1 / ((lamda+args.lamda0) * args.plr + 1))
    newP[otherbatchIdx] = (DROP[otherbatchIdx]*torch.exp(0-args.plr*max_loss)).pow(
        1 / ((lamda+args.lamda0) * args.plr + 1))
    newP.detach_()


    if indices is not None:
        newP[indices] = newP[indices]/torch.sum(newP[indices])
    # print("lambda value {} and KL Divergence is {:.3f}".format(lamda, F.kl_div(DROP.log(), uniP, reduction='sum')))
    return copy.deepcopy(newP), lamda

def binary_search_PD_lambda(args, P, uniP, minibatchIdx, otherbatchIdx, loss, indices):
    l, r = 0, 1
    # determine the interval between the lambda lowerbound and the lambda upperbound

    while PD_KL(args, P, uniP, minibatchIdx, otherbatchIdx, r, loss, indices) > args.rho:
        l = r
        r = 2*r + 1
        # print(r, "<<<<<:", PD_KL(args, P, uniP, minibatchIdx, otherbatchIdx, r, loss))

    while True:
        # print('r:', r)
        mid = (l+r)/2
        if PD_KL(args, P, uniP, minibatchIdx, otherbatchIdx, mid, loss, indices) > args.rho + 1e-3:
            l = mid
        elif PD_KL(args, P, uniP, minibatchIdx, otherbatchIdx, mid, loss, indices) < args.rho - 1e-3:
            r = mid
        else:
            return  mid

    return r

def PD_KL(args, P, uniP, minibatchIdx, otherbatchIdx, lamda, loss, indices):
    max_loss = torch.max(loss).detach()
    stb_loss = loss - max_loss
    newP = torch.zeros_like(P)
    newP[minibatchIdx] = (P[minibatchIdx]*torch.exp(args.plr*stb_loss)).pow(1/((lamda + args.lamda0)*args.plr+1))
    newP[otherbatchIdx] = (P[otherbatchIdx]*torch.exp(-args.plr*max_loss)).pow(
        1 / ((lamda + args.lamda0) * args.plr + 1))

    if indices is not None:
        newP[indices] = newP[indices]/torch.sum(newP[indices])
        return F.kl_div(newP[indices].log(), uniP[indices], reduction='sum')
    else:
        newP = newP / torch.sum(newP)
        return F.kl_div(newP.log(), uniP, reduction='sum')


    # print(">>>>:", F.kl_div(newP.log(), uniP, reduction = 'sum'))


def forward(args, data_loader, model_new, criterion, epoch=0, training=True, optimizer=None, var = False):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()
    auc_score = 0
    if args.auc:
        target_list = np.array([])
        pred_target_list = np.array([])
    torch.manual_seed(777+epoch)

    for i, (_, inputs, target) in enumerate(data_loader):

        # print('val dataset:', inputs.size(), target.size())
        data_time.update(time.time() - end)
        if training and i == 0:
            pass

        if args.gpus is not None:
            target = target.cuda()
        input_var = Variable(inputs.type(args.type), volatile=not training)
        target_var =  Variable(target)
        output, _ = model_new(input_var)
        loss = criterion(output, target_var)

        if args.auc:
            target_list = np.concatenate((target_list, target.cpu().numpy()))
            _, pred_target = torch.max(output.data, 1)
            pred_target_list = np.concatenate((pred_target_list , pred_target.cpu().numpy()))


        if type(output) is list:
            output = output[0]
        prec1,prec5= accuracy(output.data, target, topk=(1, 5))

        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))

        if training:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()  # Updated new

        batch_time.update(time.time() - end)
        end = time.time()

    if args.auc:
        fpr, tpr, threshold = metrics.roc_curve(target_list, pred_target_list, pos_label=1)
        #print("target_list:", target_list[0:5], "pred_target_list:", pred_target_list[0:5])
        auc_score = metrics.auc(fpr, tpr)
    return losses.avg, top1.avg, top5.avg, auc_score



def forward_cifar_val_loader(args, data_loader, model_new, criterion, epoch=0, training=True, optimizer=None, var = False):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()
    auc_score = 0
    if args.auc:
        target_list = np.array([])
        pred_target_list = np.array([])
    torch.manual_seed(777+epoch)

    for i, (inputs, target) in enumerate(data_loader):
        data_time.update(time.time() - end)
        if training and i == 0:
            pass

        if args.gpus is not None:
            target = target.cuda()
        input_var = Variable(inputs.type(args.type), volatile=not training)
        target_var =  Variable(target)
        output, _ = model_new(input_var)
        loss = criterion(output, target_var)
        if args.auc:
            target_list = np.concatenate((target_list, target.cpu().numpy()))
            _, pred_target = torch.max(output.data, 1)
            pred_target_list = np.concatenate((pred_target_list , pred_target.cpu().numpy()))


        if type(output) is list:
            output = output[0]
        prec1,prec5= accuracy(output.data, target, topk=(1, 5))

        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))

        if training:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()  # Updated new

        batch_time.update(time.time() - end)
        end = time.time()

        # if i % args.print_freq == 0 or i == len(data_loader) - 1:
        #     print('{phase} - Epoch: [{0}][{1}/{2}]\t'
        #                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
        #                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
        #                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
        #                      epoch, i, len(data_loader),
        #                      phase='TRAINING' if training else 'EVALUATING',
        #                       loss=losses, top1=top1, top5=top5))

    if args.auc:
        fpr, tpr, threshold = metrics.roc_curve(target_list, pred_target_list, pos_label=1)
        #print("target_list:", target_list[0:5], "pred_target_list:", pred_target_list[0:5])
        auc_score = metrics.auc(fpr, tpr)
    return losses.avg, top1.avg, top5.avg, auc_score

def validate_cifar_val_loader(args, data_loader, model_new, criterion, epoch):
    model_new.eval()
    return forward_cifar_val_loader(args, data_loader, model_new, criterion, epoch,
                   training=False, optimizer=None, var = False)

def validate(args, data_loader, model_new, criterion, epoch):
    model_new.eval()
    return forward(args, data_loader, model_new, criterion, epoch,
                   training=False, optimizer=None, var = False)


def MBSGD(args, model_new, results):

    torch.manual_seed(args.random_seed)
    np.random.seed(args.random_seed)

    train_loader, val_loader, test_loader = get_train_val_test_loader(args, None)
    CE_criterion = nn.CrossEntropyLoss()
    optimizerW = torch.optim.SGD(model_new.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)
    wandb.init(config=args, project="SCCMA", entity="qiqi-helloworld")

    train_loss, train_prec1, train_prec5 = 0, 0, 0
    best_acc1 = 0
    start_time = time.time()
    mvg_g_obj, mvg_grad_lambda = 0, 0

    for epoch in range(args.resumed_epoch, args.epochs):
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()
        model_new.train()
        adjust_curlr_beta(epoch, args, optimizerW)
        for batch_idx, (_, inputs, targets) in enumerate(train_loader):

            inputs, targets = inputs.cuda(), targets.cuda()
            outputs, _ = model_new(inputs)
            loss = CE_criterion(outputs, targets)
            model_new.zero_grad()
            loss.backward()
            optimizerW.step()
            acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
            # print("Loss Size: ", loss.size())
            losses.update(loss.item(), inputs.size(0))
            top1.update(acc1, inputs.size(0))
            top5.update(acc5, inputs.size(0))

            if batch_idx == len(train_loader) - 1:
                train_loss, train_prec1, train_prc5 = losses.avg, top1.avg, top5.avg
                # print('train_prec1:%.3f'%train_prec1, top1.avg)

        train_loss, train_prec1, train_prec5, train_auc_score = validate(args, train_loader, model_new, CE_criterion,
                                                                         epoch)

        if 'cifar' in args.dataset:
            val_loss, val_prec1, val_prec5, val_auc_score = validate_cifar_val_loader(args, val_loader, model_new,
                                                                                      CE_criterion, epoch)
        else:
            val_loss, val_prec1, val_prec5, val_auc_score = validate(args, val_loader, model_new, CE_criterion, epoch)

        if test_loader is not None:
            test_loss, test_prec1, test_prec5, test_auc_score = validate(args, test_loader, model_new, CE_criterion,
                                                                         epoch)
        overall_running_time = (time.time() - start_time) // 60
        is_best = True if val_prec1 >= best_acc1 else False
        best_acc1 = max(best_acc1, val_prec1)
        results.add(epoch=epoch, val_loss=val_loss,
                    train_prec1=train_prec1, val_prec1=val_prec1 if test_loader is None else test_prec1,
                    train_prec5=train_prec5, val_prec5=val_prec5 if test_loader is None else test_prec1,
                    overall_running_time=overall_running_time)
        results.save()

        ##### Print on the screen.
        output = ('Train: [{0}/{1}], lr: {lr:.5f}\t'
                    'Train Loss {train_loss:.4f} Val Loss {val_loss:.4f}\t'
                    'Train Prec@1 {train_prec1:.3f} Val Prec@1 {val_prec1:.3f} \t'
                    'Train Prec@5 {train_prec5:.3f} Val Prec@5 {val_prec5:.3f}'.format(
            epoch, args.epochs, train_loss=train_loss, val_loss=val_loss,
            train_prec1=train_prec1, val_prec1=val_prec1 if test_loader is None else test_prec1, train_prec5=train_prec5, val_prec5=val_prec5 if test_loader is None else test_prec1,
            lr=args.curlr))
        print(output)
        print("Lambda Variable value: ", str(args.lamda))
        print('Total number of running time is {overall_running_time:.3f}'.format(
                overall_running_time=overall_running_time))

        wandb.log({"lr": args.curlr, 'Optimized Lambda Variable': args.lamda}, step=epoch)
        wandb.log({"train loss": train_loss, 'train acc1': train_prec1, 'train acc5': train_prec5}, step=epoch)
        wandb.log({"test loss": val_loss, 'test acc1': val_prec1, 'test acc5': val_prec5}, step=epoch)
        wandb.log({"best test acc": best_acc1, 'beta': args.curbeta}, step=epoch)



        # if epoch == args.epochs-1:
        #     torch.save({'model': model_new.state_dict(), 'epoch': epoch}, os.path.join('./models/'+ args.dataset + '_last.pth'))

def adjust_curlr_beta(epoch, args, optimizer = None):
    epoch = epoch + 1
    if args.epochs == 3:
        args.curbeta = args.beta
        if epoch > 2:
            args.curlr = args.lr * 0.1
        else:
            args.curlr = args.lr

    elif args.epochs == 2:
        args.curbeta = args.beta
        if epoch > 1:
            args.curlr = args.lr * 0.1
        else:
            args.curlr = args.lr
    elif args.epochs == 30 or args.epochs == 40:
        args.curbeta = args.beta
        if epoch >= 20:
            args.curlr = args.lr * 0.1
        else:
            args.curlr = args.lr
    elif args.epochs == 60:
        args.curbeta = args.beta

        if epoch <= 5:
            args.curlr = args.lr * epoch /5
        elif epoch >= 30:
            args.curlr = args.lr * 0.1
        else:
            args.curlr = args.lr

    elif args.epochs == 90:

        if epoch > 60:
            args.curlr = args.lr * 0.1
            args.curbeta = args.beta
        else:
            args.curlr = args.lr
            args.curbeta = args.beta

    elif args.epochs == 120 or args.epochs == 150:
        args.curbeta = args.beta
        if epoch <= 5:
            args.curlr = args.lr * epoch / 5
        elif epoch > 90:
            args.curlr = args.lr * 0.1
        else:
            args.curlr = args.lr

    elif args.epochs == 200:
        if epoch <= 5:
            args.curlr = args.lr * epoch / 5
            args.beta = args.beta
        elif epoch >= 180:
            args.curlr = args.lr * 0.0001
            args.beta = args.beta * 0.0001
        elif epoch >= 160:
            args.curlr = args.lr * 0.01
            args.beta = args.beta * 0.01
        else:
            args.curlr = args.lr
            args.beta = args.beta

    else:
       args.curlr = args.lr

    if optimizer is not None:
        for param_group in optimizer.param_groups:
            param_group['lr'] = args.curlr
