Transformers Implement Functional Gradient Descent to Learn Non-Linear Functions In Context

Published: 02 May 2024, Last Modified: 25 Jun 2024ICML 2024 PosterEveryoneRevisionsBibTeXCC BY 4.0
Abstract: Many neural network architectures are known to be Turing Complete, and can thus, in principle implement arbitrary algorithms. However, Transformers are unique in that they can implement gradient-based learning algorithms *under simple parameter configurations*. This paper provides theoretical and empirical evidence that (non-linear) Transformers naturally learn to implement gradient descent *in function space*, which in turn enable them to learn non-linear functions in context. Our results apply to a broad class of combinations of non-linear architectures and non-linear in-context learning tasks. Additionally, we show that the optimal choice of non-linear activation depends in a natural way on the class of functions that need to be learned.
Submission Number: 2401
Loading