Metric Transforms and Low Rank Representations of Kernels for Fast Attention

Published: 25 Sept 2024, Last Modified: 06 Nov 2024NeurIPS 2024 spotlightEveryoneRevisionsBibTeXCC BY-NC-SA 4.0
Keywords: hardness, impossibility, low rank transform, kernel method, LLM, attention
TL;DR: This paper studies the possibility and impossibility for kernel methods and metric transform
Abstract: We introduce a new linear-algebraic tool based on group representation theory, and use it to address three key problems in machine learning. 1. Past researchers have proposed fast attention algorithms for LLMs by approximating or replace softmax attention with other functions, such as low-degree polynomials. The key property of these functions is that, when applied entry-wise to the matrix $QK^{\top}$, the result is a low rank matrix when $Q$ and $K$ are $n \times d$ matrices and $n \gg d$. This suggests a natural question: what are all functions $f$ with this property? If other $f$ exist and are quickly computable, they can be used in place of softmax for fast subquadratic attention algorithms. It was previously known that low-degree polynomials have this property. We prove that low-degree polynomials are the only piecewise continuous functions with this property. This suggests that the low-rank fast attention only works for functions approximable by polynomials. Our work gives a converse to the polynomial method in algorithm design. 2. We prove the first full classification of all positive definite kernels that are functions of Manhattan or $\ell_1$ distance. Our work generalizes an existing theorem at the heart of all kernel methods in machine learning: the classification of all positive definite kernels that are functions of Euclidean distance. 3. The key problem in metric transforms, a mathematical theory used in geometry and machine learning, asks what functions transform pairwise distances in semi-metric space $M$ to semi-metric space $N$ for specified $M$ and $N$. We provide the first full classification of functions that transform Manhattan distances to Manhattan distances. Our work generalizes the foundational work of Schoenberg, which fully classifies functions that transform Euclidean to Euclidean distances. We additionally prove results about stable-rank preserving functions that are potentially useful in algorithmic design, and more. Our core new tool is called the representation theory of the hyperrectangle.
Primary Area: Learning theory
Submission Number: 1720
Loading