## ⬇️ Install

We recommend using our Dockerfile. With Docker and the [Nvidia Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/index.html) installed, it can be built with `$ make build` and run with `$ make run`.

For installing from source, first ensure you have the correct [JAX version](https://github.com/google/jax#installation) for your system installed and then install our dependencies with `$ pip install -e .`

### ⚠️ XLand

XLand-Minigrid has a different JAX requirement to JaxMARL and JaxUED. As such, code for xland is held seperately within `xland/`, with seperate a Dockerfile and Makefile located within.

## 🎯 Reproducing results

### 🧗🏼‍♂️ Train policies

All training scripts can be found within `sfl/train` and we include a set of configuration files, contained within `sweep_configs`, to launch experiements across a number of seeds using `wandb` sweeps. We also include a helpful script for easily starting sweeps, `start_wandb_sweep.py`. Using this script, SFL on single-agent JaxNav can be run across 4 GPUs with 1 agent per gpu as:

```bash
$ python start_wandb_sweep.py sweep_configs/jaxnav-sa_sfl_10seeds.yaml 0:4 1
```

We use `wandb` for logging, your API key and entity can be set within the Dockerfile.

### 📊 Evaluate performance

You can either use your own trained policies (downloaded from `wandb`) or our saved checkpoints (located within `checkpoints/`). For all settings (JaxNav single agent, JaxNav multi agent, MiniGrid and XLand), evaluation is a three step process using scripts located within `sfl/deploy` for the first three and within `xland/eval` for XLand.

1. A set number of levels are generated using `*_0_generate_levels.py`. These levels are saved to `sfl/eval/ENV_NAME`, with `ENV_NAME` being either `jaxnav` or `minigrid`.
2. Rollouts for the methods under consideration on these levels are collected with `*_1_rollout.py`, **run this twice for two seeds** (we use 0 and 1). Results from these rollouts are saved as csv's to `sfl/data/eval/results`.
3. The performance of all methods is analysed by `*_2_analyse.py`, with results plotted and saved to `results/`.

If you instead wish to analyse and vizualise performance on the hand-designed test sets, you can use `sfl/deploy/deploy_on_singletons.py` for JaxNav and `sfl/deploy/deploy_minigrid_on_singeltons.py` for MiniGrid. For the sampled test sets used with JaxNav, use `sfl/deploy/deploy_on_sampled_set.py`.

## 🧭 JaxNav

This Jax-based environment for 2D geometric navigation is introduced with this work but the code and documentation is held within [JaxMARL](https://github.com/FLAIROx/JaxMARL/tree/main/jaxmarl/environments/jaxnav).
