import torch


class ToyDataset(torch.utils.data.Dataset):
    def __init__(self, data, labels, masks, transform=None):
        self.data = data
        self.labels = torch.LongTensor(labels)
        self.masks = masks
        self.transform = transform

    def __getitem__(self, index):
        x = self.data[index]
        y = self.labels[index]
        z = self.masks[index]

        if self.transform:
            x = self.transform(x)

        return x, y, z

    def __len__(self):
        return len(self.data)