#jax>=0.4.0
numpy=1.26.0
jax==0.4.25
jaxlib>=0.4.0
optax>=0.1.3
numpy>=1.19.0
tensorflow>=2.9.0
tensorflow_datasets>=4.3.0
wandb
stadion
pot
tqdm
torch
torchsde
matplotlib