import time
import torch
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
from metrics.myAUC import AUCMeter
from utils.utils import AverageMeter
from metrics.accuracy import accuracy

def validate(val_loader, model, criterion_cls, criterion_gcn, args):
    batch_time = AverageMeter()
    losses_cls = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    # switch to evaluate mode
    model.eval()
    eval_auc = AUCMeter()
    end = time.time()
    with torch.no_grad():
        for i, (input, target, gcn_target, id_tar) in enumerate(val_loader):

            target_var = target.cuda(non_blocking=True)
            input_var = input.cuda(non_blocking=True)

            output = model(input_var, id_tar, 'val')

            loss_cls = criterion_cls(output[-1], target_var)
            acc1, acc5 = accuracy(output[-1], target_var, topk=(1, 1))
            # AUC
            needata = output[-1]
            _, predi = needata.topk(1, 1, True, True)
            predi = predi.view(len(predi))
            losses_cls.update(loss_cls.item(), input.size(0))
            #losses_gcn.update(loss_gcn.item(), input.size(0))
            eval_auc.update(predi, target_var)
            top1.update(acc1[0], input[0].size(0))
            top5.update(acc5[0], input[0].size(0))

            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss_cls {loss_cls.val:.4f} ({loss_cls.avg:.4f})\t'
                      'AUC {AUC}\t'
                      'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                    i, len(val_loader), batch_time=batch_time, loss_cls=losses_cls, AUC=eval_auc.get_auc(), top1=top1))


    print(' * Acc@1 {top1.avg:.3f} AUC {AUC}'
          .format(top1=top1, AUC=eval_auc.get_auc()))

    return top1.avg, eval_auc.get_auc()