from torchvision.datasets import ImageFolder


class IndexedImageFolder(ImageFolder):

    def __init__(self, **kwargs):
        num_channels = kwargs.pop('num_channels')
        super().__init__(**kwargs)
        assert num_channels == 1 or num_channels == 3, "number of channels can be only 1 or 3!"
        self.num_channels = num_channels

    def find_classes(self, directory):
        classes, class_to_idx = super().find_classes(directory)
        for key in class_to_idx.keys():
            class_to_idx[key] = int(key)
        return classes, class_to_idx

    def __getitem__(self, index):
        sample, target = super().__getitem__(index)

        # always load first channel
        sample = sample[0:1, :, :] if self.num_channels == 1 else sample
        return sample, target


class IndexedImageFolderPath(IndexedImageFolder):
    # this dataset returns Image, Label, Path
    def __getitem__(self, index):
        sample, target = super(IndexedImageFolderPath, self).__getitem__(index)
        return sample, target, self.imgs[index][0]
