import itertools
import warnings
from typing import Any, Optional
import numpy as np
import torch
from torch.utils.data.sampler import Sampler
from dinov3.distributed import get_rank, get_world_size
class EpochSampler(Sampler):
    def __init__(
        self,
        *,
        size: int,
        sample_count: int,
        shuffle: bool = False,
        seed: int = 0,
        start: Optional[int] = None,
        step: Optional[int] = None,
    ):
        self._size = size
        self._sample_count = sample_count
        self._shuffle = shuffle
        self._seed = seed
        self._start = get_rank() if start is None else start
        self._step = get_world_size() if step is None else step
        self._epoch = 0
    def __iter__(self):
        count = (self._size + self._sample_count - 1) // self._sample_count
        tiled_indices = np.tile(np.arange(self._sample_count), count)
        if self._shuffle:
            seed = self._seed * self._epoch if self._seed != 0 else self._epoch
            rng = np.random.default_rng(seed)
            iterable = rng.choice(tiled_indices, self._size, replace=False)
        else:
            iterable = tiled_indices[: self._size]
        yield from itertools.islice(iterable, self._start, None, self._step)
    def __len__(self):
        return (self._size - self._start + self._step - 1) // self._step
    def set_epoch(self, epoch):
        self._epoch = epoch
def _get_numpy_dtype(size: int) -> Any:
    return np.int32 if size <= 2**31 else np.int64
def _get_torch_dtype(size: int) -> Any:
    return torch.int32 if size <= 2**31 else torch.int64
def _generate_randperm_indices(*, size: int, generator: torch.Generator):
    """Generate the indices of a random permutation."""
    dtype = _get_torch_dtype(size)
    perm = torch.arange(size, dtype=dtype)
    for i in range(size):
        j = torch.randint(i, size, size=(1,), generator=generator).item()
        value = perm[j].item()
        perm[j] = perm[i].item()
        perm[i] = value
        yield value
class InfiniteSampler(Sampler):
    def __init__(
        self,
        *,
        sample_count: int,
        shuffle: bool = False,
        seed: int = 0,
        start: Optional[int] = None,
        step: Optional[int] = None,
        advance: int = 0,
    ):
        self._sample_count = sample_count
        self._seed = seed
        self._shuffle = shuffle
        self._start = get_rank() if start is None else start
        self._step = get_world_size() if step is None else step
        self._advance = advance
    def __iter__(self):
        if self._shuffle:
            iterator = self._shuffled_iterator()
        else:
            iterator = self._iterator()
        yield from itertools.islice(iterator, self._advance, None)
    def _iterator(self):
        assert not self._shuffle
        while True:
            iterable = range(self._sample_count)
            yield from itertools.islice(iterable, self._start, None, self._step)
    def _shuffled_iterator(self):
        assert self._shuffle
        generator = torch.Generator().manual_seed(self._seed)
        while True:
            iterable = _generate_randperm_indices(size=self._sample_count, generator=generator)
            yield from itertools.islice(iterable, self._start, None, self._step)
def _shuffle_tensor_slice(
    *, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator
) -> np.ndarray:
    stop = len(tensor)
    count = stop // step
    drop_count = stop - step * count
    if drop_count:
        warnings.warn(f"# of dropped samples: {drop_count}", stacklevel=1)
    dtype = _get_numpy_dtype(stop)
    result = np.empty(count, dtype=dtype)
    for i in range(count):
        j = torch.randint(0, i + 1, size=(1,), generator=generator).item() if i > 0 else 0
        result[i] = result[j]
        result[j] = tensor[start + i * step].item()
    return result
def _new_shuffle_tensor_slice(
    *, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator
) -> np.ndarray:
    stop = len(tensor)
    count = stop // step
    dtype = torch.int64
    count = stop // step
    drop_count = stop - step * count
    if drop_count:
        warnings.warn(f"# of dropped samples: {drop_count}", stacklevel=1)
    indices = torch.randperm(count, dtype=dtype, generator=generator)
    return tensor[start::step][indices].numpy()
def _make_seed(seed: int, start: int, iter_count: int) -> int:
    return seed + start + (iter_count << 24)
class ShardedInfiniteSampler(Sampler):
    def __init__(
        self,
        *,
        sample_count: int,
        shuffle: bool = False,
        seed: int = 0,
        start: Optional[int] = None,
        step: Optional[int] = None,
        advance: int = 0,
        use_new_shuffle_tensor_slice: bool = False,
    ):
        self._sample_count = sample_count
        self._seed = seed
        self._shuffle = shuffle
        self._start = get_rank() if start is None else start
        self._step = get_world_size() if step is None else step
        self._advance = advance
        self._iter_count = 0
        self._shuffle_tensor_slice_fn = (
            _new_shuffle_tensor_slice if use_new_shuffle_tensor_slice else _shuffle_tensor_slice
        )
    def __iter__(self):
        iter_count = self._advance // self._sample_count
        if iter_count > 0:
            self._advance -= iter_count * self._sample_count
            self._iter_count += iter_count
        if self._shuffle:
            iterator = self._shuffled_iterator()
        else:
            iterator = self._iterator()
        yield from itertools.islice(iterator, self._advance, None)
    def _iterator(self):
        assert not self._shuffle
        while True:
            iterable = range(self._sample_count)
            yield from itertools.islice(iterable, self._start, None, self._step)
    def _shuffled_iterator(self):
        assert self._shuffle
        generator = torch.Generator()
        generator.manual_seed(self._seed)
        dtype = _get_torch_dtype(self._sample_count)
        perm = torch.randperm(self._sample_count, dtype=dtype, generator=generator)
        while True:
            seed = _make_seed(self._seed, self._start, self._iter_count)
            generator.manual_seed(seed)
            iterable = self._shuffle_tensor_slice_fn(
                tensor=perm, start=self._start, step=self._step, generator=generator
            )
            yield from iterable
            self._iter_count += 1
