import os
import torch
from torch.utils.data import random_split
from torchvision import transforms, datasets

def get_cifar10_loaders(batch_size=256, augment=True, shuffle_train=True,
                        cifar10_dataset_class=datasets.CIFAR10,):
    data_mean = (0.4914, 0.4822, 0.4465)
    data_stddev = (0.2023, 0.1994, 0.2010)

    data_augmentation_transforms = [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
    ] if augment else []


    transform_train = transforms.Compose([
        *data_augmentation_transforms,
        transforms.ToTensor(),
        transforms.Normalize(data_mean, data_stddev),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(data_mean, data_stddev),
    ])

    original_train_dataset = cifar10_dataset_class(root=os.path.join('data', 'cifar10_data'),
                                              train=True, transform=transform_train, download=True)
    original_test_dataset = cifar10_dataset_class(root=os.path.join('data', 'cifar10_data'),
                                             train=False, transform=transform_test, download=True)

    loader_args = {
        "batch_size": batch_size,
    }
    train_loader_args = dict(loader_args)
    train_loader_args["shuffle"] = shuffle_train
    train_loader = torch.utils.data.DataLoader(
        dataset=original_train_dataset,

        **train_loader_args)

    test_loader = torch.utils.data.DataLoader(
        dataset=original_test_dataset,
        shuffle=False,
        **loader_args)

    return {"train_loader": train_loader,
            "test_loader": test_loader,
            "num_classes": 10}


def get_cifar100_loaders(batch_size=256, augment=True, shuffle_train=True):
    data_mean = (0.5071, 0.4867, 0.4408)
    data_stddev = (0.2675, 0.2565, 0.2761)

    data_augmentation_transforms = [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
    ] if augment else []

    transform_train = transforms.Compose([
        *data_augmentation_transforms,
        transforms.ToTensor(),
        transforms.Normalize(data_mean, data_stddev),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(data_mean, data_stddev),
    ])

    original_train_dataset = datasets.CIFAR100(root=os.path.join('data', 'cifar100_data'),
                                              train=True, transform=transform_train, download=True)
    original_test_dataset = datasets.CIFAR100(root=os.path.join('data', 'cifar100_data'),
                                             train=False, transform=transform_test, download=True)

    loader_args = {
        "batch_size": batch_size,
    }

    train_loader = torch.utils.data.DataLoader(
        dataset=original_train_dataset,
        shuffle=shuffle_train,
        **loader_args)

    test_loader = torch.utils.data.DataLoader(
        dataset=original_test_dataset,
        shuffle=False,
        **loader_args)

    return {"train_loader": train_loader,
            "test_loader": test_loader,
            "num_classes": 100}


dataset_factories = {
    'cifar10': get_cifar10_loaders,
    'cifar100': get_cifar100_loaders,
}

def get_available_datasets():
    return dataset_factories.keys()


def get_dataset(name, *args, **kwargs):
    return dataset_factories[name](*args, **kwargs)
