from torch.utils.data import Dataset


class SubsetWrapper(Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform

    def __getitem__(self, index):
        item = self.subset[index]
        if self.transform:
            item['x'] = self.transform(item['x'])
        return item

    def __len__(self):
        return len(self.subset)