from flax.core import FrozenDict
from jaxtyping import Array, Bool, Float, Int, Shaped
from typing import Dict, Any
from numpy import ndarray


PRNGKey = Float[Array, '2']
BoolScalar = Bool[Array, ""]
Shape = tuple[int, ...]
BFloat = Float[Array, "b"]
BInt = Int[Array, "b"]
FloatScalar = float | Float[Array, ""]
IntScalar = int | Int[Array, ""]
TFloat = Float[Array, "T"]

Action = Float[Array, 'action_dim']
Reward = FloatScalar
Done = BoolScalar
Info = Dict[str, Shaped[Array, '']]
Obs = Float[Array, 'obs_dim']
State = Float[Array, 'state_dim']

Params = dict[str, Any] | FrozenDict[str, Any]

