jax==0.4.31
jaxlib==0.4.31
optax==0.2.2
matplotlib==3.8.2
numpy==1.26.4
einops==0.8.0
tqdm==4.66.1
jax-tqdm==0.1.2
flax==0.8.5
chex==0.1.86
jaxtyping==0.2.24