# Haiku PRNGSequence
import random
import numpy as np
import jax
import jax.numpy as jnp
import collections
from typing import Iterator, Iterable, Union

PRNGKey = jax.random.PRNGKey
PRNGSequenceState = tuple[PRNGKey, Iterable[PRNGKey]]


# Haiku PRNGSequence
class PRNGSequence(Iterator[PRNGKey]):
    __slots__ = ("_key", "_subkeys")

    def __init__(self, key_or_seed: Union[PRNGKey, int]):
        """Creates a new :class:`PRNGSequence`."""
        if isinstance(key_or_seed, tuple):
            key, subkeys = key_or_seed
            self._key = key
            self._subkeys = collections.deque(subkeys)
        else:
            if isinstance(key_or_seed, int):
                key_or_seed = jax.random.PRNGKey(key_or_seed)
            # A seed value may also be passed as an int32-typed scalar ndarray.
            elif (hasattr(key_or_seed, "shape") and (not key_or_seed.shape) and
                  hasattr(key_or_seed, "dtype") and key_or_seed.dtype == jnp.int32):
                key_or_seed = jax.random.PRNGKey(key_or_seed)

            self._key = key_or_seed
            self._subkeys = collections.deque()

    def reserve(self, num):
        """Splits additional ``num`` keys for later use."""
        if num > 0:
            # When storing keys we adopt a pattern of key0 being reserved for future
            # splitting and all other keys being provided to the user in linear order.
            # In terms of jax.random.split this looks like:
            #
            #     key, subkey1, subkey2 = jax.random.split(key, 3)  # reserve(2)
            #     key, subkey3, subkey4 = jax.random.split(key, 3)  # reserve(2)
            #
            # Where subkey1->subkey4 are provided to the user in order when requested.
            new_keys = tuple(jax.random.split(self._key, num + 1))
            self._key = new_keys[0]
            self._subkeys.extend(new_keys[1:])

    @property
    def internal_state(self) -> PRNGSequenceState:
        return self._key, tuple(self._subkeys)

    def replace_internal_state(self, state: PRNGSequenceState):
        key, subkeys = state
        self._key = key
        self._subkeys = collections.deque(subkeys)

    def __next__(self) -> PRNGKey:
        if not self._subkeys:
            self.reserve(42)
        return self._subkeys.popleft()

    next = __next__

    def take(self, num) -> tuple[PRNGKey, ...]:
        self.reserve(max(num - len(self._subkeys), 0))
        return tuple(next(self) for _ in range(num))


def fix_seed(seed: int):
    try:
        import torch as th
        th.manual_seed(seed)
        th.cuda.manual_seed(seed)
        th.cuda.manual_seed_all(seed)
    except ImportError:
        pass
    try:
        import tensorflow as tf
        tf.random.set_seed(seed)
    except ImportError:
        pass

    random.seed(seed)
    np.random.seed(seed)
    return
