import os
from torchvision import transforms, datasets
from shutil import copyfile


def _imagenet(test_dir = './imagenet/val/'):
    
    transform = transforms.Compose([
        transforms.Scale(256),
        transforms.CenterCrop(224),
        transforms.ToTensor()
    ])
    return datasets.ImageFolder(test_dir, transform)

test_dataset = _imagenet(test_dir = './imagenet/val/')

#%% create class dirs
def create_classDirs(dest_dir):
    if not os.path.exists(dest_dir):
        os.mkdir(dest_dir)
    for i in range(len(test_dataset.classes)):
        _dir = dest_dir+test_dataset.classes[i]+'/'
        if not os.path.exists(_dir):
            os.mkdir(_dir)

def save_images(dest_dir, interval):
    for i in range(0,len(test_dataset.imgs), interval):
        src = test_dataset.imgs[i+1][0]
        dst = dest_dir + src.split(test_dir)[1]
        copyfile(src, dst)


test_dir = './imagenet/val/'  ## source directory
attack_dir = './imagenet/val_attack/'
certify_dir = './imagenet/val_certify/'

#%% -----------------> Selecting 2000 images for experiments on adversarial attack

create_classDirs(attack_dir)
save_images(attack_dir, 25)

#%% -----------------> Selecting 500 images for experiments on L2 certification
create_classDirs(certify_dir)
save_images(certify_dir, 500)
