import torch
from torchvision import datasets, transforms

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self,
                 *,
                 dataset,
                 transform=None,
                 target_transform=None,
                 discard_classes=[]):
        
        self.original_dataset = dataset
        self.transform = transform
        self.target_transform = target_transform
        self.discard_classes = discard_classes
        
        # Filter out discarded classes
        if len(self.discard_classes)!=0 :
            self._filter_classes()

    def _filter_classes(self):
        new_data = []
        new_targets = []

        for idx, (data, target) in enumerate(self.original_dataset):
            if target not in self.discard_classes:
                new_data.append(data)
                new_targets.append(target)

        self.data = new_data
        self.targets = new_targets

    def __getitem__(self, index):
        if self.discard_classes:
                img, target = self.data[index], self.targets[index]
        else:
            img, target = self.original_dataset.__getitem__(index)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        if len(self.discard_classes)!=0:
            return len(self.data)
        else:
            return len(self.original_dataset)