from torchvision.datasets import MNIST, CIFAR10, CIFAR100
from torchvision import transforms

from .bregmnist import BregMNIST
from .bregcifar import BregCIFAR10
from .cifar_clust import ClustCIFAR
from .cropdist import CropDist

DATA_PATH = '/home/XXXX/data'

dataset_lookup = {
    'bregmnist': BregMNIST,
    'bregcifar': BregCIFAR10,
    'clustcifar': ClustCIFAR,
    'mnist': MNIST,
    'cifar10': CIFAR10,
    'cifar100': CIFAR100,
}


def get_dataset(dataset):

    if dataset == 'cifar10':
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        train_ds = CIFAR10(DATA_PATH, train=True, download=True,
            transform=train_transform)
        test_ds = CIFAR10(DATA_PATH, train=False,
            transform=test_transform)

    elif dataset == 'cifar100':
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        train_ds = CIFAR100(DATA_PATH, train=True, download=True,
            transform=train_transform)
        test_ds = CIFAR100(DATA_PATH, train=False,
            transform=test_transform)

    return train_ds, test_ds
