import torchvision
import torchvision.transforms as transforms
import torch
import os

# Define the data directory
data_dir = os.path.join(os.path.dirname(__file__), 'data')

# Ensure the data directory exists
os.makedirs(data_dir, exist_ok=True)

def download_all_datasets():
    """ Download all datasets and store them in the 'data/' directory. """
    
    print("Downloading CIFAR-10...")
    torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True)
    torchvision.datasets.CIFAR10(root=data_dir, train=False, download=True)

    print("Downloading CIFAR-100...")
    torchvision.datasets.CIFAR100(root=data_dir, train=True, download=True)
    torchvision.datasets.CIFAR100(root=data_dir, train=False, download=True)

    print("Downloading MNIST...")
    torchvision.datasets.MNIST(root=data_dir, train=True, download=True)
    torchvision.datasets.MNIST(root=data_dir, train=False, download=True)

    print("All datasets have been successfully downloaded!")

# Run the function to download all datasets
if __name__ == "__main__":
    download_all_datasets()