import torch
from torch.utils.data import Sampler

class TimestepSampler(Sampler):
    def __init__(self, dataset, trajectory_index=0, timestep=50, **kwargs):
        self.dataset = dataset
        self.trajectory_index = trajectory_index

        # Determine the range of indices for the specified trajectory
        self.start_index = dataset.cum_lengths[trajectory_index]
        self.const_idx = self.start_index + timestep

    def __iter__(self):
        idx_meta_dict = {
            'idx': iter(self.indices),
            'cache_trajectory': True
        }
        return idx_meta_dict

    def __len__(self):
        return self.const_idx
