from torchvision import datasets

class SubsetDigits(datasets.MNIST):
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False, digits=None):
        super().__init__(root, train=train, transform=transform, target_transform=target_transform, download=download)
        self.digits = digits
        if digits is not None:
            self.indices = [i for i, target in enumerate(self.targets) if target in digits]
        else:
            self.indices = list(range(len(self.targets)))

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        img, target = super().__getitem__(self.indices[idx])
        return img, target