from torchvision import datasets


class ImageNet(datasets.ImageNet):
    def __init__(self, root, split='train', transform=None):
        super().__init__(root=root, split=split, transform=transform)

    def __getitem__(self, index):
        # Get the image and label using the parent class's method
        image, label = super().__getitem__(index)

        # Return as a dictionary
        return {'x': image, 'y': label}
