import torchvision
import torchvision.transforms as transforms


def cifar_dataset(num_classes, data_root='./data/'):

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomApply([transforms.RandAugment()]),
        transforms.ToTensor(),
    ])

    transform_test = transforms.ToTensor()

    dataset = getattr(torchvision.datasets, 'CIFAR%d' % num_classes)

    trainset = dataset(root=data_root,
                       train=True,
                       transform=transform_train,
                       download=True)
    """
    if num_classes == 10:
        import numpy as np
        #index = np.load('cifar10_good.npy').astype(np.bool)
        index = (1 - np.load('cifar10_bad_samples.npz')['index1']).astype(np.bool)
        # index = np.load('cifar10_record.npy')[:, 0].astype(np.bool)
        trainset.data = trainset.data[index]
        trainset.targets = list(np.array(trainset.targets)[index])
        print(trainset.data.shape, set(trainset.targets))
    elif num_classes == 100:
        import numpy as np
        index = (1 - np.load('cifar100_ssl_index.npy')).astype(np.bool)
        trainset.data = trainset.data[index]
        trainset.targets = list(np.array(trainset.targets)[index])
        print(trainset.data.shape, set(trainset.targets))
    """ 
    testset = dataset(root=data_root,
                      train=False,
                      transform=transform_test,
                      download=True)
    return trainset, testset
