import torch
from torch import Tensor
from torch.utils.data import Dataset, default_collate


class ImagesDataset(Dataset):
    """
    Takes a 3D torch array and iterates over it
    """

    def __init__(self, y: Tensor, window: int) -> None:
        """
        y: 3D Tensor
        window: window size
        """
        if y.ndim != 3:
            raise ValueError("ImageDataset expects a 3D tensor [N, w, h]")
        if window < 1:
            raise ValueError("window size should be >= 1")
        self.y = y
        self.window = window
        self.collate_fn = default_collate

    def __len__(self) -> int:
        return self.y.size(0) - self.window + 1

    def __getitem__(self, index: int) -> torch.Tensor:
        if index < 0 or index >= self.__len__():
            raise StopIteration()
        y = self.y[index : index + self.window]
        return {"y": y}
