import sys
import os.path as osp
from torchvision.datasets import SVHN, CIFAR10, CIFAR100

from dassl.utils import mkdir_if_missing


def extract_and_save_image(dataset, save_dir):
    if osp.exists(save_dir):
        print('Folder "{}" already exists'.format(save_dir))
        return

    print('Extracting images to "{}" ...'.format(save_dir))
    mkdir_if_missing(save_dir)

    for i in range(len(dataset)):
        img, label = dataset[i]
        class_dir = osp.join(save_dir, str(label).zfill(3))
        mkdir_if_missing(class_dir)
        impath = osp.join(class_dir, str(i + 1).zfill(5) + ".jpg")
        img.save(impath)


def download_and_prepare(name, root):
    print("Dataset: {}".format(name))
    print("Root: {}".format(root))

    if name == "cifar10":
        train = CIFAR10(root, train=True, download=True)
        test = CIFAR10(root, train=False)
    elif name == "cifar100":
        train = CIFAR100(root, train=True, download=True)
        test = CIFAR100(root, train=False)
    elif name == "svhn":
        train = SVHN(root, split="train", download=True)
        test = SVHN(root, split="test", download=True)
    else:
        raise ValueError

    train_dir = osp.join(root, name, "train")
    test_dir = osp.join(root, name, "test")

    extract_and_save_image(train, train_dir)
    extract_and_save_image(test, test_dir)


if __name__ == "__main__":
    download_and_prepare("cifar10", sys.argv[1])
    download_and_prepare("cifar100", sys.argv[1])
    download_and_prepare("svhn", sys.argv[1])
