import os
import tempfile

import torchvision
from tqdm.auto import tqdm

CLASSES = (
    'horse',
    'ape',
    'ship',
    'truck',
    'plane',
    'bird',
    'car',
    'cat',
    'deer',
    'dog'
)


def main():
    # for split in ["train", "test"]:
    # out_dir = f"stl_{split}"
    out_dir = '../stl'
    # if os.path.exists(out_dir):
    #     print(f"skipping split {split} since {out_dir} already exists.")
    #     continue

    print("downloading...")
    with tempfile.TemporaryDirectory() as tmp_dir:
        dataset = torchvision.datasets.STL10(
            root=tmp_dir, split='train', download=True
        )
        testset = torchvision.datasets.STL10(
            root=tmp_dir, split='test', download=True
        )


    print("dumping images...")
    os.mkdir(out_dir)
    for i in tqdm(range(len(dataset))):
        image, label = dataset[i]
        filename = os.path.join(out_dir, f"{CLASSES[label]}_{i:05d}.png")
        image.save(filename)

    for i in tqdm(range(len(testset))):
        image, label = testset[i]
        filename = os.path.join(out_dir, f"{CLASSES[label]}_{(i + len(dataset)):05d}.png")
        image.save(filename)




if __name__ == "__main__":
    main()
