# may also need: conda install -c nvidia cuda-nvcc
numpy
jax==0.4.7
numpyro==0.11.0
flax==0.6.4
jupyter
matplotlib
torch
tqdm
h5py
matplotlib
omegaconf
jaxopt==0.6
tensorboard