jax >= 0.2.27
flax >= 0.4.0
ml_collections >= 0.1.0
tqdm >= 4.60.0
optax >= 0.0.6
absl-py >= 0.12.0
scipy >= 1.6.0
wandb >= 0.12.14
distrax @ git+https://github.com/deepmind/distrax