import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from RandAugment import RandAugment

BASE_DIR = "/tmp/"

def get_dataset(dataset):
    trainset, testset = None, None
    num_classes = 0

    if dataset == "CIFAR10":
        root_dir = BASE_DIR + "cifar10"
        transform_train = transforms.Compose([
            transforms.RandomAffine(0.0, translate=(0.1, 0.1), shear=0, fill=0),
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        # RandAugment
        N, M = 1, 1
        transform_train.transforms.insert(0, RandAugment(N, M))
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        trainset = CIFAR10(root=root_dir, train=True, download=True, transform=transform_train)
        testset = CIFAR10(root=root_dir, train=False, download=True, transform=transform_test)
        num_classes = 10
    else:
        raise ValueError(f"Dataset {dataset} not supported yet.")

    return trainset, testset, num_classes
