from typing import Tuple

import torch
from torch.utils.data import Dataset
from torch.utils.data.dataset import Subset


class PredictiveDataset(Dataset):

    def __init__(self, data: Subset, cutoff: float):
        super().__init__()
        self.cutoff = cutoff
        self.data: torch.Tensor = data.dataset

    def __getitem__(self, item: int) -> Tuple[torch.Tensor, torch.Tensor]:
        e = self.data[item]
        return e[:-self.cutoff], e[-self.cutoff:].flatten()

    def __len__(self) -> int:
        return self.data.shape[0]
