-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
jax[tpu]==0.4.20
gym >= 0.26
numpy==1.24.3
distrax==0.1.2
flax==0.7.0
ml_collections >= 0.1.0
tqdm >= 4.60.0
chex==0.1.6
optax==0.1.5
absl-py >= 0.12.0
scipy >= 1.6.0
wandb >= 0.12.14
tensorflow==2.13.0
einops >= 0.6.1
imageio >= 2.31.1
moviepy >= 1.0.3
pre-commit == 3.3.3
