import torch
from tqdm import tqdm
from torch.functional import F


def evaluate(model_1, dataloader, device):
    n_correctly_classified, total = 0, 0
    try:
        # model_1.eval()
        with torch.no_grad():
            for batch_idx, (batch) in enumerate(tqdm(dataloader)):
                # images, labels = batch[0], batch[1]
                # images, labels = images.to(device), labels.to(device)
                images, labels = batch['data'], batch['label']
                images, labels = images[0].to(device), labels[0].to(device)

                pred = model_1(images)
                pred = pred.argmax(dim=1)


                n_correctly_classified += (pred == labels).sum()
                total += labels.shape[0]

        print(f"Accuracy on {total} examples:  {n_correctly_classified / total}")
    except RuntimeError as e:
        print("Error in evaluation: ", e)
        return None
    return n_correctly_classified / total
