jax[cuda12]==0.4.26
ipdb>=0.13.13
wandb>=0.17.0
flax>=0.8.3
jaxtyping>=0.2.28
einops>=0.8.0
matplotlib>=3.8.4
tqdm>=4.66.4
colour>=0.1.5
seaborn>=0.13.2
equinox>=0.11.4
jraph>=0.0.6.dev0
tensorflow-probability>=0.24.0
attrs>=23.2.0
numpy>=1.26.4
scipy>=1.12.0
control>=0.9.4
optax>=0.1.9
rich>=13.7.0
pyyaml>=6.0.1