Scaling Deep Learning Training with MPMD Pipeline Parallelism

Published: 11 Feb 2025, Last Modified: 13 May 2025MLSys 2025 withshepherdingEveryoneRevisionsBibTeXCC BY 4.0
Keywords: Large-Language Models, Distributed Machine Learning, Pipeline Parallelism, Single-Program Multiple-Data, Multiple-Program Multiple-Data
Abstract: We present JaxPP, a system for efficiently scaling the training of large deep learning models with flexible pipeline parallelism. We introduce a seamless programming model that allows implementing user-defined pipeline schedules for gradient accumulation. JaxPP automatically distributes tasks, corresponding to pipeline stages, over a cluster of nodes and automatically infers the communication among them. We implement a MPMD runtime for asynchronous execution of SPMD tasks. The pipeline parallelism implementation of JaxPP improves hardware utilization by up to $1.16\times$ with respect to the best performing SPMD configuration.
Submission Number: 253
Loading

OpenReview is a long-term project to advance science through improved peer review with legal nonprofit status. We gratefully acknowledge the support of the OpenReview Sponsors. © 2025 OpenReview