import torchvision
import torchvision.transforms as transforms

def get_data(state, event, splitname, with_transform=True):
    state.all["dataset.num_classes"] = 200
    state.all["dataset.num_channels"] = 3

    # get optional transforms
    if with_transform:
        transform = [t for tr in event.optional.dataset_transform(splitname) for t in tr]
    else:
        transform = []

    # data specific transform
    mean = [0.4802, 0.4481, 0.3975]
    std = [0.2302, 0.2265, 0.2262]
    normalize = transforms.Normalize(mean=mean,std=std)
    transform += [transforms.ToTensor(), normalize]

    # compose dataset
    transform = transforms.Compose(transform)
    if splitname == "train":
        return torchvision.datasets.ImageFolder(root="../datasets/tiny_imagenet/train", transform=transform)
    elif splitname == "test":
        return torchvision.datasets.ImageFolder(root="../datasets/tiny_imagenet/test", transform=transform)
    elif splitname == "val":
        return torchvision.datasets.ImageFolder(root="../datasets/tiny_imagenet/val", transform=transform)
    else:
        raise ValueError("splitname '%s' not known." % splitname)
    
def register(mf):
    mf.register_event('dataset', get_data, unique=True)

