Transformer Efficiently Learns Low-dimensional Target Functions In-context

Published: 18 Jun 2024, Last Modified: 03 Jul 2024TF2M 2024 PosterEveryoneRevisionsBibTeXCC BY 4.0
Keywords: learning theory, in-context learning, representation learning, Transformer, single-index model
TL;DR: Transformers can efficiently learn low-dimensional nonlinear function in-context and outperform algorithms directly working on the test prompt.
Abstract: Transformers can efficiently learn in-context from example demonstrations. We study ICL of a nonlinear function class via transformer with a nonlinear MLP layer: given a class of single-index target functions $f_*(\boldsymbol{x}) = \sigma_*(\langle\boldsymbol{x},\boldsymbol{\beta}\rangle)$, where the index features $\boldsymbol{\beta}\in\mathbb{R}^d$ are drawn from a rank-$r\ll d$ subspace, we show that a nonlinear transformer optimized by gradient descent learns $f_*$ in-context with a prompt length that only depends on the dimension of function class $r$. In contrast, an algorithm that directly learns $f_*$ on the test prompt yields a statistical complexity that scales with the ambient dimension $d$. Our result highlights the adaptivity of ICL to low-dimensional structures of the function class.
Submission Number: 48
Loading