import logging
import random
import torch
import torch.utils.data
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
random.seed(2022)

def eval_acc(net, testloader,criteria,device):
    net.eval()
    with torch.no_grad():
        test_acc = 0
        num_batch = 0

        for batch in testloader:
            num_batch += 1
            # batch = next(iter(testloader))
            img, label = tuple(t.to(device) for t in batch)
            pred, _ = net(img)
            test_loss = criteria(pred, label)
            test_acc += pred.argmax(1).eq(label).sum().item() / len(label)
        mean_test_loss = test_loss / num_batch
        mean_test_acc = test_acc / num_batch
    return mean_test_loss, mean_test_acc