Trained Transformers Learn Linear Models In-Context

Published: 01 Nov 2023, Last Modified: 12 Dec 2023R0-FoMo SpotlightEveryoneRevisionsBibTeX
Keywords: in-context learning, transformers, neural networks, self-attention, generalization, distribution shifts
TL;DR: We study the dynamics of gradient flow in two-layer linear transformers trained on random instances of linear regression tasks. We prove global convergence as well as characterize prediction error, including under distribution shift.
Abstract: Attention-based neural network sequence models such as transformers have the capacity to act as supervised learning algorithms: They can take as input a sequence of labeled examples and output predictions for unlabeled test examples. Indeed, recent work by Garg et al. has shown that when training GPT2 architectures over random instances of linear regression problems, these models' predictions mimic those of ordinary least squares. Towards understanding the mechanisms underlying this phenomenon, we investigate the dynamics of in-context learning of linear predictors for a transformer with a single linear self-attention layer trained by gradient flow. We show that despite the non-convexity of the underlying optimization problem, gradient flow with a random initialization finds a global minimum of the objective function. Moreover, when given a prompt of labeled examples from a new linear prediction task, the trained transformer achieves small prediction error on unlabeled test examples. We further characterize the behavior of the trained transformer under distribution shifts.
Submission Number: 84
Loading