\section{Gradient}\label{sec:gradient}

In Section~\ref{sec:gradient:x}, we show the gradient with respect to variables $x$. In Section~\ref{sec:gradient:y}, we prove the gradient with respect to variables $y$. In Section~\ref{sub:gradient:compute_c_f_h}, we compute running time of $c,f,h$. In Section~\ref{sub:gradient:reform_x}, we reformulate the gradient with respect to $X$ to compute time complexity. In Section~\ref{sub:gradient:reform_y}, we reformulate the gradient with respect to $Y$ to compute time complexity.

\subsection{Gradient for \texorpdfstring{$x$}{}}\label{sec:gradient:x}

In this section, we compute the gradient for $x$. Most of the following gradient computations can be found in \cite{gsx23_incontext,gsy23_coin}. 
\begin{lemma}[Gradient with respect to $x$]\label{lem:gradient_x}
If the following conditions hold
\begin{itemize}
    \item For each $i \in [d^2]$, let $\A_{j_0,i} \in \R^n$ denote the $i$-th column for $\A_{j_0} \in \R^{n \times d}$  
     \item Let $u(x)_{j_0} \in \R^n$ be defined as Definition~\ref{def:u}
    \item Let $\alpha(x)_{j_0} \in \R$ be defined as Definition~\ref{def:alpha}
    \item Let $f(x)_{j_0} \in \R^n$ be defined as Definition~\ref{def:f}
    \item Let $c(x,:)_{j_0,i_0} \in \R$ be defined as Definition~\ref{def:c}
    \item Let $L(x,:)_{j_0,i_0} \in \R$ be defined as Definition~\ref{def:l}
\end{itemize}
Then, for each $i \in [d^2]$, for each $j_0 \in [n]$, we have
\begin{itemize}
    \item {\bf Part 1.}
    \begin{align*}
        \frac{ \d u(x)_{j_0} }{ \d x_i } = u(x)_{j_0} \circ \A_{j_0,i} 
    \end{align*}
    \item {\bf Part 2.} 
    \begin{align*}
        \frac{\d \alpha(x)_{j_0}}{\d x_i} = \langle u(x)_{j_0} \circ \A_{j_0,i} , {\bf 1}_n \rangle
    \end{align*}
    \item {\bf Part 3.}
    \begin{align*}
        \frac{ \d f(x)_{j_0} }{ \d x_i } = f(x)_{j_0} \circ \A_{j_0,i} - f(x)_{j_0} \cdot \langle f(x)_{j_0} , \A_{j_0,i}\rangle
    \end{align*}
    \item {\bf Part 4.} For a fixed vector $v \in \R^n$ (which doesn't depend on $x$), we have
    \begin{align*}
        \frac{ \d \langle f(x)_{j_0} , v \rangle }{ \d x_i } = \langle  f(x)_{j_0} \circ \A_{j_0,i}, v \rangle - \langle  f(x)_{j_0} , v \rangle \cdot \langle f(x)_{j_0}, \A_{j_0,i} \rangle
    \end{align*}
    \begin{figure}[!ht]
    \centering
    \includegraphics[width = \linewidth]{gradient_part_4.pdf}
    \caption{The visualization of Part 4 of Lemma~\ref{lem:gradient_x}. We are given $f(x)_{j_0} , v, \A_{j_0,i} \in \R^n$. The left-hand side of the equation is the derivative of the inner product of $f(x)_{j_0}$ and $v$ with respect to $x_i \in \R$. For the right-hand side, we have three steps. Step 1: we compute the Hadamard product of $f(x)_{j_0}$ and $\A_{j_0,i}$. Step 2: We find the inner product of this Hadamard product and $v$. Step 3: We subtract the product of two inner products, one is of $f(x)_{j_0}$ and $v$ and the other is of $f(x)_{j_0}$ and $\A_{j_0,i}$, from the result of step 2. The purple rectangles represent the vector $f(x)_{j_0}$. The red rectangles represent the vector $v$. The green rectangles represent the vector $\A_{j_0, i}$.}
    \label{fig:gradient_part_4}
\end{figure}
    \item {\bf Part 5.} For each $i_0 \in [d]$
    \begin{align*}
        \frac{ \d c(x,:)_{j_0,i_0} }{ \d x_i } = \langle  f(x)_{j_0} \circ \A_{j_0,i}, v \rangle - \langle  f(x)_{j_0} , v \rangle \cdot \langle f(x)_{j_0}, \A_{j_0,i} \rangle
    \end{align*}
    \item {\bf Part 6.}
    \begin{align*}
        \frac{\d L(x,:)_{j_0,i_0}}{\d x_i} = c(x,:)_{j_0, i_0} \cdot ( \langle  f(x)_{j_0} \circ \A_{j_0,i}, v \rangle - \langle  f(x)_{j_0} , v \rangle \cdot \langle f(x)_{j_0}, \A_{j_0,i} \rangle )
    \end{align*}
    \item {\bf Part 7.} (for hessian diagonal term)
    \begin{align*}
        \frac{ \d  \langle  f(x)_{j_0} \circ \A_{j_0,i}, v \rangle }{\d x_i} = \langle f(x)_{j_0} \circ \A_{j_0,i} \circ \A_{j_0,i}, v \rangle - \langle f(x)_{j_0} \circ \A_{j_0,i} , v \rangle \cdot \langle f(x)_{j_0} , \A_{j_0,i} \rangle
    \end{align*}
    \begin{figure}[!ht]
    \centering
    \includegraphics[width = \linewidth]{gradient_part_7.pdf}
    \caption{The visualization of Part 7 of Lemma~\ref{lem:gradient_x}. We are given $f(x)_{j_0} , v, \A_{j_0,i} \in \R^n$. First, we compute the Hadamard product between $f(x)_{j_0}$ and $\A_{j_0,i}$. The left-hand side of the equation is the derivative of the inner product of this Hadamard product and $v$ with respect to $x_i \in \R$. For the right-hand side, we have four steps. Step 1: We compute the inner product of the Hadamard product of $f(x)_{j_0}, \A_{j_0,i}, \A_{j_0,i}$ and $v$. Step 2: We compute the inner product of the Hadamard product of $f(x)_{j_0}, \A_{j_0,i}$ and $v$. Step 3: We compute the inner product between $f(x)_{j_0}$ and $\A_{j_0,i}$. Step 4: We subtract the product of steps 2 and 3 from step 1. The purple rectangles represent the vector $f(x)_{j_0}$. The red rectangles represent the vector $v$. The green rectangles represent the vector $\A_{j_0, i}$.}
    \label{fig:gradient_part_7}
\end{figure}
    \item {\bf Part 8.} (for hessian off-diagonal term)
    \begin{align*}
        \frac{ \d  \langle  f(x)_{j_0} \circ \A_{j_0,i}, v \rangle }{\d x_l} = \langle f(x)_{j_0} \circ \A_{j_0,i} \circ \A_{j_0,l}, v \rangle - \langle f(x)_{j_0} \circ \A_{j_0,l} , v \rangle \cdot \langle f(x)_{j_0} , \A_{j_0,i} \rangle
    \end{align*}
    \item {\bf Part 9} (for hessian diagonal term, this can be obtained by using Part 4 as a black-box)
    \begin{align*}
       \frac{ \d \langle f(x)_{j_0}, \A_{j_0,i} \rangle }{\d x_i} = \langle f(x)_{j_0} , \A_{j_0,i} \circ \A_{j_0,i} \rangle - \langle f(x)_{j_0}, \A_{j_0,i} \rangle \cdot \langle f(x)_{j_0}, \A_{j_0,i} \rangle
    \end{align*}
    \begin{figure}[!ht]
    \centering
    \includegraphics[width = \linewidth]{gradient_part_9.pdf}
    \caption{The visualization of Part 9 of Lemma~\ref{lem:gradient_x}. We are given $f(x)_{j_0}, \A_{j_0,i} \in \R^n$. The left-hand side of the equation is the derivative of the inner product of $f(x)_{j_0}$ and $\A_{j_0,i}$ with respect to $x_i \in \R$. For the right-hand side, we have three steps. Step 1: we compute the Hadamard product of $\A_{j_0,i}$ and $\A_{j_0,i}$. Step 2: We find the inner product of $f(x)_{j_0}$ and this Hadamard product. Step 3: We subtract the square of inner product of $f(x)_{j_0}$ and $\A_{j_0,i}$ from the result of step 2. The purple rectangles represent the vector $f(x)_{j_0}$. The green rectangles represent the vector $\A_{j_0, i}$.}
    \label{fig:gradient_part_9}
\end{figure}
    \item {\bf Part 10} (for hessian off-diagonal term, this can be obtained by using Part 4 as a black-box)
    \begin{align*}
       \frac{ \d \langle f(x)_{j_0}, \A_{j_0,i} \rangle }{\d x_l} = \langle f(x)_{j_0} , \A_{j_0,i} \circ \A_{j_0,l} \rangle - \langle f(x)_{j_0}, \A_{j_0,i} \rangle \cdot \langle f(x)_{j_0}, \A_{j_0,l} \rangle
    \end{align*}
\end{itemize}
\end{lemma}
\begin{proof}

{\bf Proof of Part 1.}
See Part 4 of Proof of Lemma 5.18 in \cite{gsx23_incontext} (Page 14).

{\bf Proof of Part 2.}
See Part 5 of Proof of Lemma 5.18 in \cite{gsx23_incontext} (Page 14).

{\bf Proof of Part 3.}
See Part 9 of Proof of Lemma 5.18 in \cite{gsx23_incontext} (page 15).

{\bf Proof of Part 4.}
See Part 14 of Proof of Lemma 5.18 in \cite{gsx23_incontext} (page 15).

{\bf Proof of Part 5.}

Note that by Definition~\ref{def:c}, we have
\begin{align}\label{eq:c_x:}
    c(x,:)_{j_0,i_0} := \langle f(x)_{j_0}, v \rangle - b_{j_0,i_0}
\end{align}

Therefore, we have
\begin{align*}
        \frac{ \d c(x,:)_{j_0,i_0} }{ \d x_i } 
        = & ~ \frac{ \d (\langle f(x)_{j_0}, v \rangle - b_{j_0,i_0}) }{ \d x_i } \\
        = & ~ \frac{ \d \langle f(x)_{j_0}, v \rangle }{ \d x_i } \\
        = & ~ \langle  f(x)_{j_0} \circ \A_{j_0,i}, v \rangle - \langle  f(x)_{j_0} , v \rangle \cdot \langle f(x)_{j_0}, \A_{j_0,i} \rangle,
    \end{align*}
    where the first step comes from Eq.~\eqref{eq:c_x:}, the second step follows from $\frac{\d b_{j_0,i_0}}{\d x_i} = 0$, and the third step is due to {\bf Part 4}.

{\bf Proof of Part 6.}
Noted that by Definition~\ref{def:l}, we have
\begin{align}\label{eq:l_x:}
    L(x,:)_{j_0,i_0} = 0.5 c(x,:)_{j_0,i_0}^2 
\end{align}

Therefore, we have
\begin{align*}
    \frac{\d L(x,:)_{j_0,i_0}}{\d x_i} 
    = & ~ \frac{\d (0.5 c(x,:)_{j_0,i_0}^2)}{\d x_i} \\
    = & ~ c(x,:)_{j_0,i_0} \frac{\d c(x,:)}{\d x_i} \\
    = & ~ c(x,:)_{j_0,i_0} \cdot (\langle  f(x)_{j_0} \circ \A_{j_0,i}, v \rangle - \langle  f(x)_{j_0} , v \rangle \cdot \langle f(x)_{j_0}, \A_{j_0,i} \rangle),
\end{align*}
where the first step is due to Eq.~\eqref{eq:l_x:}, the second step is because of chain rule of derivative, the last step comes from {\bf Part 5}.

{\bf Proof of Part 7.}

We have
\begin{align*}
    \frac{ \d  \langle  f(x)_{j_0} \circ \A_{j_0,i}, v \rangle }{\d x_i} 
    = & ~   \langle  \frac{ \d (f(x)_{j_0} \circ \A_{j_0,i}) }{\d x_i} , v \rangle \\
    = & ~   \langle  \frac{ \d f(x)_{j_0}}{\d x_i} \circ \A_{j_0,i}  , v \rangle \\
    = & ~   \langle  (f(x)_{j_0} \circ \A_{j_0,i} - f(x)_{j_0} \cdot \langle f(x)_{j_0} , \A_{j_0,i}\rangle) \circ \A_{j_0,i}  , v \rangle \\
    = & ~   \langle  f(x)_{j_0} \circ \A_{j_0,i} \circ \A_{j_0,i} - f(x)_{j_0} \cdot \langle f(x)_{j_0} , \A_{j_0,i}\rangle \circ \A_{j_0,i}  , v \rangle \\
    = & ~   \langle  f(x)_{j_0} \circ \A_{j_0,i} \circ \A_{j_0,i} , v \rangle - \langle f(x)_{j_0} \cdot \langle f(x)_{j_0} , \A_{j_0,i}\rangle \circ \A_{j_0,i}  , v \rangle \\
    = & ~   \langle  f(x)_{j_0} \circ \A_{j_0,i} \circ \A_{j_0,i} , v \rangle - \langle f(x)_{j_0} , \A_{j_0,i}\rangle \cdot \langle f(x)_{j_0} \circ \A_{j_0,i}  , v \rangle
\end{align*}
where the first step is due to Fact~\ref{fac:exponential_der_rule}, the second step comes from Fact~\ref{fac:exponential_der_rule}, the third step is because of  {\bf Part 4}, the fourth step is owing to simple algebra, the fifth step follows from Fact~\ref{fac:circ_rules}, and the last step comes from Fact~\ref{fac:circ_rules}.


{\bf Proof of Part 8.}

We have 
\begin{align*}
    \frac{ \d  \langle  f(x)_{j_0} \circ \A_{j_0,i}, v \rangle }{\d x_l} 
    = & ~ \langle  \frac{ \d (f(x)_{j_0} \circ \A_{j_0,i}) }{\d x_l} , v \rangle \\
    = & ~ \langle  \frac{ \d f(x)_{j_0}}{\d x_l} \circ \A_{j_0,i}  , v \rangle \\
    = & ~ \langle (f(x)_{j_0} \circ \A_{j_0, l} - f(x)_{j_0} \cdot \langle f(x)_{j_0}, \A_{j_0,l} \rangle) \circ \A_{j_0,i}, v \rangle \\
    = & ~ \langle  f(x)_{j_0} \circ \A_{j_0,i} \circ \A_{j_0,l} - f(x)_{j_0} \cdot \langle f(x)_{j_0} , \A_{j_0,l}\rangle \circ \A_{j_0,i}  , v \rangle \\
    = & ~ \langle  f(x)_{j_0} \circ \A_{j_0,i} \circ \A_{j_0,l} , v \rangle - \langle f(x)_{j_0} \cdot \langle f(x)_{j_0} , \A_{j_0,l}\rangle \circ \A_{j_0,i}  , v \rangle \\
    = & ~ \langle  f(x)_{j_0} \circ \A_{j_0,i} \circ \A_{j_0,l} , v \rangle - \langle f(x)_{j_0} , \A_{j_0,l}\rangle \cdot \langle f(x)_{j_0} \circ \A_{j_0,i}  , v \rangle
\end{align*}
where the first step comes from Fact~\ref{fac:exponential_der_rule}, the second step is because of Fact~\ref{fac:exponential_der_rule}, the third step follows from {\bf Part 4}, the fourth step is due to simple algebra, the fifth step is owing to Fact~\ref{fac:circ_rules}, and the last step comes from Fact~\ref{fac:circ_rules}.

{\bf Proof of Part 9.}


We have
\begin{align*}
    \frac{ \d \langle f(x)_{j_0}, \A_{j_0,i} \rangle }{\d x_i}
    = & ~  \langle \frac{ \d f(x)_{j_0} }{\d x_i}, \A_{j_0,i} \rangle \\
    = & ~  \langle f(x)_{j_0} \circ \A_{j_0,i} - f(x)_{j_0} \cdot \langle f(x)_{j_0} , \A_{j_0,i}\rangle, \A_{j_0,i} \rangle \\
    = & ~ \langle f(x)_{j_0} , \A_{j_0,i} \circ \A_{j_0,i} \rangle - \langle f(x)_{j_0}, \A_{j_0,i} \rangle \cdot \langle f(x)_{j_0}, \A_{j_0,i} \rangle
\end{align*}
where the first step is due to Fact~\ref{fac:exponential_der_rule}, the second step comes from {\bf Part 4}, and the last step is because of Fact~\ref{fac:circ_rules}.

{\bf Proof of Part 10.}
We have 
\begin{align*}
     \frac{ \d \langle f(x)_{j_0}, \A_{j_0,i} \rangle }{\d x_l}
     = & ~  \langle \frac{ \d f(x)_{j_0}}{\d x_l}, \A_{j_0,i}  \rangle \\
     = & ~ \langle f(x)_{j_0} \circ \A_{j_0,l} - f(x)_{j_0} \cdot \langle f(x)_{j_0} , \A_{j_0,l}\rangle, \A_{j_0,i} \rangle \\
     = & ~ \langle f(x)_{j_0} , \A_{j_0,i} \circ \A_{j_0,l} \rangle - \langle f(x)_{j_0}, \A_{j_0,i} \rangle \cdot \langle f(x)_{j_0}, \A_{j_0,l} \rangle
\end{align*}
where the first step comes from Fact~\ref{fac:exponential_der_rule}, the second step is owing to {\bf Part 4}, and the last step is due to Fact~\ref{fac:circ_rules}.
\end{proof}

\subsection{Gradient With Respect to \texorpdfstring{$y$}{}}\label{sec:gradient:y}

In this section, we compute the gradient with respect to $y$.

\begin{lemma}\label{lem:gradient_y}
If the following conditions hold
\begin{itemize}
\item  Let $v \in \R^n$ which doesn’t depend on x and also doesn’t depend on y.
\item Let $c(:,y)_{j_0,i_0} \in \R$ be defined as Definition~\ref{def:c}.
\item Let $L(:,y)_{j_0,i_0} \in \R$ be defined as Definition~\ref{def:l}.
\item Let 
 $h(y_{i_0}):= \underbrace{ A_3 }_{n \times d} \underbrace{ y_{i_0} }_{d \times 1}.$
 \item Let $h(y_{i_0}) = h(y)_{i_0}$ for convenient
 \item Let $A_{3,*,i_2} \in \R^n$ denote the $i_2$-th column of matrix $A_{3} \in \R^{n \times d}$ for each $i_2 \in [d]$
 \end{itemize}
 Then, we have
 \begin{itemize}
 \item {\bf Part 1.} If $i_1=i_0$
 \begin{align*}
    \frac{\d h(y_{i_0}) }{ \d y_{i_1,i_2} } = A_{3,*, i_2}
 \end{align*}
 \item {\bf Part 2.} If $i_1\neq i_0$
 \begin{align*}
    \frac{\d h(y_{i_0}) }{ \d y_{i_1,i_2} } = {\bf 0}_n
 \end{align*}
 \item {\bf Part 3.} If $i_1 = i_0$
 \begin{align*}
    \frac{\d \langle v, h(y)_{i_0} \rangle }{ \d y_{i_1,i_2}} = \langle v, A_{3,*,i_2} \rangle
 \end{align*}
 \item {\bf Part 4.} If $i_1 \neq i_0$
  \begin{align*}
    \frac{\d \langle v, h(y)_{i_0} \rangle }{ \d y_{i_1,i_2}}  = 0
 \end{align*}
\item {\bf Part 5.} If $i_1 = i_0$
 \begin{align*}
    \frac{\d c(:,y)_{j_0,i_0} }{ \d y_{i_1,i_2}} = \langle v, A_{3,*,i_2} \rangle
 \end{align*}
 \item {\bf Part 6.} If $i_1 \neq i_0$
  \begin{align*}
    \frac{\d c(:,y)_{j_0,i_0} }{ \d y_{i_1,i_2}}  = 0
 \end{align*}
 \item {\bf Part 7.} If $i_1 = i_0$
 \begin{align*}
    \frac{\d L(:,y)_{j_0,i_0} }{ \d y_{i_1,i_2}} = c(:,y)_{j_0,i_0} \langle v, A_{3,*,i_2} \rangle
 \end{align*}
 \item {\bf Part 8.} If $i_1 \neq i_0$
  \begin{align*}
    \frac{\d L(:,y)_{j_0,i_0} }{ \d y_{i_1,i_2}}  = 0
 \end{align*}
 \end{itemize}
 \end{lemma}
 \begin{proof}
{\bf Proof of Part 1.}
\begin{align*}
    \frac{\d h(y_{i_0})}{\d y_{i_1,i_2} }
    = & ~ \frac{\d A_{3} y_{i_0}}{\d y_{i_1,i_2}} \\
    = & ~ A_{3,*,i_2}
\end{align*}
where the first step is due to the definition of $h(y_{i_0})$ (see the Lemma statement), and the last step comes from that for $i \neq i_2,~\frac{\d}{\d y_{i_2}}f(y_{i}) = 0$.

{\bf Proof of Part 2.}
\begin{align*}
    \frac{\d h(y_{i_0})}{\d y_{i_1,i_2} } = & ~ {\bf 0}_n
\end{align*}
where the first step is due to $i_1 \neq i_2$.

{\bf Proof of Part 3.}
\begin{align*}
    \frac{\d \langle v, h(y)_{i_0} \rangle }{ \d y_{i_1,i_2}} = & ~ \langle v, \frac{\d h(y_{i_0})}{\d y_{i_1,i_2}} \rangle \\
    = & ~ \langle v, A_{3,*,i_2} \rangle
\end{align*}
where the first step comes from Fact~\ref{fac:exponential_der_rule}, the second step is due to the result of {\bf Part 1}.

{\bf Proof of Part 4.}
\begin{align*}
    \frac{\d \langle v, h(y)_{i_0} \rangle }{ \d y_{i_1,i_2}}  
    = & ~ \langle v, \frac{\d h(y_{i_0})}{\d y_{i_1,i_2}} \rangle \\
    = & ~ 0
\end{align*}
where the first step is becaues of Fact~\ref{fac:exponential_der_rule}, the second step comes from the result of {\bf Part 2}.

{\bf Proof of Part 5.}
\begin{align*}
    \frac{\d c(:,y)_{j_0,i_0} }{ \d y_{i_1,i_2}} 
    =  & ~ \frac{\d \langle v, h(y)_{i_0} \rangle - b_{j_0,i_0}}{\d y_{i_1,i_2}}\\\
    = & ~ \frac{\d \langle v, h(y)_{i_0} \rangle }{ \d y_{i_1,i_2}} \\
    = & ~ \langle v, A_{3,*,i_2} \rangle
\end{align*}
where the first step comes from the Definition~\ref{def:c}, the second step is because of $\frac{\d b_{j_0,i_0} }{ \d y_{i_1,i_2}} = 0$, and the last step is due to {\bf Part 3}.


{\bf Proof of Part 6.}
\begin{align*}
    \frac{\d c(:,y)_{j_0,i_0} }{ \d y_{i_1,i_2}} 
    =  & ~ \frac{\d \langle v, h(y)_{i_0} \rangle - b_{j_0,i_0}}{\d y_{i_1,i_2}}\\\
    = & ~ \frac{\d \langle v, h(y)_{i_0} \rangle }{ \d y_{i_1,i_2}} \\
    = & ~ 0
\end{align*}
where the first step is due to the Definition~\ref{def:c}, the second step comes from $\frac{\d b_{j_0,i_0} }{ \d y_{i_1,i_2}} = 0$, and the last step is owing to {\bf Part 4}.

{\bf Proof of Part 7.}
\begin{align*}
    \frac{\d L(:,y)_{j_0,i_0} }{ \d y_{i_1,i_2}} 
    = & ~ \frac{\d 0.5 c(:,y)_{j_0,i_0}^2}{\d y_{i_1,i_2}} \\
    = & ~ c(:,y)_{j_0,i_0} \cdot \frac{\d c(:,y)_{j_0,i_0}}{\d y_{i_1,i_2}} \\
    = & ~  c(:,y)_{j_0,i_0} \langle v, A_{3,*,i_2} \rangle
\end{align*}
where the first step is due to the Definition~\ref{def:l}, the second step comes from the chain rule of derivative, and the last step is owing to {\bf Part 5}.

{\bf Proof of Part 8.}
\begin{align*}
    \frac{\d L(:,y)_{j_0,i_0} }{ \d y_{i_1,i_2}} 
    = & ~ \frac{\d 0.5 c(:,y)_{j_0,i_0}^2}{\d y_{i_1,i_2}} \\
    = & ~ c(:,y)_{j_0,i_0} \cdot \frac{\d c(:,y)_{j_0,i_0}}{\d y_{i_1,i_2}} \\
    = & ~  0
\end{align*}
where the first step is because of the Definition~\ref{def:l}, the second step is due to the chain rule of derivative, and the last step comes from {\bf Part 6}.
 \end{proof}

 \subsection{Computation of \texorpdfstring{$c,f,h$}{}}\label{sub:gradient:compute_c_f_h}

 In this section, we explain how to compute $c(x,y), f(x), h(y)$.
\begin{lemma}\label{lem:forward_computation}
If the following conditions hold
\begin{itemize}
    \item For each $j_0 \in [n]$, $i_0 \in [d]$, let $c(x,y)_{j_0,i_0} \in \R$ be defined as Definition~\ref{def:c}. (We can view $c(x,y)$ as an $n \times d$ matrix)
    \item For each $j_0 \in [n]$, let $f(x)_{j_0} \in \R^n$ be defined as Definition~\ref{def:f}. (We can view $f(x)$ as an $n \times n$ matrix)
    \item For each $i_0 \in [d]$, let $h(y)_{i_0} \in \R^n$ be defined as Definition~\ref{def:h}. (We can view $h(y)$ as $n \times d$ matrix)
    \item Let $A_3 \in \R^{n \times d}$
    \item We can view $y$ as an $d \times d$ matrix
\end{itemize}
Then, we can compute $f,h, c$ in $O(\Tmat(n,d,d) + \Tmat(n,n,d))$ time.
\end{lemma}
\begin{proof}
By definition~\ref{def:h}, we have 
\begin{align}\label{eq:h_y}
      \underbrace{h(y) }_{n \times d} =  \underbrace{ A_3 }_{n \times d} \underbrace{ y }_{d \times d}.
\end{align}
First $h(y) \in \R^{n \times d}$ can be viewed as multiplying $n \times d$ matrix ($A_3$) and $d \times d$ matrix ($y$), this can be computed in $\Tmat(n,d,d)$.

\begin{figure}[!ht]
    \centering
    \includegraphics[width = 0.6\linewidth]{h_y.pdf}
    \caption{The visualization of Eq.~\eqref{eq:h_y}. We have $A_3 \in \R^{n \times d}$. $h: \R^{d \times d} \to \R^{n \times d}$ is a function, which maps the matrix $y \in \R^{d \times d}$ to $h(y)$ by multiplying $A_3$ and $y$. The red rectangles represent matrices which are the factors, and the blue rectangle represents the matrix which is the product.}
    \label{fig:h_y}
\end{figure}

We also have
\begin{align}\label{eq:f_x}
    \underbrace{ f(x) }_{n \times n} = \underbrace{ D(X)^{-1} }_{n \times n} \exp( \underbrace{ A_1 }_{n \times d} \underbrace{ X }_{d \times d} \underbrace{ A_2^\top }_{d \times n} ), \mathrm{~~~and~~~} D(X) = \diag( \exp(A_1XA_2^\top) {\bf 1}_n )
\end{align}
Then the computation of $f(x) \in \R^{n \times n}$ can be done in $\Tmat(n,n,d) + \Tmat(n,d,d)$.

\begin{figure}[!ht]
    \centering
    \includegraphics[width = \linewidth]{f_x.pdf}
    \caption{The visualization of Eq.~\eqref{eq:f_x}. We have $A_1, A_2 \in \R^{n \times d}$, $X \in \R^{d \times d}$, and $D(X) \in \R^{n \times n}$ (see Definition~\ref{def:attention} and Figure~\ref{fig:attention_optimization}). First, we find the inverse of the matrix $D(X)$ and compute $\exp(A_1 X A_2^\top) \in \R^{n \times n}$, as shown in Figure~\ref{fig:attention_optimization}. Then, we multiply $D(X)^{-1}$ and $\exp(A_1 X A_2^\top)$ to get $f(x) \in \R^{n \times n}$. The green squares represent the square matrices in $\R^{n \times n}$. The blue rectangles represent the matrices in $\R^{n \times d}$ (the dark blue denotes the transpose of the matrix in $\R^{n \times d}$). The red square represents the square matrices in $\R^{d \times d}$.}
    \label{fig:f_x}
\end{figure}

Given that 
\begin{align}\label{eq:c_xy}
    \underbrace{ c (x,y) }_{n \times d} = \underbrace{ f(x) }_{n \times n} \underbrace{ h(y) }_{n \times d} - \underbrace{ B }_{n \times d}
\end{align}

Then $c$ can be done in $\Tmat(n,n,d)$.

\begin{figure}[!ht]
    \centering
    \includegraphics[width = 0.7\linewidth]{c_xy.pdf}
    \caption{The visualization of Eq.~\eqref{eq:c_xy}. Let $f(x) \in \R^{n \times n}$ (see Figure~\ref{fig:f_x}) and $h(y) \in \R^{n \times d}$ (see Figure~\ref{fig:h_y}). We have $B \in \R^{n \times d}$. We multiply $f(x)$ with $h(y)$ and subtract $B$ from their product to get $c(x, y) \in \R^{n \times d}$. The green square represents the square matrices in $\R^{n \times n}$. The blue rectangles represent the matrix in $\R^{n \times d}$.}
    \label{fig:c_xy}
\end{figure}
\end{proof}

 \subsection{Reformulating Gradient (\texorpdfstring{$x$}{}) in Matrix View}\label{sub:gradient:reform_x}

 In this section, we reformulate the gradient $x$ in the matrix's view.

 \begin{lemma}\label{lem:compute_gradient_x}
If the following conditions hold
    \begin{itemize}
        \item  $\frac{\d L(x,y)_{j_0,i_0}}{\d x_i} 
    =  c(x,y)_{j_0,i_0} \cdot (\langle  f(x)_{j_0} \circ \A_{j_0,i}, h(y)_{i_0} \rangle - \langle  f(x)_{j_0} , h(y)_{i_0} \rangle \cdot \langle f(x)_{j_0}, \A_{j_0,i} \rangle)$
    \item Let $c(x,y) \in \R^{n \times d}$
    \item Let $f(x)_{j_0} \in \R^n$
    \item Let $v = h(y)_{i_0} \in \R^n$
    \item Let $\frac{\d L(x,y)}{\d x} = \sum_{j_0=1}^n \sum_{i_0=1}^d \frac{\d L(x,y)_{j_0,i_0} }{ \d x }$
    \item Let \begin{align*}
        q(x,y)_{j_0} = \sum_{i_0=1}^d c(x,y)_{j_0,i_0} h(y)_{i_0}
\end{align*}
    \end{itemize}
    then, we have
    \begin{itemize}
        \item {\bf Part 1.}
        \begin{align*}
            \frac{\d L(x,y)_{j_0,i_0} } { \d x } = \underbrace{ c(x,y)_{j_0,i_0} }_{ \mathrm{scalar} } \cdot \underbrace{ \A_{j_0}^\top }_{d^2 \times n} \underbrace{ ( \diag(f(x)_{j_0}) - f(x)_{j_0} f(x)_{j_0}^\top ) }_{n \times n} \underbrace{ h(y)_{i_0} }_{n \times 1}  
        \end{align*}
        \item {\bf Part 2.} Suppose $c(x,y), \A, f(x), h(y)$ are given, then $\frac{\d L(x,y)_{j_0,i_0} } { \d x } $ can be computed in $O(n d^2 )$ time.
        \item {\bf Part 3.} 
        \begin{align*}
            \frac{\d L(x,y)}{\d x} = \sum_{j_0=1}^n \underbrace{ \A_{j_0}^\top }_{d^2 \times n} \underbrace{ ( \diag(f(x)_{j_0}) - f(x)_{j_0} f(x)_{j_0}^\top ) }_{n \times n} \underbrace{q(x,y)_{j_0}}_{n \times 1}
        \end{align*}
        \item {\bf Part 4.}
        Suppose $c(x,y), \A, f(x), h(y)$ are given, then 
        $
            \frac{\d L(x,y)}{\d x} \in \R^{d^2} 
       $
        can be computed in $\Tmat(n,d,n) + \Tmat(n,d,d)$ time
    \end{itemize}
 \end{lemma}
 \begin{proof}

{\bf Proof of Part 1.}

From the Lemma statement, we have
\begin{align}\label{eq:lxy_j0_i0}
    \frac{\d L(x,y)_{j_0,i_0}}{\d x_i} 
    =  c(x,y)_{j_0,i_0} \cdot (\langle  f(x)_{j_0} \circ \A_{j_0,i}, h(y)_{i_0} \rangle - \langle  f(x)_{j_0} , h(y)_{i_0} \rangle \cdot \langle f(x)_{j_0}, \A_{j_0,i} \rangle)
\end{align}

Note that by Fact~\ref{fac:circ_rules}, we have
\begin{align*}
    \langle  f(x)_{j_0} \circ \A_{j_0,i}, h(y)_{i_0} \rangle = \A_{j_0,i}^\top \diag(f(x)_{j_0}) h(y)_{i_0}
\end{align*}
and 
\begin{align*}
    \langle  f(x)_{j_0} , v \rangle \cdot \langle f(x)_{j_0}, \A_{j_0,i} \rangle
    = \A_{j_0,i}^\top f(x)_{j_0} f(x)_{j_0}^\top h(y)_{i_0}
\end{align*}


Therefore, Eq.~\eqref{eq:lxy_j0_i0} becomes
\begin{align*}
    \frac{\d L(x,y)_{j_0,i_0}}{\d x_i} 
    = & ~ c(x,y)_{j_0,i_0} \cdot (\A_{j_0,i}^\top \diag(f(x)_{j_0}) h(y)_{i_0} - \A_{j_0,i}^\top f(x)_{j_0} f(x)_{j_0}^\top h(y)_{i_0})\\
    = & ~ c(x,y)_{j_0,i_0} \cdot \A_{j_0,i}^\top ( \diag(f(x)_{j_0}) - f(x)_{j_0} f(x)_{j_0}^\top)h(y)_{i_0},
\end{align*}
where the second step follows from simple algebra.




Thus, we complete the proof.

{\bf Proof of Part 2.}

We first compute 
$( \diag( f(x)_{j_0} ) - f(x)_{j_0} f(x)_{j_0}^\top ) h(y)_{i_0}$, this can be done in $O(n)$ time.

Then we can compute the rest, it takes $O(nd^2)$ time.

{\bf Proof of Part 3 and Part 4.}

Firstly, we can compute $q(x,y)_{j_0}  \in \R^n$. 

Recall from the Lemma statement, we have 
\begin{align}\label{eq:q_xy_j0}
q(x,y)_{j_0} = \sum_{i_0=1}^d c(x,y)_{j_0,i_0} h(y)_{i_0}.
\end{align}

Let $q(x,y)_{j_0} \in \R^n$ denote the $j_0$-th column of $q(x,y)$.

Then we have
\begin{align*}
    q(x,y) = \underbrace{ h(y) }_{n \times d} \underbrace{ c(x,y)^\top }_{d \times n}
\end{align*}

This takes $\Tmat(n,d,n)$ time.

Then, we compute 
\begin{align}\label{eq:p_xy_j0}
p(x,y)_{j_0} = ( \diag( f(x)_{j_0} ) - f(x)_{j_0} f(x)_{j_0}) q(x,y)_{j_0}.
\end{align}

This takes $O(n^2)$ time in total.

We can show that
\begin{align*}
    & ~ \frac{\d L(x,y)}{\d x} \\
    = & ~ \sum_{j_0=1}^n \sum_{i_0=1}^d \frac{\d L(x,y)_{j_0,i_0} }{ \d x } \\
    = & ~ \sum_{j_0=1}^n \sum_{i_0=1}^d \underbrace{ c(x,y)_{j_0,i_0} }_{ \mathrm{scalar} } \cdot \underbrace{ \A_{j_0}^\top }_{d^2 \times n} \underbrace{ ( \diag( f(x)_{j_0} ) - f(x)_{j_0} f(x)_{j_0}^\top ) }_{n \times n} \underbrace{ h(y)_{i_0} }_{n \times 1}  \\
    = & ~ \sum_{j_0=1}^n \A_{j_0}^\top ( \diag( f(x)_{j_0} ) - f(x)_{j_0} f(x)_{j_0}) q(x,y)_{j_0} \\
    = & ~ \sum_{j_0=1}^n \A_{j_0}^\top p(x,y)_{j_0} \\
    = & ~ \vect( A_1^\top p(x,y) A_2 )
\end{align*}
where the first step is based on Definition~\ref{def:L}, the second step is because of {\bf Part 1}, the third step is due to Eq.~\eqref{eq:q_xy_j0}, the fourth step follows from Eq.~\eqref{eq:p_xy_j0}, and the last step due to tensor-trick.

Note that $A_1^\top p(x,y) A_2$ can be computed in $\Tmat(n,d,n) + \Tmat(d,n,d)$ time.
 \end{proof}

\subsection{Reformulating Gradient (\texorpdfstring{$y$}{}) in Matrix View}\label{sub:gradient:reform_y}


In this section, we reformulate the gradient $y$ in the matrix's view.




\begin{lemma}\label{lem:compute_gradient_y}
If the following conditions hold
\begin{itemize}
    \item if $i_1=i_0$, $\frac{\d L(x,y)_{j_0,i_0} }{ \d y_{i_1,i_2}} = c(x,y)_{j_0,i_0} \langle f(x)_{j_0}, A_{3,*,i_2} \rangle$
    \item if $i_1\neq i_0$, $\frac{\d L(x,y)_{j_0,i_0} }{ \d y_{i_1,i_2}} =0$
    \item Let $\frac{\d L(x,y) }{\d y_{i_0,i_2}} = \sum_{j_0=1}^n  c(x,y)_{j_0,i_0} \langle f(x)_{j_0}, A_{3,*,i_2} \rangle $
    \item Let $\wt{q}(x,y)_{i_0} = \sum_{j_0=1}^n f(x)_{j_0} c(x,y)_{j_0,i_0}$
\end{itemize}
Then we have
\begin{itemize}
    \item {\bf Part 1.} 
    \begin{align*}
        \frac{\d L(x,y)_{j_0,i_0} }{ \d y_{i_0,i_2}} = \underbrace{A_{3,*,i_2}^\top}_{1 \times n} \underbrace{f(x)_{j_0}}_{n \times 1} \underbrace{c(x,y)_{j_0,i_0}}_{\mathrm{scalar}}
    \end{align*}    
    \item {\bf Part 2.}
    \begin{align*}
        \frac{\d L(x,y) }{\d y_{i_0,i_2}} = \underbrace{A_{3,*,i_2}^\top}_{1 \times n} \underbrace{\wt{q}(x,y)_{i_0}}_{n \times 1}
    \end{align*}
    \item {\bf Part 3.}
    \begin{align*}
        \frac{\d L(x,y) }{\d y} = \vect ( \underbrace{ A_3^\top}_{d \times n} \underbrace{ \wt{q}(x,y) }_{n \times d} )
    \end{align*}
    \item {\bf Part 4}. Computing $ \frac{\d L(x,y) }{\d y} $ takes $\Tmat(n,n,d) + \Tmat(n,d,d)$
\end{itemize}
\end{lemma}
\begin{proof}
{\bf Proof of Part 1.}
\begin{align*}
    \frac{\d L(x,y)_{j_0,i_0} }{ \d y_{i_0,i_2}}
    = & ~ c(x,y)_{j_0,i_0} \langle f(x)_{j_0}, A_{3,*,i_2} \rangle \\
    = & ~ A_{3,*,i_2}^{\top} f(x)_{j_0} c(x,y)_{j_0,i_0} 
\end{align*}
where the first step comes from the assumption from the Lemma statement and the second step is based on Fact~\ref{fac:circ_rules}.


{\bf Proof of Part 2.}
\begin{align*}
    \frac{\d L(x,y) }{\d y_{i_0,i_2}} = & ~ \sum_{j_0=1}^n  c(x,y)_{j_0,i_0} \langle f(x)_{j_0}, A_{3,*,i_2} \rangle \\
    = & ~ \sum_{j_0=1}^n A_{3,*,i_2}^{\top} f(x)_{j_0} c(x,y)_{j_0,i_0} \\
    = & ~ A_{3,*,i_2}^\top \wt{q}(x,y)_{i_0}
\end{align*}
where the first step is due to the assumption from the Lemma statement,
the second step is because of Fact~\ref{fac:circ_rules}, and the last step comes from the definition of $\wt{q}(x,y)_{i_0}$ (see from the Lemma statement).

{\bf Proof of Part 3.}
\begin{align*}
    \frac{\d L(x,y) }{\d y} = \vect (  A_3^\top \wt{q}(x,y)  )
\end{align*}
where the first step comes from tensor trick based on {\bf Part 2}.

{\bf Proof of Part 4.}
Computing $\wt{q}(x,y) \in \R^{n \times d}$ takes $\Tmat(n,n,d)$ time.

Computing $A_3^\top \wt{q}(x,y)$ takes $\Tmat(n,d,d)$ time. 

\end{proof}