#!/bin/bash

conda env create -f jax_environment.yml
eval "$(conda shell.bash hook)"
conda activate jax
pip install pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
pip install wandb ml_collections dm-haiku 

cd brax
pip install -e .[develop]
cd ..

pip install -e .