import torch
import torchvision

PARAMS_MNIST = {
    'dataset_name': 'MNIST',
    'n_data': 60_000,
    'shape': (1, 28, 28),
    'mean': (0.1307,),
    'std': (0.3081,),
    'root': '',
}

MNIST_MEAN = (0.1307,)
MNIST_STD = (0.3081,)

transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=MNIST_MEAN,
                                     std=MNIST_STD)])

def get_mnist(test=False):
    return torchvision.datasets.MNIST(
        root="",
        train=not test,
        download=True,
        transform=transform
    )

def get_mnist_data_loader(batch_size=128, test=False):
    return torch.utils.data.DataLoader(
        get_mnist(test),
        batch_size=batch_size,
        shuffle=not test)
