Keywords: Large-Language Models, Distributed Machine Learning, Pipeline Parallelism, Single-Program Multiple-Data, Multiple-Program Multiple-Data
Abstract: We present JaxPP, a system for efficiently scaling the training of large deep learning
models with flexible pipeline parallelism.
We introduce a seamless programming model that allows implementing user-defined pipeline
schedules for gradient accumulation.
JaxPP automatically distributes tasks, corresponding to pipeline stages, over
a cluster of nodes and automatically infers the communication among them.
We implement a MPMD runtime for asynchronous execution of SPMD tasks.
The pipeline parallelism implementation of JaxPP improves hardware utilization by up
to $1.16\times$ with respect to the best performing SPMD configuration.
Submission Number: 253
Loading