import sys
import fire
from torchvision import datasets

def download_cifar10(loc):
    """Download CIFAR10 to the given location. Analogous to scripts/mnist_data.py"""
    datasets.CIFAR10(loc, download = True, train = True)
    datasets.CIFAR10(loc, download = True, train = False)


if __name__ == "__main__":
    fire.Fire(download_cifar10)
