numpy>=1.25,<1.26
pandas==1.5.3
jax>=0.4.19
jaxlib>=0.4.19
flax>=0.7.4
optax>=0.1.7
chex>=0.1.83
wandb>=0.13
ipython>=7.34.0
GitPython>=3.1.29