import random as rd
import numpy as np
import itertools
from typing import Iterable, overload


_INDICES_PERM_STORAGE: dict[tuple[int, int], list[tuple[int, ...]]] = {}
_INDICES_PROD_STORAGE: dict[tuple[int, int], list[tuple[int, ...]]] = {}


def _indices_perm(n: int, k: int, _use_storage: bool = True) -> list[tuple[int, ...]]:

    if _use_storage:
        try:
            return _INDICES_PERM_STORAGE[n, k].copy()
        except KeyError:
            pass
    
    out = list(itertools.permutations(range(n), k))
    if _use_storage:
        _INDICES_PERM_STORAGE[n, k] = out
    return out


def _indices_prod(n: int, k: int, _use_storage: bool = True) -> list[tuple[int, ...]]:
    if _use_storage:
        try:
            return _INDICES_PROD_STORAGE[n, k].copy()
        except KeyError:
            pass

    out = list(itertools.product(*(range(n) for _ in range(k))))
    if _use_storage:
        _INDICES_PROD_STORAGE[n, k] = out
    return out


def permutations[T](items: list[T], n: int, k: int | None = None,
                    _use_storage=True):
    """
    Generate at most `k` sub-lists with `n` elements from `items`. 
    """

    indices = _indices_perm(len(items), n, _use_storage)

    if k is not None and k < len(indices):
        indices = rd.choices(indices, k=k)

    for i in indices:
        yield tuple(items[j] for j in i)


def _idx2tuple(i: int, ns: tuple[int, ...]) -> tuple[int, ...]:
    """
    The `i`-th indices for a array with shape `ns`.
    """

    if len(ns) > 1:
        n = ns[-1]
        return _idx2tuple(i // n, ns[:-1]) + (i % n,)
    elif len(ns) == 1:
        return (i % ns[0]),
    else:
        assert False


def selfproducts[T](items: list[T], n: int, k: int | None = None) -> Iterable[tuple[T, ...]]:
    """
    Generate at most `k` sub-lists with `n` elements from `items`.
    If `k` is None, all possible combinations are generated.
    """

    m = len(items)
    if k is None or (k >= (M := m**n)):
        yield from itertools.product(*((items,) * n))
    else:
        ms = (len(items),) * n
        for i in map(int, np.random.choice(M, size=k, replace=False)):
            js = _idx2tuple(i, ms)
            yield tuple(items[j] for j in js)

@overload
def indices(shape: Iterable[int]) -> Iterable[tuple[int, ...]]: ...
@overload
def indices(shape: Iterable[int], dims: tuple[int, ...]) -> Iterable[tuple[int | slice, ...]]: ...
def indices(shape: Iterable[int], dims: tuple[int, ...] | None = None) -> Iterable[tuple[int | slice, ...]]:
    if dims is None:
        return itertools.product(*(range(n) for n in shape))
    else:
        return _indices_with_slice(shape, dims)


def _indices_with_slice(shape: Iterable[int], dims: tuple[int, ...]) -> Iterable[tuple[int | slice, ...]]:
    shape = tuple(shape)
    shape_of_dims = tuple(shape[d] for d in dims)
    temp: list[slice | int] = [slice(None)] * len(shape)
    for idx in itertools.product(*(range(n) for n in shape_of_dims)):
        for d, i in zip(dims, idx):
            temp[d] = i
        yield(tuple(temp))
