import torch
from torch.utils.data import DataLoader, Subset, ConcatDataset
from torchvision import transforms
from torchvision.datasets import CIFAR10

from experiments.datasets import DataLoaders
from experiments.utils.train_validation_split import train_validation_split_different_transformer


def get_CIFAR10(root="./", *, augmentation: bool, normalize: bool):
    input_size = 32
    num_classes = 10
    if augmentation:
        train_transforms = [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ]
    else:
        train_transforms = []
    if normalize:
        train_transforms.append(transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)))
    train_transform = transforms.Compose(train_transforms)

    train_dataset = CIFAR10(root + "data/CIFAR10", train=True, transform=train_transform, download=True)

    test_transforms = [transforms.ToTensor()]
    if normalize:
        test_transforms.append(transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)))
    test_transform = transforms.Compose(test_transforms)

    validation_dataset = CIFAR10(root + "data/CIFAR10", train=True, transform=test_transform, download=True)
    test_dataset = CIFAR10(root + "data/CIFAR10", train=False, transform=test_transform, download=True)

    return input_size, num_classes, train_dataset, test_dataset, validation_dataset


device = torch.device("cuda")


def dataloaders(
        train_batch_size, test_batch_size, *, train_only: bool, augmentation: bool, normalize: bool = True,
        validation_size: int = 0, test_only: bool=False
) -> DataLoaders:
    _, num_classes, train_dataset, test_dataset, validation_dataset = get_CIFAR10(augmentation=augmentation,
                                                                                  normalize=normalize)

    assert not (train_only and test_only)
    if train_only:
        test_dataset = train_dataset
    elif test_only:
        train_dataset = test_dataset

    if validation_size > 0:
        train_dataset, validation_dataset = train_validation_split_different_transformer(
            num_classes, train_dataset, validation_dataset, validation_size
        )
    else:
        # Drop the validation dataset.
        validation_dataset = None

    num_workers = 4

    # For crazy testing
    if False:
        train_dataset = Subset(train_dataset, range(512))
        test_dataset = Subset(test_dataset, range(512))

    train_loader = DataLoader(
        train_dataset,
        batch_size=train_batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=num_workers,
        pin_memory=True,
    )
    train_eval_loader = DataLoader(
        train_dataset, batch_size=test_batch_size, shuffle=False, num_workers=num_workers, pin_memory=True
    )
    test_loader = DataLoader(
        test_dataset, batch_size=test_batch_size, shuffle=False, num_workers=num_workers, pin_memory=True
    )
    if validation_dataset:
        validation_loader = DataLoader(
            validation_dataset, batch_size=test_batch_size, shuffle=False, num_workers=num_workers, pin_memory=True
        )
    else:
        validation_loader = None

    return DataLoaders(train_loader, test_loader, train_eval_loader, validation_loader)



def combined_dataloader(train_batch_size, test_batch_size, *, augmentation: bool, normalize: bool = True) -> DataLoaders:
    _, num_classes, train_dataset, test_dataset, validation_dataset = get_CIFAR10(augmentation=augmentation,
                                                                                  normalize=normalize)

    num_workers = 4

    # For crazy testing
    if False:
        train_dataset = Subset(train_dataset, range(512))
        test_dataset = Subset(test_dataset, range(512))

    combined_dataset = ConcatDataset([train_dataset, test_dataset])

    train_loader = DataLoader(
        combined_dataset,
        batch_size=train_batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=num_workers,
        pin_memory=True,
    )
    train_eval_loader = DataLoader(
        combined_dataset, batch_size=test_batch_size, shuffle=False, num_workers=num_workers, pin_memory=True
    )
    test_loader = DataLoader(
        combined_dataset, batch_size=test_batch_size, shuffle=False, num_workers=num_workers, pin_memory=True
    )

    return DataLoaders(train_loader, test_loader, train_eval_loader)
