\section{Analysis Of Algorithm~\ref{alg:main_result}}\label{sec:newton}

We introduce the concept of a $(l,M)$-good function in Section~\ref{sub:newton:good} and discuss the notion of a well-initialized point. Subsequently, we will present our approximation and update rule methods in Section~\ref{sub:newton:approximation}. In light of the optimization problem introduced in Definition~\ref{def:attention}, we put forward Algorithm~\ref{alg:main_result}, and in this section, we establish the correctness and convergence of the algorithm.

\subsection{\texorpdfstring{$(l,M)$}{}-Good Loss Function}
\label{sub:newton:good}

We will now introduce the definition of a $(l, M)$-Good Loss Function. Next, let's revisit the optimization problem defined in Definition~\ref{def:L} as follows:
\begin{align*}
   L(X,Y) := 0.5 \cdot \| \underbrace{ D(X)^{-1} }_{n \times n} \underbrace{ \exp(A_1 X A_2^\top) }_{n \times n} \underbrace{ A_3 }_{n \times d} \underbrace{Y}_{d \times d} - \underbrace{ B  }_{n \times d} \|_F^2
\end{align*}
We will now demonstrate that our optimization function possesses the following properties.

\begin{definition}[$(l,M)$-good Loss function]\label{def:assumptions}
For a function $L : \R^d \rightarrow \R$, if the following conditions hold,
\begin{itemize}
    \item {\bf Hessian is $M$-Lipschitz.} If there exists a positive scalar $M>0$ such that
    \begin{align*}
       \| \nabla^2 L(x,y) - \nabla^2 L(\wt{x},\wt{y}) \| \leq M\cdot ( \| x - \wt{x} \|_2 + \| y - \wt{y} \|_2 )
    \end{align*}
    \item {\bf $l$-local Minimum.}  
    Given $l >0$ as a positive scalar. If there exists a vector $x^* \in \R^{d^2}$ and $y^* \in \R^{d^2}$ such that the following holds
    \begin{itemize}
        \item $\nabla L(x^*, y^*) = {\bf 0}_d$.
        \item $\nabla^2 L(x^*, y^*) \succeq l \cdot I_{2 d^2}$.
    \end{itemize}
    \item {\bf Good Initialization Point.} Let $x_0$ and $y_0$ denote the initialization point. If $r_0:= (\| x_0 -x_*\|_2 + \| y_0 - y_*\|_2 )$ satisfies
    \begin{align*}
        r_0 M \leq 0.1 l.
    \end{align*}    
\end{itemize}
 we say $L$ is $(l,M)$-good 
\end{definition}

Drawing upon Lemma~\ref{lem:hessian_lower_bound} and Lemma~\ref{lem:lips_H_xy}, we can establish that our loss function (See Definition~\ref{def:L}) satisfies the aforementioned assumption.



\subsection{Convergence}\label{sub:newton:approximation}
After introducing the approximation method 'Sparsifier via TensorSketch' in Section~\ref{sec:tensorsketch}, we will now proceed to introduce the update method employed in Algorithm~\ref{alg:main_result}.
In this section, we demonstrate the concept of approximate update and present an induction hypothesis. 

\begin{definition}[Approximate Update]\label{def:update_x_t}

The following process is considered by us
\begin{align*}
    \begin{bmatrix} x(t+1) \\ y(t+1) \end{bmatrix} \gets \begin{bmatrix} x(t) \\ y(t) \end{bmatrix} -  \begin{bmatrix} g( x(t)) \\ g(y(t)) \end{bmatrix} \wt{H}^{-1}
\end{align*}
\end{definition}

A tool from previous work is presented by us now.
\begin{lemma}[Iterative shrinking, a variation of Lemma 6.9 on page 32 of \cite{lsz23}]\label{lem:one_step_shrinking}
If the following conditions hold
\begin{itemize}
    \item Loss Function $L$ is $(l,M)$-good (see  Definition~\ref{def:assumptions}). 
    \item Let $\epsilon_0 \in (0,0.1)$ (see Lemma~\ref{lem:compute_hessian_approximate}). 
    \item Let $x^*, y^*$ be defined in Definition~\ref{def:assumptions} and $x_t,y_t$ be defined in Definition~\ref{def:update_x_t}.
    \item Let $r_t:= \| x_t - x^* \|_2 + \| y_t - y^*\|_2$.
    \item Let $\ov{r}_t: = M \cdot r_t$
\end{itemize}
It follows that  
\begin{align*}
r_{t+1} \leq 2 \cdot (\epsilon_0 + \ov{r}_t/( l - \ov{r}_t ) ) \cdot r_t.
\end{align*} 
\end{lemma}

In this context, where $T$ denotes the total number of iterations in the algorithm, we require the following lemma based on the induction hypothesis to apply Lemma~\ref{lem:one_step_shrinking}. This lemma is a well-established concept in the literature, and for further details, you can refer to \cite{lsz23}.
\begin{lemma}[Induction hypothesis, Lemma 6.10 on page 34 of \cite{lsz23}]\label{lem:newton_induction}
If the following condition hold
\begin{itemize}
    \item $\epsilon = 0.01$ (see Lemma~\ref{lem:compute_hessian_approximate})
    \item Let $x^*, y^*$ be defined in Definition~\ref{def:assumptions} and $x_t, y_t$ be defined in Definition~\ref{def:update_x_t}.
    \item Let $r_t:= \| x_t - x^* \|_2 + \| y_t - y^*\|_2$.
    \item  For each $i \in [T]$, $r_{i} \leq 0.4 \cdot r_{i-1}$, for all $i \in [t]$
    \item Let $l$ and $M$ be Defined in Definition~\ref{def:assumptions}
    \item $M \cdot r_i \leq 0.1 l$, for all $i \in [t]$.
\end{itemize}
It follows that
\begin{itemize}
    \item $r_{t+1} \leq 0.4 r_t$
    \item $M \cdot r_{t+1} \leq 0.1 l$
\end{itemize}
\end{lemma}



\section{Main Theorem}

In this section, we incorporate our analysis together and present our main Theorem.

\begin{theorem}[Main Theorem, Formal version of Theorem~\ref{thm:main_informal}]\label{thm:main}

    If the following conditions hold:
    \begin{itemize}
        \item Let $A_1, A_2, A_3, B \in \R^{n \times d}$.
        \item Let $X, Y \in \R^{d \times d}$.
        \item Let $D(X) \in \R^{n \times n}$ be defined as $D(X) := \diag( \exp(A_1 X A_2^\top ) {\bf 1}_n )$.
        \item Let $\epsilon \in (0, 0.1)$.
        \item Let $\omega \approx 2.37$.
        \item Let $r_0 = \| x_0 - x^*\|_2 + \| y_0 - y^*\|_2$
    \end{itemize}
    
    Then, there exists an algorithm (see Algorithm~\ref{alg:main_result}) that runs in $\log(r_0/\epsilon)$ iterations and spends 
    \begin{align*}
        \wt{O}( \Tmat(n,d,n) + \Tmat(n,d,d) + d^{2\omega}) 
    \end{align*}
    per iteration
    and solves the attention optimization problem (defined in Definition~\ref{def:attention}):
    \begin{align*}
        \min_{X,Y \in \R^{d \times d}}  \| D(X)^{-1} \exp(A_1 X A_2^\top) A_3 Y - B \|_F^2,
    \end{align*}
    and finally outputs $\wt{x}$, $\wt{y}$ such that
    \begin{align*}
        (\| \wt{x} - x^*\|_2 + \|\wt{y} - y^*\|_2) \leq \epsilon
    \end{align*}
    with probability $1-1/\poly(n)$. 
\end{theorem}

\begin{proof}
    This follows from combining Lemma~\ref{lem:compute_gradient_x}, Lemma~\ref{lem:compute_gradient_y}, Lemma~\ref{lem:forward_computation}, Lemma~\ref{lem:hessian_lower_bound}, Lemma~\ref{lem:lip_main}, Lemma~\ref{lem:hessian_property:x}, Lemma~\ref{lem:hessian_property:y}, Lemma~\ref{lem:hessian_xy}, and Lemma~\ref{lem:lips_H_xy}.

    {\bf Number of iterations.}
    
    By Lemma~\ref{lem:newton_induction}, we have that 
    \begin{align*}
        (\| x_T - x^*\|_2 + \|y_T - y^*\|_2) \leq 0.4^T (\| x_0 - x^*\|_2 + \| y_0 - y^*\|_2)
    \end{align*}
    By choosing $T = \log (r_0 / \epsilon)$, the accuracy is satisfied.
    
    {\bf Analysis of time complexity.}
   
    The analysis of the time complexity can be divided into two parts (forward computation and backward computation ).

    


    {\bf Proof of forward computation.}

    This follows from Lemma~\ref{lem:forward_computation}, where we can compute $f,h, c$ in 
    \begin{align*}
        O(\Tmat(n,d,d) + \Tmat(n,n,d))
    \end{align*}
    time.

    {\bf Proof of gradient computation.}

    This follows from Lemma~\ref{lem:compute_gradient_x} and Lemma~\ref{lem:compute_gradient_y}, which takes
    \begin{align*}
        O(\Tmat(n,n,d) + \Tmat(n,d,d))
    \end{align*}
    time.

    {\bf Proof of Hessian computation.}
    
    This follows from Lemma~\ref{lem:compute_hessian_approximate}, which takes
    \begin{align*}
        \wt{O}(nd) + \Tmat(d^2,d^2,d^2)
    \end{align*}
    time.

    {\bf Proof of $g$ times inverse of approximate Hessian.}

    The running time of $g$ times inverse of approximate hessian is as follows 
    \begin{align*}
        \Tmat(d^2,d^2,d^2) = d^{2\omega}
    \end{align*}
    Therefore, for each iteration, the time spent is as follows
    \begin{align*}
         \wt{O}( \Tmat(n,d,n) + \Tmat(n,d,d) + d^{2\omega}) 
    \end{align*}
\end{proof}


\section{More Related Works}
 

\paragraph{Second-order Method}

Second-order method have been used for solving many convex optimization and non-convex optimization problems, such as linear programming \cite{cls19,b20,jswz21,sy21,gs22,hlz23}, empirical risk minimization \cite{lsz19,qszz23}, support vector machines \cite{gsz23}, cutting plan method \cite{lsw15,jlsw20}, semi-definite programming \cite{jkl+20,hjs+22,gs22,syz23_sdp}, hyperbolic programming/polynomials \cite{dszz23,zz23}, streaming algorithm \cite{lsz+23,bs23,syz23_sdp}, federated learning \cite{bsy23}.

\paragraph{Convergence and Deep Neural Network Optimization}

Many works focus on analyzing optimization, convergence guarantees, and training improvement. \cite{ll18} shows that stochastic gradient descent optimizes over-parameterized neural networks on structured data, while \cite{dzps18} demonstrates that gradient descent optimizes over-parameterized neural networks. In \cite{azls19a}, a convergence theory for over-parameterized deep neural networks via gradient descent is developed. \cite{azls19b} analyzes the convergence rate of training recurrent neural networks. \cite{adh+19a} provides a fine-grained analysis of optimization and generalization for over-parameterized two-layer neural networks. \cite{adh+19b} studies exact computation with an infinitely wide neural network. \cite{cgh+19} proposes a Gram-Gauss-Newton method for optimizing over-parameterized neural networks. \cite{zg19} improves the analysis of the global convergence of stochastic gradient descent when training deep neural networks, requiring a milder over-parameterization compared to prior research. Other research, such as \cite{os20,jt19,zpd+20}, focuses on optimization and generalization, while \cite{gms23,lsz23} emphasize the convergence rate and stability. Works like \cite{bpsw20,szz21,als+22,mosw22,z22,cls24_grams} concentrate on specialized optimization algorithms and techniques for training neural networks, and \cite{lss+20,hlsy21} concentrate on leveraging neural network structure.







\paragraph{Algorithmic Regularization}



There is a significant body of research exploring the latent bias inherent in gradient descent when applied to separable classification tasks. This research typically employs logistic or exponentially-tailed loss functions to maximize margins, as demonstrated in previous studies \cite{jt20, glss18, kpot21, jt21, shn+18, mwg+20, nlg+19}. These novel findings have also been applied to non-separable data through the utilization of gradient-based techniques \cite{jdst20, jt19_alg, jt18}. Analysis of implicit bias in regression problems and associated loss functions is carried out using methods such as mirror descent \cite{ykm20, aw20a, aw20b, vkr19, sata22, wgl+20, alh21, glss18} and stochastic gradient descent \cite{hwlm21, lwa21, lr20, zwb+21, dml21, lwm19, bgvv20}. These findings extend to the implicit bias of adaptive and momentum-based optimization methods \cite{jst21, wmcl21, qq19}.



