import torchvision.datasets as datasets


class ImageFolderCustom(datasets.ImageFolder):
    def __init__(self, root, transform=None, target_transform=None, loader=datasets.folder.default_loader, reduction_set=None):
        super(ImageFolderCustom, self).__init__(root, transform=transform, target_transform=target_transform, loader=loader)
        self.reduction_set = reduction_set
        self.length = len(self.samples)
        
    def reset_reduction(self):
        self.reduction_set = None
        self.length = len(self.samples)
        
    def set_length(self, length):
        self.length = length
        
    def __len__(self):
        return self.length
        
    def set_reduction_set(self, reduction_set):
        self.reduction_set = reduction_set
        set_org = set(range(len(self.samples)))
        set_aft = set_org - set(reduction_set)
        
        self.ind_map = dict(zip(range(len(set_aft)), list(set_aft)))
        
        return len(set_aft)
    
    def __getitem__(self, idx):
        if hasattr(self, 'ind_map') and not self.reduction_set is None:
            idx = self.ind_map[idx]
        
        path, target = self.samples[idx]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target, idx
        