import presets
import torch
import torch.utils.data
import torchvision
from val_robust import utils


def evaluate(net, data_loader):
    net.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = f"Test:imgnet-sk"

    num_processed_samples = 0
    with torch.inference_mode():
        for image, target in metric_logger.log_every(data_loader, 500, header):
            image = image.to('cuda', non_blocking=True)
            target = target.to('cuda', non_blocking=True)
            output = net(image)

            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))

            batch_size = image.shape[0]
            metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
            metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
            num_processed_samples += batch_size

    print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}")
    return metric_logger.acc1.global_avg


def load_data(valdir):

    preprocessing = presets.ClassificationPresetEval(
        crop_size=224, resize_size=256,
    )

    dataset_test = torchvision.datasets.ImageFolder(
        valdir,
        preprocessing,
    )

    test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    return dataset_test, test_sampler


def eval_imgnet_sk(net,batch_size,num_workers,location):


    net.cuda()
    net.eval()

    dataset_test, test_sampler = load_data(location)
    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=batch_size, sampler=test_sampler, num_workers=num_workers, pin_memory=True
    )

    print("*************ImageNet-sketch Results*****************")
    acc = evaluate(net, data_loader_test)
    print("*************ImageNet-sketch Results*****************")
    return acc
