import torchvision
import torchvision.transforms as transforms

def get_data(state, event, splitname, with_transform=True):
    state.all["dataset.num_classes"] = 10
    state.all["dataset.num_channels"] = 1

    # get optional transforms
    if with_transform:
        transform = [t for tr in event.optional.dataset_transform(splitname) for t in tr]
    else:
        transform = []

    # data specific transform
    mean = [0.4913997551666284]
    std = [0.24703225141799082]
    normalize = transforms.Normalize(mean=mean,std=std)
    transform += [transforms.ToTensor(), normalize]

    # compose dataset
    transform = transforms.Compose(transform)
    if splitname == "train":
        data =  torchvision.datasets.MNIST(root='./download', train=True, transform=transform, download=True)
        state.all["trainset_size"] = len(data)
        return data
    elif splitname == "test":
        data =  torchvision.datasets.MNIST(root='./download', train=False, transform=transform,download=True)
        state.all["testset_size"] = len(data)
        return data
    elif splitname == "val":
        state.all["valset_size"] = 0
        return None
    else:
        raise ValueError("splitname '%s' not known." % splitname)
def register(mf):
    mf.register_event('dataset', get_data, unique=True)

