# Copyright (c) OpenMMLab. All rights reserved.
from typing import Iterator

import torch
from mmengine.dataset import DefaultSampler

from mmpretrain.registry import DATA_SAMPLERS


@DATA_SAMPLERS.register_module()
class SequentialSampler(DefaultSampler):
    """Sequential sampler which supports different subsample policy.

    Args:
        dataset (Sized): The dataset.
        round_up (bool): Whether to add extra samples to make the number of
            samples evenly divisible by the world size. Defaults to True.
        subsample_type (str): The method to subsample data on different rank.
            Supported type:

            - ``'default'``: Original torch behavior. Sample the examples one
              by one for each GPU in terms. For instance, 8 examples on 2 GPUs,
              GPU0: [0,2,4,8], GPU1: [1,3,5,7]
            - ``'sequential'``: Subsample all examples to n chunk sequntially.
              For instance, 8 examples on 2 GPUs,
              GPU0: [0,1,2,3], GPU1: [4,5,6,7]
    """

    def __init__(self, subsample_type: str = 'default', **kwargs) -> None:
        super().__init__(shuffle=False, **kwargs)

        if subsample_type not in ['default', 'sequential']:
            raise ValueError(f'Unsupported subsample typer "{subsample_type}",'
                             ' please choose from ["default", "sequential"]')
        self.subsample_type = subsample_type

    def __iter__(self) -> Iterator[int]:
        """Iterate the indices."""
        indices = torch.arange(len(self.dataset)).tolist()

        # add extra samples to make it evenly divisible
        if self.round_up:
            indices = (
                indices *
                int(self.total_size / len(indices) + 1))[:self.total_size]

        # subsample
        if self.subsample_type == 'default':
            indices = indices[self.rank:self.total_size:self.world_size]
        elif self.subsample_type == 'sequential':
            num_samples_per_rank = self.total_size // self.world_size
            indices = indices[self.rank *
                              num_samples_per_rank:(self.rank + 1) *
                              num_samples_per_rank]

        return iter(indices)
