import torch
from torchvision import datasets, transforms


def get_MNIST_loaders(data_dir, batch_size, download=True):
    transform = transforms.Compose([transforms.Resize((16, 16)),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (0.5,))])
    train_set = datasets.MNIST(root=data_dir, train=True, transform=transform, download=download)
    valid_set = datasets.MNIST(root=data_dir, train=False, transform=transform, download=False)
    trainloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
    testloader = torch.utils.data.DataLoader(valid_set, batch_size=batch_size, shuffle=False)
    return trainloader, testloader


def get_CIFAR10_loader(data_dir, batch_size, download=True):
    cifar_norm_mean = (0.49139968, 0.48215827, 0.44653124)
    cifar_norm_std = (0.24703233, 0.24348505, 0.26158768)
    transform_train = transforms.Compose([transforms.RandomHorizontalFlip(),
                                          transforms.Resize((32,32)),
                                          transforms.ToTensor(),
                                          transforms.Normalize(cifar_norm_mean, cifar_norm_std)])
    transform_valid = transforms.Compose([transforms.Resize((32,32)),
                                          transforms.ToTensor(),
                                          transforms.Normalize(cifar_norm_mean, cifar_norm_std)])
    train_set = datasets.CIFAR10(root=data_dir, train=True, transform=transform_train, download=download)
    valid_set = datasets.CIFAR10(root=data_dir, train=False, transform=transform_valid, download=False)
    trainloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
    testloader = torch.utils.data.DataLoader(valid_set, batch_size=batch_size, shuffle=False)
    return trainloader, testloader
