Transformers Learn Higher-Order Optimization Methods for In-Context Learning: A Study with Linear Models

Published: 07 Nov 2023, Last Modified: 13 Dec 2023M3L 2023 PosterEveryoneRevisionsBibTeX
Keywords: in-context learning, transformers, linear regression
TL;DR: a transformer model trained on linear regression learns to implement an algorithm more similar to higher-order optimization methods, than to gradient descent.
Abstract: Transformers are remarkably good at *in-context learning* (ICL)---learning from demonstrations without parameter updates---but how they perform ICL remains a mystery. Recent work suggests that Transformers may learn in-context by internally running Gradient Descent, a first-order optimization method. In this paper, we instead demonstrate that Transformers learn to implement higher-order optimization methods to perform ICL. Focusing on in-context linear regression, we show that Transformers learn to implement an algorithm very similar to *Iterative Newton's Method*, a higher-order optimization method, rather than Gradient Descent. Empirically, we show that predictions from successive Transformer layers closely match different iterations of Newton's Method *linearly*, with each middle layer roughly computing 3 iterations. In contrast, *exponentially* more Gradient Descent steps are needed to match an additional Transformers layer; this suggests that Transformers have an comparable rate of convergence with high-order methods such as Iterative Newton, which are exponentially faster than Gradient Descent. We also show that Transformers can learn in-context on ill-conditioned data, a setting where Gradient Descent struggles but Iterative Newton succeeds. Finally, we show theoretical results which support our empirical findings and have a close correspondence with them: we prove that Transformers can implement $k$ iterations of Newton's method with $\mathcal{O}(k)$ layers.
Submission Number: 40