Numerical Fragility in Transformers: A Layer-wise Theory for Explaining, Forecasting, and Mitigating Instability
TL;DR: We develop a layer-wise forward-error theory for low-precision Transformers that explains, predicts, and mitigates numerical instability.
Abstract: We study numerical instability in Transformers trained in low precision and give a first-order, module-wise theory that predicts when and where forward errors amplify. For self-attention we derive a layer-wise bound that factors into three interpretable diagnostics: a score-scale ratio $\kappa_{\mathrm{score}}$, a row-wise softmax sensitivity $\kappa_{\mathrm{softmax}}$, and the value conditioning $\kappa(V)$. We prove a residual relaxation inequality showing that residual blocks attenuate depth-wise accumulation, and we give a precision- and width-aware LayerNorm indicator $\rho_{\mathrm{LN}}$ with a corresponding first-order error bound in the $\varepsilon$-dominated regime. These pieces combine into a unified forward-stability bound whose per-layer right-hand side is directly estimable during training.
On Tiny-ViT/CIFAR-10, we evaluate the bound and its components in three studies. (Exp-1) A decisive scatter shows that the combined predictor $ \kappa_{\mathrm{softmax}} \cdot (1 + \kappa_{\mathrm{score}}) \cdot \kappa(V) \cdot \lVert W_O \rVert_2 + \kappa_{\mathrm{eff}} + C_{\mathrm{LN}} $ tracks the observed FP32 $ \leftrightarrow $ LP forward mismatch across seeds, widths, and precisions; scaling by $ \epsilon_{\mathrm{mach}} $ collapses mixed-precision points. (Exp-2) The time series of the maximal $ \kappa_{\mathrm{softmax}} $ acts as an early-warning signal, consistently leading increases in forward error by 16--24 steps with correlations of 0.65--0.82 (permutation-test $ p \approx 10^{-3} $) and a Precision@K of 0.89--1.00. (Exp-3) Guided by $ \rho_{\mathrm{LN}} $, a simple LayerNorm-$ \varepsilon $ intervention that targets $ \rho_* $ yields small but consistent stabilization (e.g., mean tail-loss reduction $ \approx 0.010 $ at $ \rho_* = 0.6 $, cap $ = 10^{-2} $), with no architectural changes and negligible overhead.
Overall, our theory provides actionable, unitless diagnostics that (i) explain when self-attention is numerically fragile, (ii) forecast instability in advance, and (iii) motivate a minimally invasive mitigation that trends beneficial in practice.
Submission Number: 2273
Loading