from typing import Iterable, Union, Optional, Any

import jax.numpy as jnp
import numpy as np


Array1D = Union[jnp.ndarray, np.ndarray]
Array2D = Union[jnp.ndarray, np.ndarray]
Array3D = Union[jnp.ndarray, np.ndarray]
Array4D = Union[jnp.ndarray, np.ndarray]
ArrayND = Union[jnp.ndarray, np.ndarray]

CtrlArray = Union[jnp.ndarray, np.ndarray]
