"""Utilities relate to RNG."""
import numpy as np
import torch

from typing import Optional


def make_random_unique_indices(
    indexed_size: int,
    n_indices: int,
    n_groups: int,
    *,
    np_generator: Optional[np.random.Generator] = None,
    device: Optional[torch.device] = None,
    dtype: torch.dtype = torch.int64,
) -> torch.Tensor:
    """Generated random unique indices.

    Args:
        indexed_size: The size of the vector that are indexing into. Used to
            get a max index value.
        n_indices: Number of indices to generate for each group of indices.
        n_groups: Number of groups of indices.

    Returns:
        A tensor of shape [n_groups, n_indices] with each row containing a set
        of unique indices sorted in ascending order.
    """
    if np_generator is None:
        np_generator = np.random.default_rng()

    # TODO: See how long this takes for n_groups=30_000.
    samples = [
        np_generator.choice(indexed_size, size=(n_indices,), replace=False, shuffle=False)
        for _ in range(n_groups)
    ]
    samples = np.concatenate(samples, axis=0)

    return torch.from_numpy(samples).type(dtype).to(device)
