import os.path

from torchvision import datasets, transforms, models
import torch.utils.data

# download dataset from https://zenodo.org/record/6568778

class ImageFolderWithEmptyDirs(datasets.ImageFolder):
    """
    This is required for handling empty folders from the ImageFolder Class.
    """

    def find_classes(self, directory):
        classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
        if not classes:
            raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes) if
                        len(os.listdir(os.path.join(directory, cls_name))) > 0}
        return classes, class_to_idx


# extract and unzip the dataset, then write top folder here
dataset_folder = 'data/ImageNet-Patch'

available_labels = {
    487: 'cellular telephone',
    513: 'cornet',
    546: 'electric guitar',
    585: 'hair spray',
    804: 'soap dispenser',
    806: 'sock',
    878: 'typewriter keyboard',
    923: 'plate',
    954: 'banana',
    968: 'cup'
}

# select folder with specific target
target_label = 954

dataset_folder = os.path.join(dataset_folder, str(target_label))
normalizer = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225])
transforms = transforms.Compose([
    transforms.ToTensor(),
    normalizer
])

dataset = ImageFolderWithEmptyDirs(dataset_folder, transform=transforms)
model = models.resnet50(pretrained=True)
loader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=5)
model.eval()

batches = 10
correct, attack_success, total = 0, 0, 0
for batch_idx, (images, labels) in enumerate(loader):
    if batch_idx == batches:
        break
    pred = model(images).argmax(dim=1)
    correct += (pred == labels).sum()
    attack_success += sum(pred == target_label)
    total += pred.shape[0]

accuracy = correct / total
attack_sr = attack_success / total

print("Robust Accuracy: ", accuracy)
print("Attack Success: ", attack_sr)
