import torch
import torch.nn.functional as F
from tqdm.auto import tqdm


def test_unlearn_classification(model, device, test_loader):
    model.to(device)
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in tqdm(test_loader, total=len(test_loader)):
            data, target = data.to(device), target.to(device)
            out = model(data)
            test_loss += F.cross_entropy(out.logits, target, reduction="sum").item()
            pred = out.logits.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100.0 * correct / len(test_loader.dataset)

    print(
        "\nTest set: Average loss: {}, Accuracy: {}/{} ({}%)\n".format(
            test_loss,
            correct,
            len(test_loader.dataset),
            accuracy,
        )
    )
    return test_loss, accuracy


def test_classification_model(
    model,
    test_loader,
    loss_fn,
    device,
):
    model.to(device)
    model.eval()
    losses = []
    correct = 0
    with torch.no_grad():
        for data, target in tqdm(test_loader, total=len(test_loader)):
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = loss_fn(output.logits, target)
            losses.append(loss)
            pred = output.logits.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    total_loss = torch.mean(torch.tensor(losses)).item()
    accuracy = 100.0 * correct / len(test_loader.dataset)
    print(
        "\nTest set: Average loss: {}, Accuracy: {}/{} ({}%)\n".format(
            total_loss,
            correct,
            len(test_loader.dataset),
            accuracy,
        )
    )
    return total_loss, accuracy
