-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
jaxlib==0.4.28+cuda12.cudnn89
nvidia-cudnn-cu12==8.9.7.29
jax==0.4.28
tensorflow==2.16.1
flax==0.8.3
numpyro==0.14.0
pandas==2.1.0
tqdm==4.66.1
hydra-core==1.3.2
hydra-joblib-launcher==1.2.0
hydra-submitit-launcher==1.2.0
wandb==0.15.10
ml-collections==0.1.1
ott-jax==0.4.6
tensorflow-probability==0.24.0
dm-haiku==0.0.12
distrax==0.1.5
blackjax==1.0.0
mergedeep==1.3.4
matplotlib==3.8.4