TL;DR: Examination of how multi-step functional gradient descent may be performed exactly with appropriately designed self-attention and criss attention layers, plus skip connections, with a demonstration on two real datasets.
Abstract: In-context learning based on attention models is examined for data with categorical outcomes, with inference in such models viewed from the perspective of functional gradient descent (GD). We develop a network composed of attention blocks, with each block employing a self-attention layer followed by a cross-attention layer, with associated skip connections. This model can exactly perform multi-step functional GD inference for in-context inference with categorical observations. We perform a theoretical analysis of this setup, generalizing many prior assumptions in this line of work, including the class of attention mechanisms for which it is appropriate. We demonstrate the framework empirically on synthetic data, image classification and language generation.
Lay Summary: The Transformer is widely used as a generative model in virtually all language models being deployed in practice today. In spite of the success of such models, little is known about how they work. This paper has sought to provide insight on the mechanisms by which Transformers respond and adapt to the prompt that is provided as input. A key advance of this paper concerns the form of observed data, which is categorical. By this we mean that the observations are from a finite (but large) discrete set, which corresponds to the vocabulary of tokens used in language models. We have shown that the Transformer can learn to perform prompt-dependent inference based on a widely studied mathematical framework called gradient descent. This insight suggests ways in which the Transformer can adapt to prompts as applied to language.
Link To Code: https://github.com/aarontwang/icl_attention_categorical
Primary Area: Deep Learning->Attention Mechanisms
Keywords: attention networks, in-context learning, Transformers
Submission Number: 6787
Loading