from torch.utils.data import Dataset
import torch

class WindowedPredictionDataset(Dataset):
    def __init__(
        self, video,
        velocity, rot_velocity,
        positions, thetas,
        window_size, n_future_pred=1
    ):        
        self.scene_in = video
        self.scene_out = self.scene_in

        self.scene_in = torch.from_numpy(self.scene_in) if self.scene_in is not None else None
        self.scene_out = torch.from_numpy(self.scene_out)

        self.velocity = torch.from_numpy(velocity)
        self.rot_velocity = torch.from_numpy(rot_velocity)

        self.positions = torch.from_numpy(positions)
        self.thetas = torch.from_numpy(thetas)

        self.window_size = window_size
        self.n_future_pred = n_future_pred

    def __getitem__(self, index):
        # Ensure that the index is within the range of the dataset
        if not (0 <= index < len(self)):
            raise ValueError("Index out of range")

        window_slice = (index*self.window_size, index*self.window_size+self.window_size)
        inputs = (
            self.scene_in[:, window_slice[0]:window_slice[1]] if self.scene_in is not None
            else torch.Tensor([])
        )
        
        vel, rot_vel, pos, thet, label =\
            [], [], [], [], []
        for f in range(self.n_future_pred):
            s = window_slice[0] + f
            e = window_slice[1] + f
            vel.append(self.velocity[:, s:e])
            rot_vel.append(self.rot_velocity[:, s:e])
            pos.append(self.positions[:, s:e])
            thet.append(self.thetas[:, s:e])

            label.append(self.scene_out[:, s+1:e+1])

        vel = torch.stack(vel, dim=1)
        rot_vel = torch.stack(rot_vel, dim=1)
        pos = torch.stack(pos, dim=1)
        thet = torch.stack(thet, dim=1)
        label = torch.stack(label, dim=1)

        return inputs, vel, rot_vel, pos, thet, label

    def __len__(self):
        return self.positions.shape[1] // self.window_size - self.n_future_pred
