diffrax==0.6.0
flax==0.8.1
h5py==3.11.0
ipython==8.12.3
jax==0.4.23
jaxopt==0.8.3
jaxtyping==0.2.34
lineax==0.0.7
matplotlib==3.8.3
numpy==2.1.2
optax==0.2.3
orbax==0.1.9
pandas==2.2.3
Pillow==11.0.0
scikit_learn==1.4.1.post1
scipy==1.14.1
tqdm==4.66.2
ott-jax==0.4.5
