from typing import Callable, Tuple

import jax
import jax.numpy as jnp

try:
    KeyArray = jax.random.KeyArray          # newer JAX
except AttributeError:
    KeyArray = getattr(jax, "Array", jnp.ndarray)  # older JAX fallback

ParametrisedScoreFunction = Callable[[dict, dict, jnp.ndarray, float], jnp.ndarray]
ScoreFunction = Callable[[jnp.ndarray, float], jnp.ndarray]

SDEUpdateFunction = Callable[
    [KeyArray, jnp.ndarray, float],
    Tuple[
        jnp.ndarray,
        jnp.ndarray,
    ],
]
