import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Subset
from torch.utils.data import DataLoader


def get_dataset(batch_size, max_perc=1.0, augmentation=False, download=False, n_workers=2):
    simple_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4865, 0.4409),
                             (0.2673, 0.2564, 0.2762))
    ])
    if augmentation:
        aug_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4, padding_mode='edge'),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4865, 0.4409),
                                 (0.2673, 0.2564, 0.2762))
        ])
    else:
        aug_transform = simple_transform

    # Train - Val - Test
    trainset = torchvision.datasets.CIFAR100(
        root='./data', train=True, download=download, transform=aug_transform
    )
    n_train = len(trainset)
    indices = list(range(n_train))
    split = int(n_train - (n_train * 0.1))
    limited_split = round(max_perc*split)
    trainset = Subset(trainset, indices[:limited_split])
    train_loader = DataLoader(
        trainset, batch_size=batch_size, shuffle=True, num_workers=n_workers
    )

    valset = torchvision.datasets.CIFAR100(
        root='./data', train=True, download=False, transform=simple_transform
    )

    valset = Subset(valset, indices[split:])
    val_loader = DataLoader(
        valset, batch_size=batch_size, shuffle=False, num_workers=n_workers
    )

    testset = torchvision.datasets.CIFAR100(
        root='./data', train=False, download=download, transform=simple_transform
    )
    test_loader = DataLoader(
        testset, batch_size=batch_size, shuffle=False, num_workers=n_workers
    )

    classes = (
        'beaver',
        'dolphin',
        'otter',
        'seal',
        'whale',
        'aquarium fish',
        'flatfish',
        'ray',
        'shark',
        'trout',
        'orchids',
        'poppies',
        'roses',
        'sunflowers',
        'tulips',
        'bottles',
        'bowls',
        'cans',
        'cups',
        'plates',
        'apples',
        'mushrooms',
        'oranges',
        'pears',
        'sweet peppers',
        'clock',
        'computer keyboard',
        'lamp',
        'telephone',
        'television',
        'bed',
        'chair',
        'couch',
        'table',
        'wardrobe',
        'bee',
        'beetle',
        'butterfly',
        'caterpillar',
        'cockroach',
        'bear',
        'leopard',
        'lion',
        'tiger',
        'wolf',
        'bridge',
        'castle',
        'house',
        'road',
        'skyscraper',
        'cloud',
        'forest',
        'mountain',
        'plain',
        'sea',
        'camel',
        'cattle',
        'chimpanzee',
        'elephant',
        'kangaroo',
        'fox',
        'porcupine',
        'possum',
        'raccoon',
        'skunk',
        'crab',
        'lobster',
        'snail',
        'spider',
        'worm',
        'baby',
        'boy',
        'girl',
        'man',
        'woman',
        'crocodile',
        'dinosaur',
        'lizard',
        'snake',
        'turtle',
        'hamster',
        'mouse',
        'rabbit',
        'shrew',
        'squirrel',
        'maple',
        'oak',
        'palm',
        'pine',
        'willow',
        'bicycle',
        'bus',
        'motorcycle',
        'pickup truck',
        'train',
        'lawn-mower',
        'rocket',
        'streetcar',
        'tank',
        'tractor'
    )
    n_classes = len(classes)
    loaders = {
        "train": train_loader,
        "val": val_loader,
        "test": test_loader
    }
    return loaders, n_classes


if __name__=="__main__":
    get_dataset(1, download=True)
