import torchvision
from torch.utils.data import DataLoader

from mind_the_pad.paths import dataset_folder


def preprocessor_mnist():
    return torchvision.transforms.Compose(
        [
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize((0.1307,), (0.3081,))
        ]
    )


def letters_mnist_train_dl(batch_size, preprocessing_image=None):
    train_dataset = letters_mnist_train_dataset(preprocessing_image)
    train_dl = DataLoader(train_dataset, batch_size=batch_size)
    return train_dl


def letters_mnist_train_dataset(preprocessing_image=None):
    if preprocessing_image is None:
        preprocessing_image = preprocessor_mnist()
    train_dataset = torchvision.datasets.EMNIST(dataset_folder, split='letters', train=True,
                                                download=True, transform=preprocessing_image)
    return train_dataset


def letters_mnist_test_dl(batch_size, preprocessing_image=None):
    test_dataset = letters_mnist_test_dataset(preprocessing_image)
    test_dl = DataLoader(test_dataset, batch_size=batch_size)
    return test_dl


def letters_mnist_test_dataset(preprocessing_image=None):
    if preprocessing_image is None:
        preprocessing_image = preprocessor_mnist()
    test_dataset = torchvision.datasets.EMNIST(dataset_folder, split='letters', train=False,
                                               download=True, transform=preprocessing_image)
    return test_dataset