import torch
from torch.utils.data import Dataset

class TimeSeriesDataset(Dataset):
    def __init__(self, data, timestamps, input_len, output_len, d_inner):
        """Initialize the time series dataset.
        
        Args:
            data: numpy.ndarray of shape (num_samples, seq_len, input_dim),
                the input time series data sequences
            timestamps: numpy.ndarray of shape (num_samples, seq_len + 1),
                the timestamp sequences for each sample
            input_len: int, length of input sequence
            output_len: int, length of output sequence
            d_inner: int, dimension of inner features
        """
        self.data = data
        self.timestamps = timestamps
        self.input_len = input_len
        self.output_len = output_len
        self.d_inner = d_inner

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        """Get a single sample from the dataset.
        
        Args:
            idx: int, index of the sample

        Returns:
            tuple:
                - x: torch.Tensor of shape (input_len+output_len, input_dim),
                    the input sequence
                - y: torch.Tensor of shape (output_len, input_dim),
                    the target sequence
                - timestamps: torch.Tensor of shape (seq_len + 1,),
                    the timestamp sequence
        """
        # Get input and target sequences
        x = self.data[idx, :self.input_len+self.output_len, :]
        y = self.data[idx, self.input_len:self.input_len + self.output_len, :]
        
        # Get corresponding timestamps
        timestamps = self.timestamps[idx]  # (seq_len + 1,)
        
        return x, y, timestamps