chex==0.1.85
jax==0.4.20
jaxlib==0.4.20
matplotlib==3.8.4
numpy==1.26.3
optax==0.1.7
pip==23.3.1
scipy==1.12.0
tqdm==4.65.0
