DrJAX: Scalable and Differentiable MapReduce Primitives in JAX

Published: 18 Jun 2024, Last Modified: 09 Jul 2024WANT@ICML 2024 PosterEveryoneRevisionsBibTeXCC BY 4.0
Keywords: parallel machine learning, distributed machine learning, software, jax, mapreduce, federated learning
TL;DR: We present DrJAX, a JAX-based library designed to support large-scale distributed and parallel machine learning algorithms that use MapReduce-style operations.
Abstract: We present DrJAX, a JAX-based library designed to support large-scale distributed and parallel machine learning algorithms that use MapReduce-style operations. DrJAX leverages JAX's sharding mechanisms to enable native targeting of TPUs and state-of-the-art JAX runtimes, including Pathways. DrJAX embeds building blocks for MapReduce computations as primitives in JAX. This enables three key benefits. First, DrJAX computations can be translated directly to XLA HLO, enabling flexible integration with a wide array of ML training platforms. Second, DrJAX computations are fully differentiable. Last, DrJAX computations can be interpreted out to existing batch-processing compute systems, including traditional MapReduce systems like Apache Beam and cross-device compute systems like those powering federated learning applications. We show that DrJAX provides an easily programmable, performant, and scalable framework for parallelized algorithm development.
Submission Number: 31
Loading