import torch
from torch.utils.data import Sampler
import random

class SingleTrajectorySampler(Sampler):
    def __init__(self, 
                 dataset, 
                 trajectory_index=0, 
                 shuffle=False, 
                 seed=None, 
                 **kwargs
                 ):
        self.dataset = dataset
        self.trajectory_index = trajectory_index
        self.shuffle = shuffle
        self.seed = seed
        self.generator = random.Random(seed)

        # Determine the range of indices for the specified trajectory
        self.start_index = dataset.cum_lengths[trajectory_index]
        self.end_index = dataset.cum_lengths[trajectory_index + 1]
        self.indices = list(range(self.start_index, self.end_index))

        # Internal iterator state
        self._reset_iterator()

    def _reset_iterator(self):
        """Shuffle if needed and reset the internal iterator."""
        if self.shuffle:
            self.generator.shuffle(self.indices)
        self._iter = iter(self.indices)

    def __iter__(self):
        while True:
            try:
                yield next(self._iter)
            except StopIteration:
                # Re-shuffle (if enabled) and restart
                self._reset_iterator()
                yield next(self._iter)

    def __len__(self):
        return len(self.indices)
