"""Type definitions."""
from typing import Callable, NamedTuple, Union

import haiku as hk
import jax.numpy as jnp
import numpy as np

Array = Union[np.ndarray, jnp.ndarray]
Forwardable = Union[Callable, hk.Module]


class ActorOutput(NamedTuple):
  observation: Array
  reward: Array
  is_first: Array
  is_last: Array
  action: Array
