\section{Technique Overview}\label{sec:tech_overview}


In this section, we introduce the primary technique employed in this paper. This serves as a summary of our theoretical analysis, which is deferred to the Appendix due to space limitations.

Specifically, in Section~\ref{sub:tech_overview:analysis}, we present the key mathematical properties used to analyze the attention optimization problem, as defined in Definition~\ref{def:attention}. In Section~\ref{sub:tech_overview:algorithm}, we describe the techniques for constructing and analyzing the essential properties of our main algorithm (see Algorithm~\ref{alg:main_result}).

\subsection{Theoretical Analysis}
\label{sub:tech_overview:analysis}


\paragraph{Big Picture}

In this section, we provide an overview of the key techniques used in our theoretical analysis. Our analysis of this multivariate loss function relies on a novel technique that leverages support vector machines (SVM) to reformulate the loss function:
\begin{align*}
    \| D(X)^{-1} \exp(A_1 X A_2^\top) A_3 Y - B \|_F^2
\end{align*}
into the form of inner products and Kronecker product
\begin{align}\label{eq:complicated}
    & ~ \sum_{j_0=1}^n \sum_{i_0=1}^d ( \langle \langle \exp( \mathsf{A}_{j_0} x ) , {\bf 1}_n \rangle^{-1} \notag\\ 
    \cdot & ~ \exp( \mathsf{A}_{j_0} x ), A_{3} Y_{*,i_0} \rangle - b_{j_0,i_0} )^2.
\end{align}
We define 
\begin{itemize}
    \item $u(x)_{j_0} := \exp( \A_{j_0} x )$,
    \item $\alpha(x)_{j_0}:= \langle \exp( \A_{j_0} x ), {\bf 1}_n \rangle$,
    \item $f(x)_{j_0} := \alpha(x)_{j_0}^{-1} u(x)_{j_0} $,
    \item $h(Y)_{i_0}:= A_3  Y_{*,i_0} $, and 
    \item $c(x,y)_{j_0,i_0}:= \langle f(x)_{j_0}, h(y)_{i_0} \rangle - b_{j_0,i_0}$.
\end{itemize}
to decompose Eq.~\eqref{eq:complicated} into small parts and compute their gradient and Hessian respectively.
Unlike prior works that focus on single-variable loss functions \citep{gsyz23_quantum,dms23,syz23,dls23,gms23,gsy23_coin,gsx23_incontext}, our multivariate loss function has a more complex Hessian matrix: $H = \begin{bmatrix}
H_{x,x} & H_{x,y} \\
H_{y,x} & H_{y,y}
\end{bmatrix}$. We first present how we decompose the Hessian into blocks ($X, Y$). Then, we show that the diagonal sub Hessian matrices $H_{x,x}, H_{y,y} \in \mathbb{R}^{d^2 \times d^2}$ are positive semi-definite and provide an upper bound on the spectral norm of the off-diagonal sub Hessian matrices $H_{x,y}, H_{y,x} \in \mathbb{R}^{d^2 \times d^2}$. Next, we demonstrate that the full Hessian matrix $H \in \mathbb{R}^{2d^2 \times 2d^2}$, consisting of the sub matrices $H_{x,x}$, $H_{x,y}$, $H_{y,x}$, and $H_{y,y}$, is also positive semi-definite. Finally, we introduce techniques for proving that the Hessian is Lipschitz.


\paragraph{Problem Reformulation Using SVM} 

The initial works \citep{dls23,gsy23_hyper,syz23} on attention regression problems consider the simplest $\ell_2$ norm, such as $\min_{x \in \R^d} \| \langle \exp(Ax) , {\bf 1}_n \rangle^{-1} \exp(Ax) - c \|_2^2$ ({\bf Part 1} of Definition~\ref{def:softmax}), which corresponds to a single row of the full attention matrix. Inspired by the tensor trick from \cite{dssw18,djs+19},
\begin{align*}
    \vect(A_1 X A_2^\top) = (A_1 \otimes A_2) \vect(X) \in \R^{n^2},
\end{align*}
later works \citep{gsy23_coin,gsx23_incontext} consider a slightly more complicated version of the {\bf Part 1} equation, namely the Frobenius norm of the whole matrix, such as $\min_{X \in \R^{d \times d}} \| D(X)^{-1} \exp(A_1 X A_2^\top) - C \|_F^2$ ({\bf Part 2} of Definition~\ref{def:softmax}). In particular, instead of using a single rescaling factor ({\bf Part 1}), we now have $n$ rescaling factors ({\bf Part 2}). We split $\exp((A_1 \otimes A_2) \vect(X)) \in \mathbb{R}^{n^2}$ into $n$ chunks, each of size $n$, and apply the same rescaling factor within each chunk.

\begin{remark}\label{rem:block}
    For a matrix $\mathsf{A} = A_1 \otimes A_2 \in \mathbb{R}^{n^2 \times d^2}$, we can split it into $n$ blocks, where the first block $\mathsf{A}_1 \in \mathbb{R}^{n \times d^2}$ contains the first $n$ rows of $\mathsf{A}$, the second block $\mathsf{A}_2 \in \mathbb{R}^{n \times d^2}$ contains the next $n$ rows of $\mathsf{A}$, and so on. The $j_0$-th block $\mathsf{A}_{j_0} \in \mathbb{R}^{n \times d^2}$ contains the rows from $(j_0 - 1)n + 1$ to $j_0n$ of $\mathsf{A}$, and the $n$-th block $\mathsf{A}_n \in \mathbb{R}^{n \times d^2}$ contains the rows from $(n - 1)n + 1$ to $n^2$ of $\mathsf{A}$.
\end{remark}

Note that while the tensor trick is necessary for considering matrix norm regression, it is not sufficient to account for the value matrix $A_3 Y$ in the attention optimization problem (Definition~\ref{def:attention}).
Therefore, we take a step further by incorporating both the SVM and the tensor trick to reformulate the entire equation of the attention optimization problem. The standard SVM objective function \citep{j06,cl01,gsz23,tlto23} in optimization can be viewed as the product of a summation over a batch of inner products. Inspired by this, we define $n$ functions $f(x)_{j_0 } = \langle \exp( \A_{j_0} x )  , {\bf 1}_n  \rangle^{-1} \exp( \A_{j_0} x ) \in \R^n$ (see Definition~\ref{def:f}) for each $j_0 \in [n]$ and $d$ functions $h(Y)_{i_0} = A_3 Y_{*,i_0}  \in \R^n$ (see Definition~\ref{def:h}), where $\A_{j_0} \in \R^{n \times d^2}$ is one $n \times d^2$ block from $\A$. Here, $x$ is the vectorization of $X$, and $y$ is the vectorization of $Y$. Then the objective function in Definition~\ref{def:attention}, $\| D(X)^{-1} \exp(A_1 X A_2^\top) A_3 Y - B \|_F^2$, can be turned into 
\begin{align}\label{eq:svm}
    \sum_{j_0=1}^n \sum_{i_0=1}^d ( \langle f(x)_{j_0}, h(Y)_{i_0} \rangle - b_{j_0,i_0} )^2
\end{align}
where $b_{j_0,i_0}$ is the entry of matrix $B \in \R^{n \times d}$. 
 We call this formulation SVM-inspired formulation.




\paragraph{Split Hessian Into Blocks ($X,Y$)}


In the fast approximation and convergence guarantee of the training process for the attention matrix, the PSD property is a key focus in Section~\ref{sec:hessian}. Unlike single or multiple softmax regression or their variants \citep{dls23,gsx23_incontext,gsy23_hyper}, both the weights $X$ and $Y$ (as defined in Definition~\ref{def:attention}) need to be considered, which significantly increases the complexity of the analysis. Therefore, our Hessian matrix discussed in Section~\ref{sec:hessian} has the following format
\begin{align*}
H = \begin{bmatrix}
H_{x,x} & H_{x,y} \\
H_{y,x} & H_{y,y}
\end{bmatrix}
\end{align*}
To establish the positive semi-definite property, we will examine the properties of the matrix above individually.

\paragraph{Positive Semi-Definite For Hessian $H_{x,x}$, $H_{y,y}$}
The positive semi-definite of ${ H_{x,x}, H_{y,y} }$ constitutes a crucial initial step in the proof outlined in Lemma~\ref{lem:hessian_lower_bound}. These Hessian are discussed in detail in Section~\ref{sec:psd_H_xx} and Section~\ref{sec:hessian_Y}. However, proving the PSD property for $H_{x,x}$ and $H_{y,y}$ in the context of the attention optimization problem is non-trivial. The challenges arise from the complex structure of the attention mechanism and the presence of the exponential function in the loss formulation (Definition~\ref{def:attention}).

To tackle these challenges, we dive deep into the structure of $H_{x,x}$ and $H_{y,y}$ (see Section~\ref{sec:psd_H_xx} and Section~\ref{sec:hessian_Y} for details). We express these matrices in terms of the constituent functions of the attention mechanism, such as the exponential function, the softmax function, and the key-query-value transformations. This fine-grained representation allows us to analyze the PSD property at a granular level. Another key insight in our analysis is the role of the regularization term (see details in Section~\ref{sub:preli:regularization}) in the loss function. By carefully choosing the regularization weight, we can ensure that it dominates any potentially negative contributions from the complex attention terms. This is a delicate balancing act, as the regularization weight needs to be large enough to enforce the PSD property, but not so large that it overwhelms the attention signal \citep{llr23,dls23}.

Leveraging this insight, we derive lower bounds on the regularization weight that guarantee the PSD property for $H_{x,x}$ and $H_{y,y}$ (Lemma~\ref{lem:hessian_property:y} and Lemma~\ref{lem:hessian_property:x} respectively). These bounds are expressed in terms of the spectral norms of the attention matrices and the minimum singular values of the key-query-value transformations. By ensuring that the regularization weight exceeds these bounds, we can provably establish the PSD property: there exists a real number $l > 0$ such that
\begin{align*}
H(x) = H_{x,x} \succeq l\cdot I_{d^2} ~ ~\text{and}~~H(y) = H_{y,y} \succeq l\cdot I_{d^2}.
\end{align*}


\paragraph{Upper Bounds for the Spectral Norm of $H_{x,y}$, $H_{y,x}$}

$H_{x,y}$ and $H_{y,x}$ blocks capture the intricate interaction between the weights $X$ and $Y$ in the attention mechanism. Bounding their influence is crucial for ensuring the overall positive semi-definite (PSD) property of the Hessian and the convergence of our optimization algorithm.
To establish the spectral upper bound of $H_{x,y}$, we can decompose $H_{x,y}$ into $\{ {G_i} \}_{i=1}^4$ as described in Lemma~\ref{lem:summary_Gi_xy_psd}. 
Another important technique in our analysis is the use of the boundedness properties of the attention functions. We show that the exponential function and the softmax function, when applied to bounded inputs, produce outputs with controlled spectral norms. This allows us to propagate the boundedness through the complex matrix expressions in $H_{x,y}$.

Leveraging these insights, we derive a spectral upper bound for each component in Lemma~\ref{lem:summary_Gi_xy_psd}, namely
% \begin{align*}
$
    \max_{i \in [n]} \| G_i \| \leq R^2,
$
% \end{align*}
where $R$ is a constant that depends on the spectral norms of the attention matrices. Using these component-wise bounds, we then derive a tight spectral upper bound for the full off-diagonal block $H_{x,y}$,
% \begin{align*}
$
    \| H(x,y) \| \leq nd \cdot 10 R^2
$
% \end{align*}
Given this upper bound, our final focus in the proof of the positive semi-definite property (PSD) will be as follows.

\paragraph{PSD for Hessian $H$}

The challenge in establishing the PSD property for H lies in the complex interplay between its constituent blocks: $H_{x,x}$, $H_{x,y}$, $H_{y,x}$, and $H_{y,y}$. Each of these blocks has its own intricate structure, involving the attention matrices, the exponential function, and the softmax normalization. Moreover, the off-diagonal blocks $H_{x,y}$ and $H_{y,x}$ introduce cross-term interactions that can potentially disrupt the PSD property.

To tackle this challenge, we employ a carefully orchestrated analysis that leverages the properties of the individual blocks and their interrelationships. Our strategy is to show that the PSD property of the diagonal blocks $H_{x,x}$ and $H_{y,y}$ is strong enough to compensate for any potentially negative contributions from the off-diagonal blocks.
 
 With the PSD property of the diagonal blocks and the spectral bounds on the off-diagonal blocks in hand, we then embark on the final step of proving the PSD property for the full Hessian $H$ through using 
 \begin{align*}
     & ~ \begin{bmatrix}
u^\top & v^\top 
\end{bmatrix}
H
\begin{bmatrix}
u \\
v
\end{bmatrix} \\
= & ~ u^\top H_{x,x} u + v^\top H_{y,y} v + u^\top H_{x,y} v + v^\top H_{y,x} u,
 \end{align*}
for any arbitrary $u, v \in \R^{d^2}$.
 
Consequently, based on the positive semi-definite property of the diagonal matrix, the computation of the off-diagonal part of the matrix does not affect the positivity of the entire matrix, thereby establishing a positive semi-definite. With $\alpha_1$, $\alpha_2$, $\alpha_3$ as the bound of the matrix above respectively in Lemma~\ref{lem:hessian_lower_bound}, we have the following result 
\begin{align*}
    H \succeq \min \{ \alpha_1 - \alpha_3, \alpha_2 - \alpha_3 \} \cdot I_{2d^2}
\end{align*}
Given the relationship of $\{ a_i\}_{i = 1}^3$ as discussed above, the positive semi-definite property of the Hessian matrix is established.

\paragraph{Lipschitz Property for Hessian}



The Lipschitz property of the Hessian is determined by the upper bound and Lipschitz property of the basic functions that constitute the Hessian matrix $H$. Since $H$ has three parts $H_{x,x}$, $H_{x,y}$ and $H_{y,y}$. In Section~\ref{sec:hessian_Y}, due to $H(y)$ is independent of $y$, the Lipschitz property can be easily established.
For details of others, we refer the readers to read Section~\ref{sec:lips_H_xy}. 

To compute the Lipschitz continuity of $H_{x,x}$, we begin by providing a brief explanation. In our proof, we first establish upper bounds for the functions $u(x)$, $c(x)$, and $f(x)$ in Lemma~\ref{lem:upper_bound}, which together form the matrix $H_{x,x}$ (as detailed in Section~\ref{sub:preli:help_def_x}). Importantly, we ensure that these basic functions possess the Lipschitz property in Lemma~\ref{lem:basic_lips}.
Using the foundational components mentioned above, we can decompose $H_{x,x}$ into 4 distinct parts denoted as $\{ G_k \}_{k=1}^4$. 
We will leverage the Lipschitz property of the basic functions above and a method introduced below.
The following task is extensively involved in the Lipschitz proof (for each $G_k$), we want to bound  
% \begin{align*}
$
| \prod_{i=1}^t \beta_i(x) - \prod_{i=1}^t \beta_i(\wt{x}) |,
$
% \end{align*}
which has an upper bound as:
\begin{align*}
    \sum_{j=0}^{t-1} | \prod_{i=0}^j \beta_i( \wt{x} ) \prod_{i'=j+1}^t \beta_{i'} ( x ) - \prod_{i=1}^{j+1} \beta_i( \wt{x} ) \prod_{i'=j+2}^{t+1} \beta_{i'} ( x )  |
\end{align*}
where assume that $\beta_0(x)= 1$ and $\beta_{t+1}(x) = 1$ for convenience.  
We will then proceed to establish the Lipschitz continuity of $H_{x,x}$


\begin{align*}
    & ~ \sum_{k=1}^K \|G_{k }(x,y) - G_{k}(\wt{x}, \wt{y})\| \\
    \leq & ~ n^{1.5} \exp(20R^2) ( \|x -\wt{x} \|_2 +  \| y -\wt{y} \|_2 )
\end{align*}


\subsection{Algorithm}
\label{sub:tech_overview:algorithm}



\begin{algorithm}[!ht]\caption{Our Algorithm }\label{alg:main_result}
\begin{algorithmic}[1]
\Procedure{TrainingAlgorithm}{$A_1, A_2, A_3$} \Comment{Theorem~\ref{thm:main_informal}}
   \State Let $x(0), y(0) \in \R^{d^2}$ denote initialization point.
   \For{$t = 0 \to T-1$}
        \State {\color{blue} /*Forward*/ }
        \State Compute $h(y(t) ) \in \R^{n \times d}$ \Comment{$\Tmat(n,d,d)$ time}
        \State Compute $f(x(t)) \in \R^{n \times n}$ \Comment{$\Tmat(n,d,n)$ time}
        \State Compute $c(x(t),y(t)) \in \R^{n \times d}$ (based on $f(x(t))$, $h(y(t))$) \Comment{$\Tmat(n,d,d)$ time}
        \State {\color{blue} /*Gradient*/}
        \State Compute $g(x(t))$ based on Lemma~\ref{lem:compute_gradient_x} \Comment{$\Tmat(n,d,n) + \Tmat(n,d,d)$ time}
        \State Compute $g(y(t))$ based on Lemma~\ref{lem:compute_gradient_y} \Comment{$\Tmat(n,d,n) + \Tmat(n,d,d)$ time}
        \State {\color{blue} /*Hessian*/}
        \State Compute $\wt{H}$ via {\sf TensorSRHT} \Comment{$\wt{O}(nd + d^{2\omega})$}
        \State {\color{blue} /*Update*/}
        \State $\begin{bmatrix} x(t+1) \\ y(t+1) \end{bmatrix} \gets \begin{bmatrix} x(t) \\ y(t) \end{bmatrix} -  \begin{bmatrix} g( x(t)) \\ g(y(t)) \end{bmatrix} \wt{H}^{-1}$ \Comment{$O(d^{2\omega})$}
   \EndFor 
   \State \Return $\begin{bmatrix} x(T) \\ y(T) \end{bmatrix}$
\EndProcedure
\end{algorithmic}
\end{algorithm}

In this section, we present the techniques for constructing and analyzing the properties of our algorithm (see Algorithm~\ref{alg:main_result}). First, we present our technique for simplifying the computation of the attention matrix. Then, we display the techniques for the gradient and Hessian computation. After that, we delve into the primary contribution of our work: TensorSRHT fast approximation for Hessian. Finally, we combine the running time of all of the previous parts (forward function, gradient, Hessian, inverse of approximate Hessian) and conclude the total running time of our algorithm (see Algorithm~\ref{alg:main_result}).

\paragraph{Forward Computation}


To simplify the computation of the attention matrix, we can decompose the computation process into three components: $f$, $c$, and $h$ as defined in Section~\ref{sub:preli:general_def}. The forward computation can then be completed in $O(\Tmat(n,d,d) + \Tmat(n,n,d))$ time, as stated in Lemma~\ref{lem:forward_computation}.


\paragraph{Gradient Computation}
We can compute the gradient in Section~\ref{sec:gradient} as follows:
\begin{align*}
\frac{\d L(x,y)}{\d x} &= \vect ( A_1^\top p(x,y) A_2 ),
\end{align*}
for some matrix $p(x,y) \in \R^{n \times n}$.
Here $A_1^\top p(x,y) A_2$ can be computed in $\Tmat(n,d,n) + \Tmat(d,n,d)$ time. 
Similarly,
\begin{align*}
\frac{\d L(x,y)}{\d y} &= \vect( A_3^\top \widetilde{q}(x,y) ),
\end{align*}
which also takes $\Tmat(n,n,d) + \Tmat(n,d,d)$ time. We will now establish the overall running time for gradient computation.
By utilizing the results from Lemma~\ref{sub:gradient:reform_x} and Lemma~\ref{sub:gradient:reform_y}, we can efficiently compute the gradients of $g(x(t))$ and $g(y(t))$ in $\Tmat(n,d,n) + \Tmat(n,d,d)$ time.



\paragraph{Straightforward Hessian Computation}

Computing the Hessian in straightforward way would take $\Tmat(d^2, n^2, d^2)$ time, because we need to explicitly write down $\A^\top \A \in \R^{d^2 \times d^2}$ where $\A \in \R^{n^2 \times d^2}$. This is too slow, we use sketching ideas to speed up this running time. Using sketching matrices to speed up the Hessian computation has been extensively studied in convex and non-convex optimization \citep{jswz21,lsz19,sy21,gs22,gsz23,qszz23}.

\paragraph{TensorSRHT Fast Approximation for Hessian}
Now, let's delve into the key contribution of this paper. Given that $\A = A_1 \otimes A_2 \in \R^{n^2 \times d^2}$, the time complexity of regression becomes prohibitively expensive. Our contribution aims to execute a fast approximation to significantly reduce the time complexity when using the Newton Method. We construct our $\mathsf{TensorSRHT}$ sketching matrix $S \in \R^{m \times n^2}$ by
\begin{align*}
    S = \frac{1}{\sqrt{m}} P \cdot (QD_1 \otimes QD_2),
\end{align*}
where $P \in \{0, 1\}^{m \times n^2}$ contains only one $1$ at a random coordinate, $Q$ is a $n \times n$ Hadamard matrix, and $D_1$, $D_2$ are two $n \times n$ independent diagonal matrices with diagonals that are each independently set to be a Rademacher random variable (uniform in $\{-1, 1\}$). We choose $m=O(\epsilon^{-2}d^2 \log^3(nd / \epsilon\delta) ) \ll n^2$, where $\epsilon > 0$ is the accuracy parameter and $\delta \in (0, 1)$ is the failure probability, so $S\A \in \R^{m \times d^2}$ is a much smaller matrix compared with $\A \in \R^{n^2 \times d^2}$.
Therefore, using $S\A$, we can construct a sparse Hessian. This reduces the time from $\Tmat(d^2,n^2,d^2)$ down to $\wt{O}(nd) + \Tmat(d^2,d^2,d^2)$\footnote{We consider the regime $n \gg d$ in the paper which is the most common setting in practice because $n$ is the length of the document, and $d$ is feature dimension.}. Additionally, \cite{akk+20,swyz21} show that with $m=O(\epsilon^{-2}d^2 \log^3(nd / \epsilon\delta) )$, the $\mathsf{TensorSRHT}$ sketching matrix $S$ is an oblivious subspace embedding, which may further implies that with high probability ($1 - \delta$), the sketched Hessian $\wt H$ approximates the true Hessian $H$ with bounded error in terms of $\epsilon$.

\paragraph{Overall Time}
Building upon the aforementioned properties, we can apply the Newton Method in Section~\ref{sec:newton} to establish convergence for the regression problem. 
In Summary, we know that 
\begin{itemize}
    \item Computing forward function $\Tmat(n,n,d) + \Tmat(n,d,d)$ time (Lemma~\ref{lem:forward_computation})
    \item Computing gradient takes $\Tmat(n,n,d) + \Tmat(n,d,d)$ time (Lemma~\ref{lem:compute_gradient_x} and Lemma~\ref{lem:compute_gradient_y})
    \item Compute Hessian takes $\wt{O}(nd) + \Tmat(d^2,d^2,d^2)$ (Lemma~\ref{lem:compute_hessian_approximate})
    \item Compute $g$ times inverse of approximate Hessian, this can be done in $\Tmat(d^2,d^2,d^2) = d^{2\omega}$
\end{itemize}


The total time can be expressed as $\widetilde{O}(\Tmat(n,d,n) + \Tmat(n,d,d) + d^{2\omega}) \log(1/\epsilon)$, for $\omega \approx 2.37$. 
