jax>=0.4.14
jaxlib>=0.4.14
multipledispatch
numpy
tqdm

[cpu]
jax[cpu]>=0.4.14

[cuda]
jax[cuda]>=0.4.14

[dev]
dm-haiku
flax
funsor>=0.4.1
graphviz
jaxns==2.4.8
matplotlib
optax>=0.0.6
pylab-sdk
pyyaml
requests
tensorflow_probability>=0.18.0

[doc]
ipython
nbsphinx>=0.8.9
readthedocs-sphinx-search>=0.3.2
sphinx>=5
sphinx_rtd_theme
sphinx-gallery

[examples]
arviz
jupyter
matplotlib
pandas
seaborn
scikit-learn
wordcloud

[test]
importlib-metadata<5.0
ruff>=0.1.8
pytest>=4.1
pyro-api>=0.1.1
scipy>=1.9

[tpu]
jax[tpu]>=0.4.14
