from torch.utils.data import Dataset


class DatasetWithReward(Dataset):
    def flat_reward_transform(self, r):
        pass
