MPX: Model Predictive Control in JAX

Published: 29 Apr 2026, Last Modified: 29 May 2026ICRA Workship on FOR 2nd EditionEveryoneRevisionsBibTeXCC BY 4.0
Keywords: Model Predictive Control, Numerical Optimization, GPU Acceleration, Differentiable Optimization
Abstract: In this work, we present MPX, a Python-based library for Model Predictive Control (MPC) that leverages GPU acceleration through JAX. The library is designed for speed and ease of use thanks to its JAX-based implementation: users can write simple Python code that is just-in-time compiled to achieve high-performance execution on hardware accelerators. The framework enables whole-body MPC at sufficiently high update rates for deployment on real robotic systems. System dynamics are seamlessly integrated through MuJoCo XLA. MPX has been developed with learning applications in mind; its GPU compatibility enables efficient parallelization over large batches of environments, making it well-suited for training scenarios with MPC in the loop. Additionally, the solver is fully differentiable, unlocking new possibilities for tightly integrated learning approaches. The related code can be found at https://github.com/iit-DLSLab/mpx
Submission Number: 33
Loading