import torch
import numpy as np

import torchvision
import torchvision.transforms as transforms


DATA_DESC = {
    'data': 'mnist',
    'classes': ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9'),
    'num_classes': 10,
    'w':28,
    'h':28,
    'd':1
}


def load_mnist(data_dir, use_augmentation=False, validation = False):
    """
    Returns CIFAR10 train, test datasets and dataloaders.
    Arguments:
        data_dir (str): path to data directory.
        use_augmentation (bool): whether to use augmentations for training set.
    Returns:
        train dataset, test dataset. 
    """
    test_transform = transforms.Compose([transforms.ToTensor()])
    train_transform = test_transform
    
    train_dataset = torchvision.datasets.MNIST(root=data_dir, train=True, download=True, transform=train_transform)
    test_dataset = torchvision.datasets.MNIST(root=data_dir, train=False, download=True, transform=test_transform)
    if validation:
        dataset_size = len(train_dataset)
        indices = list(range(dataset_size))
        split = 1024
        np.random.shuffle(indices)
        train_indices, val_indices = indices[split:], indices[:split]

        # Creating PT data samplers and loaders:
        train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indices)
        val_sampler = torch.utils.data.sampler.SubsetRandomSampler(val_indices)
        return train_dataset, test_dataset, train_sampler, val_sampler
    return train_dataset, test_dataset