import torch
import numpy as np

import utils
from imagenet import get_x_y_from_data_dict


def validate(val_loader, model, criterion, args):
    """
    Run evaluation
    """
    losses = utils.AverageMeter()
    top1 = utils.AverageMeter()

    # switch to evaluate mode
    model.eval()

    predicts_all = np.array([])
    if args.imagenet_arch:
        device = (
            torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
        )
        for i, data in enumerate(val_loader):
            image, target = get_x_y_from_data_dict(data, device)
            with torch.no_grad():
                output = model(image)
                loss = criterion(output, target)

            output = output.float()
            loss = loss.float()

            # measure accuracy and record loss
            prec1 = utils.accuracy(output.data, target)[0]
            losses.update(loss.item(), image.size(0))
            top1.update(prec1.item(), image.size(0))
    else:
        for i, (image, target) in enumerate(val_loader):
            image = image.cuda()
            target = target.cuda()

            # compute output
            with torch.no_grad():
                output = model(image)
                loss = criterion(output, target)

            output = output.float()
            loss = loss.float()

            # measure accuracy and record loss
            # todo add predicts
            prec1, predicts = utils.accuracy_predicts(output.data, target)
            prec1 = prec1[0]
            losses.update(loss.item(), image.size(0))
            top1.update(prec1.item(), image.size(0))

            predicts_all = np.append(predicts_all, predicts)

    return top1.avg, predicts_all
