from __future__ import print_function, division

import sys
import time
import torch
from .util import AverageMeter, accuracy
from einops import reduce, repeat
import copy

def validate(val_loader, model, criterion, opt, mask_list = None):
    """validation"""
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()
    mask_list_copy = copy.deepcopy(mask_list)
    with torch.no_grad():
        end = time.time()
        for idx, (input, target) in enumerate(val_loader):

            input = input.float()
            if torch.cuda.is_available():
                input = input.cuda()
                target = target.cuda()
                if mask_list is not None:
                    mask_list = [mask.cuda() for mask in mask_list]

            features = []
            
            # compute output
            if mask_list == None:
                output, _ = model(input, features, is_feat = False)
            else:
                for mask_index, mask in enumerate(mask_list):
                    mask_list[mask_index] = repeat(mask_list[mask_index], 'c h w-> b c h w', b = list(input.shape)[0]) # b c h w

                output, _ = model(input, mask_list, features, is_feat=False)
                mask_list = copy.deepcopy(mask_list_copy)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(acc1, input.size(0))
            top5.update(acc5, input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if idx % opt.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                       idx, len(val_loader), batch_time=batch_time, loss=losses,
                       top1=top1, top5=top5))

        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
              .format(top1=top1, top5=top5))

    return top1.avg, top5.avg, losses.avg, mask_list_copy