import os
from filelock import FileLock
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


def get_cifar_loaders(fpath, num_class, batch_size):
    """Get CIFAR-10 dataloader.

    Args:
        fpath: str, path to CIFAR10 dataset.
        batch_size: int, batch size.
    """
    # Acknowledgement: https://github.com/kuangliu/pytorch-cifar
    transform_train = transforms.Compose(
        [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
    )

    transform_test = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
    )
    if num_class == 10:
        data = datasets.CIFAR10
    elif num_class == 100:
        data = datasets.CIFAR100
    with FileLock(os.path.expanduser(fpath + ".lock")):
        train_loader = DataLoader(
            data(
                fpath, train=True, download=True, transform=transform_train
            ),
            batch_size=batch_size,
            shuffle=True,
        )
    test_loader = DataLoader(
        data(fpath, train=False, transform=transform_test),
        batch_size=batch_size,
        shuffle=True,
    )

    return train_loader, test_loader


def get_mnist_loaders(fpath, batch_size):
    """Get MNIST dataloader.

    Args:
        fpath: str, path to MNIST dataset.
        batch_size: int, batch size.
    """
    mnist_transforms = transforms.Compose(
        [
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
        ]
    )

    with FileLock(os.path.expanduser(fpath + ".lock")):
        train_loader = DataLoader(
            datasets.MNIST(
                fpath, train=True, download=True, transform=mnist_transforms
            ),
            batch_size=batch_size,
            shuffle=True,
        )
    test_loader = DataLoader(
        datasets.MNIST(fpath, train=False, transform=mnist_transforms),
        batch_size=batch_size,
        shuffle=True,
    )
    return train_loader, test_loader

def get_celeba_loaders(fpath, batch_size):
    """Get CelebA dataloader.

    Args:
        fpath: str, path to CelebA dataset.
        batch_size: int, batch size.
    """
    SetRange = transforms.Lambda(lambda X: 2 * X - 1.)
    SetScale = transforms.Lambda(lambda X: X / X.sum(0).expand_as(X))

    celeba_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
                                        transforms.CenterCrop(148),
                                        transforms.Resize(64),
                                        transforms.ToTensor(),
                                        SetRange])

    with FileLock(os.path.expanduser(fpath + ".lock")):
        train_loader = DataLoader(
            datasets.CelebA(
                fpath, split='train', download=True, transform=celeba_transforms
            ),
            batch_size=batch_size,
            shuffle=True,
            drop_last=True,
        )
    test_loader = DataLoader(
        datasets.CelebA(fpath, split='test', transform=celeba_transforms),
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
    )
    return train_loader, test_loader
