import torchvision

"""
BabelImageNet from https://arxiv.org/pdf/2306.08658.pdf
Adapted from https://github.com/gregor-ge/Babel-ImageNet, thanks to the authors
"""
class BabelImageNet(torchvision.datasets.ImageNet):
    def __init__(self, root: str, idxs, split: str = "val", download=None, **kwargs) -> None:
        super().__init__(root, split, **kwargs)
        examples_per_class = len(self.targets) // 1000
        select_idxs = [idx*examples_per_class + i for idx in idxs for i in range(examples_per_class)]
        self.targets = [i for i in range(len(idxs)) for _ in range(examples_per_class)]
        self.imgs = [self.imgs[i] for i in select_idxs]
        self.samples = [self.samples[i] for i in select_idxs]
        self.idxs = idxs

    def __getitem__(self, i):
        img, target = super().__getitem__(i)
        target = self.idxs.index(target)
        return img, target