import torch
from torch.utils.data import Sampler
import numpy as np

class TrajectoryDistributedSampler(Sampler):

    def __init__(self, 
                 dataset, 
                 shuffle=True, 
                 seed=0
                 ):
        try:
            import torch.distributed as dist
            num_replicas = dist.get_world_size()
            rank = dist.get_rank()
        except ValueError:
            num_replicas = 1  # Default to 1 if not using distributed training
            rank = 0  # Default to 0 if not using distributed training
        
        print(f'rank: {rank}, num_replicas: {num_replicas}')
        
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.shuffle = shuffle
        self.seed = seed
        self.epoch = 0

        # Compute the number of trajectories
        self.num_trajectories = len(dataset.lengths)

        # Determine the trajectories assigned to this replica
        self.trajectories_per_replica = self.num_trajectories // self.num_replicas
        self.start_index = self.trajectories_per_replica * self.rank
        self.end_index = self.start_index + self.trajectories_per_replica
        if self.rank == self.num_replicas - 1:
            self.end_index = self.num_trajectories

        # Compute the indices for this replica
        self.indices = []
        for trajectory_index in range(self.start_index, self.end_index):
            start = dataset.cum_lengths[trajectory_index]
            end = dataset.cum_lengths[trajectory_index + 1]
            self.indices.extend(range(start, end))

    def __iter__(self):
        if self.shuffle:
            # Deterministically shuffle based on epoch and seed
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            indices = torch.randperm(len(self.indices), generator=g).tolist()
            indices = [self.indices[i] for i in indices]
        else:
            indices = self.indices

        return iter(indices)

    def __len__(self):
        return len(self.indices)

    def set_epoch(self, epoch):
        self.epoch = epoch
