jax>=0.4.13
flax
distrax
chex
tensorflow-datasets
