import argparse
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=10)
parser.add_argument('--lr', type=float, default=1e-3)
args = parser.parse_args()

transform = transforms.Compose([transforms.ToTensor()])
train_data = datasets.MNIST('../data', train=True, download=True, transform=transform)
test_data = datasets.MNIST('../data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_data, args.batch_size)
test_loader = DataLoader(test_data, args.batch_size)