chex
jax[cuda12_pip]
jaxlib
optax
tensorflow_probability
