import torch
from torch.utils.data import Sampler
import numpy as np

class TrajectoryDistributedSampler(Sampler):

    def __init__(self, 
                 dataset, 
                 shuffle=True, 
                 seed=0
                 ):
        try:
            import torch.distributed as dist
            num_replicas = dist.get_world_size()
            rank = dist.get_rank()
        except ValueError:
            num_replicas = 1  # Default to 1 if not using distributed training
            rank = 0  # Default to 0 if not using distributed training

        print(f'TrajectoryDistributedSampler: rank: {rank}/{num_replicas}')
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.shuffle = shuffle
        self.seed = seed
        self.epoch = 0

        total_length = sum([l for l in dataset.lengths])
        print(f'{total_length=}')
        
        # drop last few if they don't divide total length
        len_per_replica = total_length // num_replicas
        
        # sample contiguous indices for each rank (may facilitate loading from hdf5)
        offset = len_per_replica * rank
        self.indices = offset + np.arange(len_per_replica)
        print(f'range: [{self.indices[0]}, {self.indices[-1]}]')

    def __iter__(self):
        self.epoch += 1
        
        if self.shuffle:
            # Deterministically shuffle based on epoch and seed
            g = np.random.default_rng(self.seed + self.epoch)
            indices = g.permutation(len(self.indices))
            indices = [self.indices[i] for i in indices]
        else:
            indices = self.indices

        for i in indices:
            yield i

    def __len__(self):
        return len(self.indices)



if __name__ == '__main__':
    
    pass
