TL;DR: We prove that the softmax function cannot robustly model sharp functions with increasing size, based on several controlled experimental observations over simple attention heads, as well as in language models.
Abstract: A key property of reasoning systems is the ability to make sharp decisions on their input data. For contemporary AI systems, a key carrier of sharp behaviour is the softmax function, with its capability to perform differentiable query-key lookups. It is a common belief that the predictive power of networks leveraging softmax arises from "circuits" which sharply perform certain kinds of computations consistently across many diverse inputs. However, for these circuits to be robust, they would need to generalise well to arbitrary valid inputs. In this paper, we dispel this myth: even for tasks as simple as finding the maximum key, any learned circuitry must disperse as the number of items grows at test time. We attribute this to a fundamental limitation of the softmax function to robustly approximate sharp functions with increasing problem size, prove this phenomenon theoretically, and propose adaptive temperature as an ad-hoc technique for improving the sharpness of softmax at inference time.
Lay Summary: One of the most important components of modern deep learning models (including large language models) is the softmax function, which provides a mechanism for the AI model to carefully focus on the most important parts of the input for answering the given query -- for example, when translating a sentence between languages, an AI system may use softmax to highlight subject-object pairs, or adjective-noun relationships. It has been long presumed that this mechanism is critical for enabling AI systems to perform complex reasoning, and that the detected patterns of attention remain robust for all possible inputs. In our paper we dispel this myth, and show that, as you provide inputs that grow beyond the largest ones the AI has seen during training, the amounts of focus emitted by softmax -- at least, as it is currently leveraged in modern AI systems -- must provably disperse, converging to a situation where the model is unable to significantly focus on any specific part of the input. Appropriately addressing this challenge is therefore important for future work on open-ended or longer-horizon problems with AI.
Primary Area: Deep Learning->Theory
Keywords: softmax, size generalisation, attention, transformers, length generalisation, sharpness, entropy
Submission Number: 12291
Loading