import torchvision
import torchvision.transforms as transforms

def get_data(state, event, splitname, with_transform=True):

    state.all["dataset.num_classes"] = 10
    state.all["dataset.num_channels"] = 3

    # get optional transforms
    if with_transform:
        transform = [t for tr in event.optional.dataset_transform(splitname) for t in tr]
    else:
        transform = []

    # data specific transform
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

    normalize = transforms.Normalize(mean=mean,std=std)
    transform += [transforms.ToTensor(), normalize]

    # compose dataset
    transform = transforms.Compose(transform)
    if splitname == "train":
        data = torchvision.datasets.CIFAR10(root='../datasets', train=True, download=True, transform=transform)
        state.all["trainset_size"] = len(data)
        return data
    elif splitname == "test":
        data = torchvision.datasets.CIFAR10(root='../datasets', train=False, download=True, transform=transform)
        state.all["testset_size"] = len(data)
        return data
    elif splitname == "val":
        state.all["valset_size"] = 0
        return None
    else:
        raise ValueError("splitname '%s' not known." % splitname)

def register(mf):
    mf.register_event('dataset', get_data, unique=True)

