# Welcome to JaxZSC!

This repository implements popular Zero-Shot Coordination and Ad-hoc temawork algorithms in Jax in the style of clean RL.
It focuses on simple and easy to understand single file implementations.
It is geared towards researchers that what to expand these algorithms or use them as baselines in their projects.

![A image displaying core principles of UPD: A ego agent seeds a partner generator that generates new partners for the ego agent to train with scored by a learnability function.](readme_assets/Teaser1CV5.png?raw=true "Unsupervised Partner Design")

It also implements **Unsupervised Partner Design**, our new ad-hoc teamwork method inspired by work on unsupervised environment design.
Note that throughout this repository we refer to UPD as DPD (Dual Partner Design) which was the working name of the algorithm during development.

| Algorithm |  Paper                                                                                                   |
| --------- | -------------------------------------------------------------------------------------------------------- |
| SP        | -                                                                                                        |
| FCP       | https://arxiv.org/abs/2110.08176                                                                         |
| MEP       | https://arxiv.org/abs/2112.11701                                                                         |
| E3T       | https://papers.nips.cc/paper_files/paper/2023/file/07a363fd2263091c2063998e0034999c-Paper-Conference.pdf |
| UPD       | Ours                                                                                                     |

# Structure

```txt
src/
--| agents/
--| --| overcooked/                  # Hardcoded policies from ROTATE
--| --| actors.py                    # Neural networks
--| --| agent_interface.py           # Agent interface class

--| envs/
--| --| overcooked/                  # Augemented overcooked from Jaxmarl and ROTATE
--| --| ogc/                         # The Overcooked Generalisation Challange

--| jaxzsc/                          # Algorithms
--| --| best_response/               # Algothrims for BR training against a population
--| --| brdiv/                       # BRDiv algorithm for eval partners
--| --| dpd/                         # Our unsupervised partner desing method
--| --| --| dpd_ippo_*_w_bias_rnn.py # Our main contribution.
--| --| e3t/                         # E3T baseline
--| --| evaluation/                  # Evaluation scripts
--| --| fcp/                         # FCP baseline
--| --| mep/                         # MEP baseline
--| --| sp/                          # Self-play baseline

--| sweep_configs/ # Configs to be used with run_sweep.py.
--| run_sweep.py # Helper for running sweeps.
```

# Reproduceability

Our experiments were mostly condudected using WANDB sweeps.
Our configurations thus retrain all models used in our work.
This makes it easy to rerun experiments using a single command.
Feel free to adjust the configurations as needed.
Specifically, running

```bash
python src/run_sweep.py $sweep_file_name $which_gpus $how_many_runs_per_gpu
```

Will reproduce all checkpoints used in our work.
For example:

```bash
python src/run_sweep.py dpd_ippo_overcooked_w_bias_rnn_learnability 0,1 4
```

Will retrain all UPD agents for all layouts using 6 seeds on GPUs 0 and 1 with 4 training runs per GPU (30 agents in total).
The experiments will generate an XPID id that looks similiar to this: `$RANDOM_NAME_$INFO_$layout_SEED_0` where `0` will be replaced with 0 - 5.
Given such an xpid evaluation is easy. For example on could use:

```bash
python -m src.jaxzsc.evaluation.eval_overcooked_rnn_checkpoints_on_brdiv_hardcoded --base_xpid SomeXPID_SEED_0 --max_seed 5
```

This will test seeds 0-5 with the heldout partners and write results to the console for easy copying into spreadsheet programs. 

# Instalation

Our work has the same prerequisites as does JaxMARl.
Please follow the instructions there to use this repo.

# Citations

We make use or base our implementation on top of the following papers:

```bib
@inproceedings{rutherford2024noregrets,
    title={No Regrets: Investigating and Improving Regret Approximations for Curriculum Discovery},
    author={Alexander Rutherford and Michael Beukman and Timon Willi and Bruno Lacerda and Nick Hawes and Jakob Nicolaus Foerster},
    booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
    year={2024},
    url={https://arxiv.org/abs/2408.15099}
}
```

```bib
@inproceedings{
    flair2024jaxmarl,
    title={JaxMARL: Multi-Agent RL Environments and Algorithms in JAX},
    author={Alexander Rutherford and Benjamin Ellis and Matteo Gallici and Jonathan Cook and Andrei Lupu and Gar{\dh}ar Ingvarsson and Timon Willi and Ravi Hammond and Akbir Khan and Christian Schroeder de Witt and Alexandra Souly and Saptarashmi Bandyopadhyay and Mikayel Samvelyan and Minqi Jiang and Robert Tjarko Lange and Shimon Whiteson and Bruno Lacerda and Nick Hawes and Tim Rockt{\"a}schel and Chris Lu and Jakob Nicolaus Foerster},
    booktitle={The Thirty-eight Conference on Neural Information Processing Systems Datasets and Benchmarks Track},
    year={2024},
}
```

```bib
@article{coward2024JaxUED,
  title={JaxUED: A simple and useable UED library in Jax},
  author={Samuel Coward and Michael Beukman and Jakob Foerster},
  journal={arXiv preprint},
  year={2024},
}
```

```bib
@misc{ruhdorfer2025overcookedgeneralisationchallenge,
      title={The Overcooked Generalisation Challenge},
      author={Constantin Ruhdorfer and Matteo Bortoletto and Anna Penzkofer and Andreas Bulling},
      year={2025},
      eprint={2406.17949},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2406.17949},
}
```

```bib
@misc{wang2025rotate,
  title={ROTATE: Regret-driven Open-ended Training for Ad Hoc Teamwork},
  author={Caroline Wang, Arrasy Rahman, Jiaxun Cui, Yoonchang Sung, Peter Stone},
  archivePrefix={arXiv},
  primaryClass={cs.AI},
  url = {http://arxiv.org/abs/2505.23686},
  eprint={2505.23686},
  year={2025}
}
```
