Can Transformers Solve Least Squares to High Precision?

Published: 21 Jun 2024, Last Modified: 26 Jul 2024ES-FoMo-II 2024 PosterEveryoneRevisionsBibTeXCC BY 4.0
Keywords: in-context learning, high precision, least squares, Transformers, gated convolutions, linear regression
TL;DR: This work uncovers limitations of Transformers for high-precision numerical tasks, and proposes gated convolutional models as a promising alternative.
Abstract: Deep sequence models like Transformers have achieved remarkable results across language and vision tasks, but their ability to solve high-precision numerical problems, crucial in scientific settings, remains unclear. We explore the capabilities of existing models on the fundamental problem of least squares, motivated by recent work suggesting Transformers can implement learning algorithms on in-context linear regression problems. Surprisingly, we observe that Transformers struggle to solve least squares to high precision, even in fully determined settings: their MSE plateaus at $10^{-5}$, $9$ orders of magnitude worse than simple algorithms like gradient descent. Probing for sources of low precision, we train on basic linear algebra operations and find that Transformers struggle to precisely learn a simple element-wise multiplication task. Since numerical methods rely heavily on linear algebra primitives, including multiplication, this result suggests that Transformers struggle to implement learning algorithms to high precision, in contrast to prior findings. Our key insight is that gated convolutional models can exactly implement arithmetic circuits, including multiplications and polynomials. Using gated convolutions, we instantiate a weight construction that directly solves least squares to high precision by explicitly implementing gradient descent. Finally, based on our analysis, we propose a simple alternative to standard in-context learning, in which we supervise models to explicitly learn the gradient update rule and apply them iteratively during inference. Using this framework, we achieve $2$ orders of magnitude improvement over parameter-matched Transformers trained on standard in-context learning.
Submission Number: 68
Loading