from torchvision.datasets import CIFAR10, MNIST


class Indexed_CIFAR10(CIFAR10):
    def __getitem__(self, index):
        img, target = super().__getitem__(index)
        return img, target, index


class Indexed_MNIST(MNIST):
    def __getitem__(self, index):
        img, target = super().__getitem__(index)
        return img, target, index

