import sys
import os.path as osp
from torchvision.datasets import STL10

from dassl.utils import mkdir_if_missing


def extract_and_save_image(dataset, save_dir):
    if osp.exists(save_dir):
        print('Folder "{}" already exists'.format(save_dir))
        return

    print('Extracting images to "{}" ...'.format(save_dir))
    mkdir_if_missing(save_dir)

    for i in range(len(dataset)):
        img, label = dataset[i]
        if label == -1:
            label_name = "none"
        else:
            label_name = str(label)
        imname = str(i).zfill(6) + "_" + label_name + ".jpg"
        impath = osp.join(save_dir, imname)
        img.save(impath)


def download_and_prepare(root):
    train = STL10(root, split="train", download=True)
    test = STL10(root, split="test")
    unlabeled = STL10(root, split="unlabeled")

    train_dir = osp.join(root, "train")
    test_dir = osp.join(root, "test")
    unlabeled_dir = osp.join(root, "unlabeled")

    extract_and_save_image(train, train_dir)
    extract_and_save_image(test, test_dir)
    extract_and_save_image(unlabeled, unlabeled_dir)


if __name__ == "__main__":
    download_and_prepare(sys.argv[1])
