Attention as Inference via Fenchel Duality

TMLR Paper534 Authors

24 Oct 2022 (modified: 28 Feb 2023)Rejected by TMLREveryoneRevisionsBibTeX
Abstract: Attention has been widely adopted in many state-of-the-art deep learning models. While the significant performance improvements it brings have attracted great interest, attention is still poorly understood theoretically. This paper presents a new perspective to understand attention by showing that it can be seen as a solver of a family of estimation problems. In particular, we describe a convex optimization problem that arises in a family of estimation tasks commonly appearing in the design of deep learning models. Rather than directly solving the convex optimization problem, we solve its Fenchel dual and derive a closed-form approximation of the optimal solution. Remarkably, the solution gives a generalized attention structure, and its special case is equivalent to the popular dot-product attention adopted in transformer networks. We show that T5 transformer has implicitly adopted the general form of the solution by demonstrating that this expression unifies the word mask and the positional encoding functions. Finally, we discuss how the proposed attention structures can be integrated in practical models and empirically show that the convex optimization problem indeed provides a principle justifying the attention module design.
Submission Length: Regular submission (no more than 12 pages of main content)
Assigned Action Editor: ~Colin_Raffel1
Submission Number: 534
Loading