-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
jax[cuda11_pip]==0.4.16
distrax
chex
optax
orbax
rlax
flax
numpy
scipy

tensorflow>=2.13.0
tqdm
tqdm-multiprocess==0.0.11
pre-commit==3.3.3
tensorflow_datasets>=4.9.2

wandb
Pillow
matplotlib
tqdm
absl-py
ml-collections

einops
tyro
