Learning to Shard: RL for Co-optimizing the Parallelism Degrees and Per-operator Sharding Dimensions in Distributed LLM Inference

NeurIPS 2025 Workshop MLForSys Submission13 Authors

Published: 30 Oct 2025, Last Modified: 13 Nov 2025MLForSys2025EveryoneRevisionsBibTeXCC BY 4.0
Keywords: Large-scale machine learning system, Reinforcement learning, Distributed LLM inference, Sharding strategies
TL;DR: Our work uses RL to jointly optimize both parallelism degrees and per-operator sharding dimensions for distributed LLM inference, achieving improvements over metaheuristic baselines and discovering new strategies that outperform Megatron heuristics.
Abstract: Distributed LLM inference requires careful coordination of parallelization strategies across hundreds to thousands of NPUs to meet production SLOs. Current systems like Megatron-LM rely on static heuristics that separately configure parallelism degrees and per-operator sharding dimensions, leaving significant performance on the table as models scale and hardware topologies diversify. We introduce Learn to Shard, to our knowledge, the first RL-based approach to co-optimize both coarse-grained parallelism degrees and fine-grained per-operator sharding dimensions for distributed LLM inference. Our method employs an attention-based policy over an elite history that learns from high-performing strategies to efficiently navigate the vast combinatorial search space. Evaluated on H100 clusters with MoE models up to 1.6T parameters, Learn to Shard achieves up to 3.5$\times$ throughput improvement over metaheuristic baselines and 1.06$\times$ over Megatron heuristics.
Submission Number: 13
Loading