from torch.utils import data
import torch
import numpy as np


class Dataset(data.Dataset):
    # haracterizes a dataset for PyTorch
    def __init__(self, data_dir, nframe):
        self.data_dict = np.load(data_dir) 
        self.data = torch.tensor(self.data_dict['data'])  # npath, nstep, n
        self.npath, self.nstep, self.n = self.data.shape
        self.data = torch.transpose(self.data, 1, 2)  # npath, n, nstep
        self.size = self.npath * (self.nstep - nframe + 1)
        self.nframe = nframe  # length of a training path segment
        self.dt = float(self.data_dict['dt'])

    def __len__(self):
        # Denotes the total number of samples
        return self.size

    def __getitem__(self, index):
        # Generates one sample of data
        path = index // (self.nstep - self.nframe + 1)
        start = index % (self.nstep - self.nframe + 1)
        item = self.data[path, :, start : start+self.nframe]  # n, nframe
        return item
