import torchvision
import torchvision.transforms as transforms

def get_data(state, event, splitname, with_transform=True):

    state.all["dataset.num_classes"] = 100
    state.all["dataset.num_channels"] = 3

    # get optional transforms
    if with_transform:
        transform = [t for tr in event.optional.dataset_transform(splitname) for t in tr]
    else:
        transform = []

    # data specific transform
    mean = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
    std = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)

    normalize = transforms.Normalize(mean=mean,std=std)
    transform += [transforms.ToTensor(), normalize]

    # compose dataset
    transform = transforms.Compose(transform)
    if splitname == "train":
        data = torchvision.datasets.CIFAR100(root='../datasets', train=True, download=True, transform=transform)
        state.all["trainset_size"] = len(data)
        return data
    elif splitname == "test":
        data = torchvision.datasets.CIFAR100(root='../datasets', train=False, download=True, transform=transform)
        state.all["testset_size"] = len(data)
        return data
    elif splitname == "val":
        state.all["valeset_size"] = 0
        return None
    else:
        raise ValueError("splitname '%s' not known." % splitname)

def register(mf):
    mf.register_event('dataset', get_data, unique=True)

