import torchvision
import torchvision.transforms as transforms

def get_data(state, event, splitname, with_transform=True):
    state.all["dataset.num_classes"] = 10
    state.all["dataset.num_channels"] = 3

    # get optional transforms
    if with_transform:
        transform = [t for tr in event.optional.dataset_transform(splitname) for t in tr]
    else:
        transform = []

    # data specific transform
    mean = [0.47889522, 0.47227842, 0.43047404]
    std = [0.24205776, 0.23828046, 0.25874835]

    normalize = transforms.Normalize(mean=mean,std=std)
    transform += [transforms.ToTensor(), normalize]

    # compose dataset
    transform = transforms.Compose(transform)
    cinic_directory = '../datasets/cinic10'

    if splitname == "train":
        data =  torchvision.datasets.ImageFolder(cinic_directory + '/train', transform=transform)
        state.all["trainset_size"] = len(data)
        return data
    elif splitname == "test":
        data =  torchvision.datasets.ImageFolder(cinic_directory + '/test', transform=transform)
        state.all["testset_size"] = len(data)
        return data
    elif splitname == "val":
        data =  torchvision.datasets.ImageFolder(cinic_directory + '/val', transform=transform)
        state.all["valset_size"] = len(data)
        return data
    else:
        raise ValueError("splitname '%s' not known." % splitname)


def register(mf):
    mf.register_event('dataset', get_data, unique=True)

