import numpy as np
from torch.utils.data import DistributedSampler as _DistributedSampler


class DistributedVideoSampler(_DistributedSampler):
    def __init__(self, dataset, num_replicas=None, rank=None, shuffle=False):
        super().__init__(dataset, num_replicas=num_replicas, rank=rank)
        self.shuffle = shuffle
        assert not self.shuffle, "Specific for video sequential testing."
        self.num_samples = len(dataset)

        first_frame_indices = []
        for i, img_info in enumerate(self.dataset.data_infos):
            if img_info["frame_id"] == 0:
                first_frame_indices.append(i)

        chunks = np.array_split(first_frame_indices, num_replicas)
        split_flags = [c[0] for c in chunks]
        split_flags.append(self.num_samples)

        self.indices = [
            list(range(split_flags[i], split_flags[i + 1]))
            for i in range(self.num_replicas)
        ]

    def __iter__(self):
        indices = self.indices[self.rank]
        return iter(indices)
