#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @python: 3.6

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader


def test_img(net_g, datatest, args):
    net_g.eval()
    # testing
    test_loss = 0
    correct = 0
    data_loader = DataLoader(datatest, batch_size=args.bs, num_workers=args.num_workers)
    l = len(data_loader)
    for idx, (data, target) in enumerate(data_loader):
        if int(args.gpu) != -1:
            data, target = data.to(args.device), target.to(args.device)
        data = data.to(torch.float64)
        # We need to do the one_hot_encoding when using linear regression.
        #target_one_hot_encoded = nn.functional.one_hot(target, 10)
        #target_one_hot_encoded = target_one_hot_encoded.to(torch.float64)
        log_probs = net_g(data)
        # sum up batch loss
        test_loss += F.cross_entropy(log_probs, target, reduction='sum').item()

        # get the index of the max log-probability
        y_pred = log_probs.data.max(1, keepdim=True)[1]

        # y_pred = torch.round(log_probs)
        correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()

    test_loss /= len(data_loader.dataset)
    accuracy = 100.00 * correct / len(data_loader.dataset)
    if args.verbose:
        print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format(
            test_loss, correct, len(data_loader.dataset), accuracy))
    return accuracy.item(), test_loss

