Using Mamba https://mamba.readthedocs.io/en/latest/ to manage the packages. After installing Mamba, run the following commands on a machine with Nvidia GPUs to install the python packages.

mamba create -n jax_flax python==3.11
mamba activate jax_flax
mamba install jaxlib=*=*cuda* jax -c conda-forge
mamba install cuda-nvcc -c nvidia
mamba install scikit-learn
mamba install pytorch torchvision cpuonly -c pytorch
pip install sacred==0.8.4 chex==0.1.82 optax==0.1.7 dm-tree==0.1.8 dm-haiku
