-f https://storage.googleapis.com/jax-releases/jax_releases.html
jax[cuda]>=0.3.9
absl-py>=0.12.0
numpy>=1.21.5
tensorflow-cpu>=2.7.0  # Using tensorflow-cpu to have all GPU memory for JAX.
tensorflow-datasets>=4.4.0
matplotlib>=3.5.0
clu>=0.0.3
flax>=0.3.5
chex>=0.0.7
optax==0.1.0
ml-collections==0.1.0
scikit-image