import os
import torchvision
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10, MNIST, FashionMNIST
import torchvision.transforms as transforms

def Dataset(dataset_name, batch_size=128, DATA_DIR='datasets'):

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    if dataset_name == 'cifar10':
        print('Dataset: CIFAR10.')
        trainset = CIFAR10(root=DATA_DIR, train=True, download=True, transform=transforms.Compose([
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomCrop(32, 4),
                    transforms.ToTensor(),
                    normalize]))


        testset = CIFAR10(root=DATA_DIR, train=False, download=True, transform=transforms.Compose([
                    transforms.ToTensor(),
                    normalize]))
    elif dataset_name == 'fashion_mnist':
        print('Dataset: FashionMNIST.')
        trainset = FashionMNIST(root=DATA_DIR, train=True, download=True, transform=transforms.Compose([
                    transforms.Grayscale(3),
                    transforms.Resize(32),
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomCrop(32, 4),
                    transforms.ToTensor(),
                    normalize]))


        testset = FashionMNIST(root=DATA_DIR, train=False, download=True, transform=transforms.Compose([
                    transforms.Grayscale(3),
                    transforms.Resize(32),
                    transforms.ToTensor(),
                    normalize]))
    else:
        raise NotImplementedError("Only cifar10, mnist, fashion mnist are allowed.")






    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)
    testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4)
    return trainloader, testloader
