from torch.utils.data import Dataset


class CustomRegressionDataset(Dataset):
    def __init__(self, input, labels):
        self.input = input
        self.labels = labels

    def __len__(self):
        return len(self.input)

    def __getitem__(self, idx):
        input = self.input[idx]
        label = self.labels[idx]
        return input, label
