Abstract: Pre-training Transformers in FP4 precision is becoming a promising approach to gain substantial speedup, but it comes with a considerable loss of accuracy. Microscaling (MX) data format provides a fine-grained per-group quantization method to improve the representation ability of the FP4 format and is supported by the next-generation Blackwell GPU architecture. However, training with MXFP4 data format still results in significant degradation and there is a lack of systematic research on the reason.
In this work, we propose a novel training method TetraJet for a more accurate FP4 training. We comprehensively evaluate all of the quantizers involved in the training, and identify the weight oscillation problem in the forward pass as the main source of the degradation in MXFP4 training. Therefore, we introduce two novel methods, EMA Quantizer (Q-EMA) and Adaptive Ramping Optimizer (Q-Ramping), to resolve the oscillation problem. Extensive experiments on Vision Transformers demonstrate that TetraJet consistently outperforms the existing 4-bit training methods, and Q-EMA \& Q-Ramping can provide additional enhancement by effectively reducing oscillation. We decreased the accuracy degradation by more than 50% compared to the baseline, and can even achieve competitive performance compared to full precision training.
Lay Summary: Pre-training large AI models like Transformers usually requires a lot of computing power. One way to make training faster and more efficient is to use lower-precision numbers like 4-bit floating point (FP4). While promising, FP4 often leads to a noticeable drop in model accuracy. A newer 4-bit format called "Microscaling FP4" shows promise, but it still struggles with performance issues, and the reasons behind this are not well understood.
In our work, we introduce **TetraJet**, a new training method specifically designed to improve the accuracy of FP4 training with unbiased gradient calculation. Unlike previous approaches that may only use low-precision in parts of the model, TetraJet uses FP4 for **activations, weights, and gradients**, making it a truly low-precision training solution. Further detailed analysis shows that a key issue is "weight oscillation": unstable changes in model parameters during training. To fix this, we introduce a new training framework that includes two simple but powerful techniques: one that smooths out the quantization process (**Q-EMA**) and another that stabilizes optimization (**Q-Ramping**).
Together, these methods make low-precision training much more reliable. We show that our approach significantly reduces accuracy loss by more than half, and in some cases, performs almost as well as training with full precision, but with much lower computational cost.
Link To Code: https://github.com/thu-ml/TetraJet-MXFP4Training
Primary Area: Deep Learning->Algorithms
Keywords: Efficient Machine Learning, Low-Precision Training, Quantization, FP4, Microscaling
Submission Number: 8519
Loading