A Fast Optimization View: Reformulating Single Layer Attention in LLM Based on Tensor and SVM Trick, and Solving It in Matrix Multiplication Time

Published: 07 May 2025, Last Modified: 13 Jun 2025UAI 2025 PosterEveryoneRevisionsBibTeXCC BY 4.0
Keywords: Attention optimization, sketching, regression
TL;DR: We present a new algorithm for optimizing the one-layer attention mechanism in Large Language Models (LLMs).
Abstract: Large language models (LLMs) have played a pivotal role in revolutionizing various facets of our daily existence. Solving attention regression is a fundamental task in optimizing LLMs. In this work, we focus on providing a provable guarantee for the one-layer attention network objective function: given input matrices of a layer, $A_1, A_2, A_3 \in \mathbb{R}^{n \times d}$, our goal is to minimize the loss function: \begin{align*} L(X,Y) = \sum\_{j_0=1}^n \sum\_{i_0=1}^d ( \langle \langle \exp( \mathsf{A}\_{j_0} x ) , {\bf 1}\_n \rangle^{-1} \cdot \exp( \mathsf{A}\_{j_0} x ), A\_{3} Y\_{\*,i_0} \rangle - b\_{j_0,i_0} )^2, \end{align*} where $\mathsf{A}_{j_0} \in \mathbb{R}^{n \times d^2}$ is the $j_0$-th block of the Kronecker product of $A_1$ and $A_2$. The matrix $B \in \mathbb{R}^{n \times d}$ represents the output of a layer, and $b\_{j_0,i_0} \in \mathbb{R}$ is the $(j_0, i_0)$-th entry of $B$. $Y\_{*,i_0} \in \mathbb{R}^d$ is the $i_0$-th column vector of $Y$, and $x \in \mathbb{R}^{d^2}$ is the vectorization of $X$. In self-attention, $Q, K, V \in \mathbb{R}^{d \times d}$ represent the weight matrices for the query, key, and value, respectively. Unlike prior works that relied on simplified and single-variable versions of the attention computation problem, our multivariate loss function analyzes a complete and unsimplified attention layer, treating all these weights as variables, where $X = QK^\top \in \mathbb{R}^{d \times d}$ and $Y = V \in \mathbb{R}^{d \times d}$. We propose an iterative greedy algorithm to train a neural network using the loss function $L(X,Y)$, achieving an error bound of $\epsilon \in (0, 0.1)$. The algorithm runs in $O( ({\cal T}\_{\mathrm{mat}}(n,n,d) + {\cal T}\_{\mathrm{mat}}(n,d,d) + d^{2\omega}) \log(1/\epsilon) )$ time, where ${\cal T}\_{\mathrm{mat}}(a,b,c)$ denotes the time complexity of multiplying an $a \times b$ matrix with a $b \times c$ matrix, and $\omega \approx 2.37$ is the exponent of matrix multiplication.
Latex Source Code: zip
Signed PMLR Licence Agreement: pdf
Readers: auai.org/UAI/2025/Conference, auai.org/UAI/2025/Conference/Area_Chairs, auai.org/UAI/2025/Conference/Reviewers, auai.org/UAI/2025/Conference/Submission265/Authors, auai.org/UAI/2025/Conference/Submission265/Reproducibility_Reviewers
Submission Number: 265
Loading