import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Subset
from torch.utils.data import DataLoader


def get_dataset(batch_size, max_perc=1.0, augmentation=False, download=False, n_workers=6):
    simple_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])
    if augmentation:
        aug_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4, padding_mode='edge'),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])
    else:
        aug_transform = simple_transform

    # Train - Val - Test
    trainset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=download, transform=aug_transform
    )
    n_train = len(trainset)
    indices = list(range(n_train))
    split = int(n_train - (n_train * 0.1))
    limited_split = round(max_perc*split)
    trainset = Subset(trainset, indices[:limited_split])
    train_loader = DataLoader(
        trainset, batch_size=batch_size, shuffle=True, num_workers=n_workers
    )

    valset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=False, transform=simple_transform
    )

    valset = Subset(valset, indices[split:])
    val_loader = DataLoader(
        valset, batch_size=batch_size, shuffle=False, num_workers=n_workers
    )

    testset = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=download, transform=simple_transform
    )
    test_loader = DataLoader(
        testset, batch_size=batch_size, shuffle=False, num_workers=n_workers
    )

    classes = (
        'plane', 'car', 'bird', 'cat',
        'deer', 'dog', 'frog', 'horse', 'ship', 'truck'
    )
    n_classes = len(classes)
    loaders = {
        "train": train_loader,
        "val": val_loader,
        "test": test_loader
    }
    return loaders, n_classes

if __name__=="__main__":
    loaders = get_dataset(50000, download=True)[0]
    batch = list(b for b in loaders["train"])[0][0]
    batch = batch.to("cuda:0")
    import time
    time.sleep(5)
    print(batch.shape)
