Abstract: The resources required for training transformer models increase as the model size grows, leading to the proposal and implementation of hardware-friendly computation methods such as FP8 and OCP MX. These methods introduce the mixed-precision problem, which has an exponentially large search space and negatively impacts training stability during extensive training over 200B tokens.
Based on our observation that FP mixed precision training shares the same issues of conventional mixed-precision Quantization-Aware Training (QAT), including the oscillation problem of Straight-Through Estimator (STE)-based QAT, we propose Gaussian weight sampling. The proposed method, or GaussWS, addresses the problem on mixed-precision by extending Pseudo Quantization Training (PQT) with an FP-friendly noise distribution and a GPU-friendly noise generation method.
We demonstrate that Gaussian weight sampling is scalable, i.e., supports low-precision FP down to MXFP4, both analytically and empirically. The proposed method is efficient, incurring a low computational overhead as low as 0.47\% on the A100 GPU in terms of Llama2 training tokens per second, and requiring 2 bytes per parameter in GPU memory.
We demonstrate that the proposed method is stable, closely following or even surpassing pre-training performance of BF16 baseline with the OPT2-124M model on the OpenWebText dataset, the Llama2-134M model on the C4 dataset (up to 300B tokens) and the Llama2-1B model on the C4 dataset (up to 100B tokens).
Paper Type: Long
Research Area: Efficient/Low-Resource Methods for NLP
Research Area Keywords: quantization
Languages Studied: English
Submission Number: 3766
Loading