import collections
from typing import Any, Dict, Sequence, Union
import flax
import jax
import jax.numpy as jnp


Batch = collections.namedtuple(
    "Batch",
    ["observations", "actions", "rewards", "masks", "next_observations"],
)


Params = flax.core.FrozenDict[str, Any]
Shape = Sequence[int]
InfoDict = Dict[str, float]
Array = Any
Dtype = Any  # this could be a real type?
