\section{Introduction}


Large language models (LLMs) like GPT-1 \citep{rns+18}, BERT \citep{dclt18}, GPT-2 \citep{rwc+19}, GPT-3 \citep{bmr+20}, ChatGPT \citep{cha22}, GPT-4 \citep{o23}, OPT \citep{zrg+22}, Llama \citep{tli+23}, and Llama 2 \citep{tms+23} have demonstrated impressive capabilities in natural language processing (NLP). These models understand and generate complex language, enabling a wide range of applications such as sentiment analysis \citep{zdl+23}, language translation \citep{aaa+23}, question answering \citep{bhs+23}, and text summarization \citep{pd23}. Despite their high-quality performance, there remains untapped potential in optimizing and training these massive models, making it a challenging endeavor in the present day.




The primary technical foundation supporting the capabilities of LLMs is the attention matrix $A \in \R^{n \times n}$ \citep{rns+18,vsp+17,bmr+20,dclt18}. The central concept of attention is to learn representations that emphasize the most relevant parts of the input. To be more specific, the attention mechanism finds the correlations of the query vectors and the key vectors using the inner product. The attention weights are then determined based on the similarity of this comparison, indicating the relative importance of each input token. These attention weights are used to compute weighted averages of the value vectors, resulting in the output representation. By leveraging attention, LLMs acquire the ability to focus on the crucial aspects of the input, allowing them to gather pertinent information more efficiently and precisely. This capability enables LLMs to process longer texts and comprehend intricate semantic relationships. Notably, the self-attention mechanism enables LLMs to establish connections between various segments of the input sequence, enhancing their contextual understanding.
Mathematically, the attention computation is defined as follows:
\begin{definition}[The $\ell$-th layer forward computation]\label{def:forward_comp}
Let $n, d$ be positive integers, where $n$ denotes the number of input tokens and $d$ represents the dimensionality of the token embeddings. Let ${\bf 1}_n$ be the $n$-dimensional vector whose entries are all $1$. Let $\diag : \R^n \to \R^{n \times n}$ be a function: each entry of the vector in $\R^n$ is mapped to the diagonal entry of the matrix in $\R^{n \times n}$ and other entries of this matrix are all $0$'s.
Given weights $Q, K, V \in \R^{d \times d}$, we let $X_{\ell} \in \R^{n \times d}$ denote the $\ell$-th layer input and $X_{\ell+1} \in \R^{n \times d}$ is as follows:
\begin{align*}
 X_{\ell+1} \gets D^{-1} \exp(X_{\ell} Q K^\top X_{\ell}^\top) X_{\ell} V
\end{align*} 
where $D:= \diag( \exp( X_{\ell} Q K^\top X_{\ell}^\top ){\bf 1}_n )$ and $\exp(A)_{i, j} = \exp(A_{i, j})$ for all matrices $A$.
\end{definition}



\begin{figure*}[!ht]
    \centering
    \includegraphics[width = 0.9\linewidth]{attention_optimization.pdf}
    \includegraphics[width = 0.63\linewidth]{D_X.pdf}
    \caption{The visualization of the attention optimization problem (see Definition~\ref{def:attention}). Let $A_1, A_2, A_3, B \in \R^{n \times d}$ and $X, Y \in \R^{d \times d}$. We first get $\exp(A_1 X A_2^\top) \in \R^{n \times n}$ by multiplying $A_1$, $X$, and $A_2^\top$. Then, we have $D(X) \in \R^{n \times n}$ by computing $\diag( \exp(A_1 X A_2^\top ) {\bf 1}_n )$. After that, we multiply $D(X)^{-1}$, $\exp(A_1 X A_2^\top )$, $A_3$, and $Y$ and subtract $B$ from their product. Finally, we compute the minimum of the Frobenius norm of their difference. The blue rectangles represent the $n \times d$ matrices, the purple rectangle represents the $n$-dimensioal vector, the red squares represent the $d \times d$ matrices, and the green squares represent the $n \times n$ diagonal matrices.}
    \label{fig:attention_optimization}
\end{figure*}

Traditionally, $D^{-1} \underbrace{\exp(X_{\ell} Q K^\top X_{\ell}^\top)}_{:= A} \in \R^{n \times n}$ is denoted by $\mathrm{Softmax}(\frac{QK^\top}{\sqrt{d}}) \in \R^{n \times n}$, where each entry of $A$ represents how much focus one part of the input should pay to another part. $D^{-1}$ is used to normalize each row of the attention matrix, i.e., the sum of each row of $D^{-1} A \in \R^{n \times n}$ is equal to 1. $X_{\ell} V \in \R^{n \times d}$ is the value matrix that stores the representations or features of each input element. This results in an output representing a combination of the input values, with more important values (as determined by the attention mechanism) contributing more to the final output. In Definition~\ref{def:forward_comp}, we fully expand the $\mathrm{Softmax}$ unit and change the notation system from the traditional definition to highlight the focus of our paper, which is to look for $X = Q K^\top \in \R^{d \times d}$ and $Y = V \in \R^{d \times d}$ that minimizes the following optimization problem with respect to attention computation:
\begin{definition}[Attention optimization]\label{def:attention}
    Let $B \in \R^{n \times d}$ and $X, Y \in \R^{d \times d}$. Given inputs $A_1, A_2, A_3 \in \R^{n \times d}$, we define the attention optimization $\min_{X,Y \in \R^{d \times d}} L(X, Y)$ as:
    \begin{align*}
        \min_{X,Y \in \R^{d \times d}}  \| D(X)^{-1} \exp(A_1 X A_2^\top) A_3 Y - B \|_F^2,
    \end{align*}
    where the diagonal matrix $D(X) \in \R^{n \times n}$ is defined as $D(X) := \diag( \exp(A_1 X A_2^\top ) {\bf 1}_n )$.
\end{definition}

Here, $X = QK^\top$ and $Y = V$ are the weights we want to learn, while $A_1, A_2, A_3$ are the inputs of a layer $X_{\ell}$, and $B$ is the output layer $X_{\ell+1}$. Solving the attention optimization problem exactly takes $O(n^2d)$ time. Since the attention matrix $A = \exp(A_1 X A_2^\top)$ has $n^2$ entries, explicitly computing all entries of $A$ makes it impossible to achieve a sub-quadratic time algorithm. In real-world applications, $n \gg d$ \citep{as23}, so prior works mainly focus on approximating the attention computation to obtain a sub-quadratic time algorithm in $n$.


\paragraph{Limitations of Prior Works}

Attention computation has been analyzed in many recent works \citep{as23,bsz23,gsyz23_quantum,dms23,syz23,dls23,gms23,gsy23_coin}, but none of them provide a complete approximation of the full version of the attention optimization problem. Each of these works simplifies the problem (Definition~\ref{def:attention}) using different strategies. For example, \cite{zhdk23,bsz23} merge $A_1 X$ and $A_3 Y$ into a single matrix, respectively, by approximating
\begin{align*}
    D(X)^{-1} \exp(Q K^\top) V.
\end{align*}
\cite{kmz23} replaces the $\exp$ function in Definition~\ref{def:attention} with polynomials. Another major branch of studies on attention regression simplification focuses on the softmax regression problem, where the matrix $A_3 Y$ is completely ignored, along with its variants.
\begin{definition}[Single softmax regression \citep{dls23} and multiple softmax regression \citep{gsx23_incontext}]\label{def:softmax}
    Given a matrix $A \in \R^{n \times d}$ and a vector $c \in \R^n$, the single softmax regression problem is defined as
    \begin{align*}
     {\bf Part~1.}   \min_{x \in \R^d} \| \langle \exp(Ax) , {\bf 1}_n \rangle^{-1} \exp(Ax) - c \|_2^2 .
    \end{align*}

    Let $D(X) \in \R^{n \times n}$ be defined as in Definition~\ref{def:attention} and $C \in \R^{n \times n}$. Given $A_1, A_2 \in \R^{n \times d}$ and $X \in \R^{d \times d}$, the multiple softmax regression problem is defined as
    \begin{align*}
    {\bf Part~2.} \min_{X \in \R^{d \times d}} \| D(X)^{-1} \exp(A_1 X A_2^\top) - C \|_F^2.
\end{align*}
\end{definition}
Based on the observation in \cite{gsx23_incontext,gsy23_coin}, the equation in {\bf Part 1} of Definition~\ref{def:softmax} can be viewed as a single row of the equation in {\bf Part 2} of Definition~\ref{def:softmax}. When studying multiple softmax regression, \cite{dms23} impose an additional assumption by considering only symmetric matrices:
\begin{align*}
    D(X)^{-1} \exp(A_2 A_2^\top),
\end{align*}
but in exchange, they consider the stronger $\ell_\infty$ norm in multiple softmax regression. \cite{gsy23_hyper,gsx23_incontext} respectively study the rescaled version of single and multiple softmax regression, namely
\begin{align*}
     \min_{x \in \R^d} \| \exp(Ax) - \langle \exp(Ax) , {\bf 1}_n \rangle c \|_2^2
\end{align*}
and 
\begin{align*}
    \min_{X \in \R^{d \times d}} \| \exp(A_1 X A_2^\top) - D(X) C \|_F^2.
\end{align*}

We note that all of these softmax-related regression problems consider simpler variants to achieve sub-quadratic time algorithms: they focus only on single-variable loss functions. Specifically, they minimize the loss by adjusting the weights of the key and query matrix, $X = QK^\top$, while ignoring the weight of the value matrix, $Y = V$. However, simplifying the attention optimization problem in this way may significantly degrade model performance, potentially requiring additional training or fine-tuning. This, in turn, creates deployment challenges \citep{dlz+23}. Therefore, it is natural to ask:
\begin{center}
    {\it How fast can we optimize the training process of the attention matrix without making any simplification to Definition~\ref{def:attention}?} 
\end{center}

\paragraph{Our Result}

Although \cite{as23} shows that a one-step forward approximation of attention can be achieved in $o(n^2)$ time without explicitly formulating the $n \times n$ matrix, the speed at which the loss function can be optimized via iterative methods remains an open problem.
Therefore, in this paper, we provide a complete, unsimplified analysis of the attention optimization problem as defined in Definition~\ref{def:attention}--a task that, to the best of our knowledge, has not been previously undertaken. Additionally, we establish a provable guarantee for optimizing the attention function in the case of a single-layer attention network. 


\begin{theorem}[Informal version of our main theorem (Theorem~\ref{thm:main})]\label{thm:main_informal}
Given $A_1, A_2, A_3 \in \R^{n \times d}$, there exists an algorithm (Algorithm~\ref{alg:main_result}) that runs in $O( (\Tmat(n,d,n) + \Tmat(n,d,d) + d^{2\omega})\log(1/\epsilon))$ and solves the attention optimization problem (Defintion~\ref{def:attention}) up to $\epsilon$ accuracy with probability $1 - 1/\poly(n)$. Here $\omega \approx 2.37$\footnote{$\omega$ denotes the exponent of matrix multiplication \citep{w12,lg14,aw21,dwz23,lg23,wxxz23}, $\Tmat(a,b,c)$ denotes the time of multiplying an $a \times b$ size matrix with another $b \times c$ size matrix, and $\Tmat(n,n,n) = n^{\omega}$. See more details of matrix multiplication notation in Section~\ref{sub:preli:fast_matrix_multi}.}.
\end{theorem}

Optimizing the attention objective is a necessary subproblem that needs to be solved as part of the overall LLM training process, even if it's not sufficient on its own due to the presence of additional layers. Developing faster, more scalable algorithms for attention optimization can help reduce the computational burden of training LLMs.

To establish the correctness of our algorithm, we conduct a comprehensive analysis of the positive semi-definite (PSD) property and the Lipschitz continuity of the Hessian matrix constructed from the attention matrix. These two properties provide the necessary assurance for employing $\mathsf{TensorSRHT}$ and Newton's method, ensuring both fast computation and convergence, respectively. 



\paragraph{Notation}
 
We use $\mathbb{N}$ to denote the set of positive integers. Let $n,d \in \mathbb{N}$. We define $[n] := \{1, 2, \dots, n\}$.  
Let $x, y \in \R^d$. For all $i \in [d]$, we define $x_i \in \R$ as the $i$-th entry of $x$. We define $\langle \cdot, \cdot \rangle : \R^d \times \R^d \to \R$ as $\langle x, y \rangle := \sum_{i = 1}^d x_i y_i$. 
For all $p \in \{1, 2, \infty\}$, we define $\|x\|_p : = (\sum_{i \in [d]} |x_i|^p)^{1/p}$. We use ${\bf 1}_d$ and ${\bf 0}_d$ to denote the $d$-dimensional vectors whose entries are all $1$'s and $0$'s, respectively.
 
Let $A \in \R^{n \times d}$. For all $i \in [n]$ and $j \in [d]$, we use $A_{i, j} \in \R$ to denote the $(i, j)$-th entry of $A$, use $A_{i, *} \in \R^d$ and $A_{*, j} \in \R^n$ to denote vectors, where $(A_{i, *})_j = A_{i, j} = (A_{*, j})_i$. We use $A^\top \in \R^{d \times n}$ to denote the transpose of the matrix $A$. For $X \in \R^{d \times d}$, we define $x = \vect(X) \in \R^{d^2}$ as $X_{i,j} = \vect(X)_{(i - 1) \times d + j}$. For $x \in \R^d$, we define $\diag(x) \in \R^{d \times d}$ as $\diag(x)_{i, i} = x_i$, for all $i \in [d]$ and other entries of $\diag(x)$ are all $0$'s. $\| A \|_F \in \R$ and $\|A\| \in \R$ denote the Frobenius norm and the spectral norm of $A \in \R^{n \times d}$, respectively, where $\|A\|_F := \sqrt{\sum_{i \in [n]} \sum_{j \in [d]} |A_{i, j}|^2}$ and $\| A \| := \max_{x \in \R^d} \| A x \|_2 / \| x \|_2$. Let $\A \in \R^{n^2 \times d^2}$. For each $j_1 \in [n]$, we use $\A_{j_1} \in \R^{n \times d^2}$ to denote one $n \times d^2$ block from $\A \in \R^{n^2 \times d^2}$. Let $C, D \in \R^{d \times d}$ be symmetric matrices, $C \succeq D$ if for all $y \in \R^{d}$, $y^\top C y \geq y^\top D y$. $C$ is said to be a positive semidefinite (PSD) matrix if $y^\top C y \geq 0$. We use $I_d$ to denote the $d \times d$ identity matrix. 
Let $A \in \R^{n_1 \times d_1}$ and $B \in \R^{n_2 \times d_2}$. We define the Kronecker product between matrices $A$ and $B$, denoted $A \otimes B \in \R^{n_1 n_2 \times d_1 d_2}$, as $(A \otimes B)_{(i_1 - 1) n_2 + i_2, (j_1-1)d_2+j_2}$ 
is equal to $A_{i_1,j_1} B_{i_2,j_2}$, where $i_1 \in [n_1], j_1 \in [d_1], i_2 \in [n_2], j_2 \in [d_2]$.

\paragraph{Roadmap}

In Section~\ref{sec:related_work}, we introduce related research work. In Section~\ref{sec:tech_overview}, we provide an overview of the techniques we will use throughout the rest of the paper. 
In Section~\ref{sec:discussion}, we present a discussion of our theoretical results.
In Section~\ref{sec:conclusion}, we draw a conclusion for this paper.
