To reproduce (on linux due to issues with jaxlib) run:

conda env create --file environment.yaml
conda activate sde
python wandb_script_multiseed.py
