import math
from typing import Iterator, Optional, TypeVar

import torch
import torch.distributed as dist
from torch.utils.data.dataset import Dataset
from torch.utils.data.sampler import Sampler


__all__ = ["DistributedSampler"]


_T_co = TypeVar("_T_co", covariant=True)



class DistributedSampler(Sampler[_T_co]):
    r

    def __init__(
        self,
        dataset: Dataset,
        num_replicas: Optional[int] = None,
        rank: Optional[int] = None,
        shuffle: bool = True,
        seed: int = 0,
        drop_last: bool = False,
        consumed_samples=0,
    ) -> None:
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        if rank >= num_replicas or rank < 0:
            raise ValueError(f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]")
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.drop_last = drop_last
        
        
        if self.drop_last and len(self.dataset) % self.num_replicas != 0:  
            
            
            
            self.num_samples = math.ceil(
                (len(self.dataset) - self.num_replicas) / self.num_replicas  
            )
        else:
            self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)  
        self.total_size = self.num_samples * self.num_replicas
        self.shuffle = shuffle
        self.seed = seed
        self.consumed_indicies = consumed_samples // self.num_replicas

    def __iter__(self) -> Iterator[_T_co]:
        if self.shuffle:
            
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            indices = torch.randperm(len(self.dataset), generator=g).tolist()  
        else:
            indices = list(range(len(self.dataset)))  

        if not self.drop_last:
            
            padding_size = self.total_size - len(indices)
            if padding_size <= len(indices):
                indices += indices[:padding_size]
            else:
                indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
        else:
            
            indices = indices[: self.total_size]
        assert len(indices) == self.total_size

        
        indices = indices[self.rank : self.total_size : self.num_replicas]
        
        indices = indices[self.consumed_indicies :]
        assert len(indices) == self.num_samples - self.consumed_indicies

        return iter(indices)

    def __len__(self) -> int:
        return self.num_samples - self.consumed_indicies

    def set_epoch(self, epoch: int, consumed_samples=0) -> None:
        r
        self.epoch = epoch
        self.consumed_indicies = consumed_samples // self.num_replicas
