from torch.utils.data import DataLoader

from .fair import FairDataset
from .mnist import MNIST
from .celeba import CelebA, CelebALandmarks
from .dummy import DummyDataset
from .svhn import SVHN
from .cifar10 import CIFAR10


def get_dataset(name, tag, **kwargs):
    if name == 'fair':
        return FairDataset(tag=tag, **kwargs)
    elif name == 'mnist':
        return MNIST(tag=tag, **kwargs)
    elif name == 'celeba':
        return CelebA(tag=tag if tag != 'val' else 'valid', **kwargs)
    elif name == 'dummy':
        return DummyDataset(tag=tag, **kwargs)
    elif name == 'svhn':
        return SVHN(tag=tag, **kwargs)
    elif name == 'cifar10':
        return CIFAR10(tag=tag, **kwargs)
    else:
        raise KeyError


def get_dataloader(name, batch_size, **kwargs):
    train_dataset = get_dataset(name, tag='train', **kwargs)
    val_dataset = get_dataset(name, tag='val', **kwargs)
    test_dataset = get_dataset(name, tag='test', **kwargs)

    dataloader_options = {
        'shuffle': True,
        'drop_last': False,
        'num_workers': 4
    }

    if isinstance(batch_size, int):
        batch_size = [batch_size, batch_size]

    loaders = {
        'train': DataLoader(train_dataset, batch_size=batch_size[0], **dataloader_options),
        'val': DataLoader(val_dataset, batch_size=batch_size[1], **dataloader_options),
        'test': DataLoader(test_dataset, batch_size=batch_size[1],  **dataloader_options),
    }

    return loaders


__all__ = ['get_dataloader']