"""Utilities for making datasets."""
import torch


class UnlabelledDataset(torch.utils.data.Dataset):
    """Dataset that wraps a labelled dataset and returns only the data.

    Args:
        dataset (torch.Dataset): Labelled dataset.
    """

    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx][0]


class FlattenTransform(object):
    def __call__(self, tensor):
        return tensor.view(-1)
