jax==0.4.31
jaxlib==0.4.31
pytest==8.3.5
numpy==1.24.4
torch==2.5.1
numpyro==0.18.0
torchdiffeq==0.2.5
hydra-core==1.3.2 
pandas==2.2.3
dm-haiku==0.0.14
optax==0.2.4
jaxtyping==0.3.2
ipython==9.2.0
matplotlib==3.10.3
pyro-ppl==1.9.1
sbibm==1.0.8
sbi==0.21.0
