JaxMARL: Multi-Agent RL Environments and Algorithms in JAX

Published: 26 Sept 2024, Last Modified: 13 Nov 2024NeurIPS 2024 Track Datasets and Benchmarks PosterEveryoneRevisionsBibTeXCC BY 4.0
Keywords: reinforcement learning, multi-agent, multi-agent reinforcement learning, jax
TL;DR: We provide an efficient JAX-based MARL repository, with new and existing environments and algorithms.
Abstract: Benchmarks are crucial in the development of machine learning algorithms, significantly influencing reinforcement learning (RL) research through the available environments. Traditionally, RL environments run on the CPU, which limits their scalability with the computational resources typically available in academia. However, recent advancements in JAX have enabled the wider use of hardware acceleration, enabling massively parallel RL training pipelines and environments. While this has been successfully applied to single-agent RL, it has not yet been widely adopted for multi-agent scenarios. In this paper, we present JaxMARL, the first open-source, easy-to-use code base that combines GPU-enabled efficiency with support for a large number of commonly used MARL environments and popular baseline algorithms. Our experiments show that, in terms of wall clock time, our JAX-based training pipeline is up to 12,500 times faster than existing approaches. This enables efficient and thorough evaluations, potentially alleviating the evaluation crisis in the field. We also introduce and benchmark SMAX, a vectorised, simplified version of the popular StarCraft Multi-Agent Challenge, which removes the need to run the StarCraft II game engine. This not only enables GPU acceleration, but also provides a more flexible MARL environment, unlocking the potential for self-play, meta-learning, and other future applications in MARL. The code is available at https://github.com/flairox/jaxmarl.
Supplementary Material: pdf
Submission Number: 814
Loading