%\documentclass{styles/uai2023} % for initial submission
\documentclass[accepted]{uai2023} % after acceptance, for a revised
                                    % version; also before submission to
                                    % see how the non-anonymous paper
                                    % would look like
%% There is a class option to choose the math font
% \documentclass[mathfont=ptmx]{uai2023} % ptmx math instead of Computer
                                         % Modern (has noticable issues)
% \documentclass[mathfont=newtx]{uai2023} % newtx fonts (improves upon
                                          % ptmx; less tested, no support)
% NOTE: Only keep *one* line above as appropriate, as it will be replaced
%       automatically for papers to be published. Do not make any other
%       change above this note for an accepted version.

%% Choose your variant of English; be consistent
\usepackage[american]{babel}
% \usepackage[british]{babel}

%% Some suggested packages, as needed:
%\usepackage{biblatex}
\usepackage{natbib} % has a nice set of citation styles and commands
    \bibliographystyle{plainnat}
    %\bibliographystyle{abbrv}
    \renewcommand{\bibsection}{\subsubsection*{References}}
\usepackage{mathtools} % amsmath with fixes and additions
% \usepackage{siunitx} % for proper typesetting of numbers and units
\usepackage{booktabs} % commands to create good-looking tables
%\usepackage{tikz} % nice language for creating drawings and diagrams
\usepackage[normalem]{ulem}
%\usepackage[symbol]{footmisc}
%% Provided macros
% \smaller: Because the class footnote size is essentially LaTeX's \small,
%           redefining \footnotesize, we provide the original \footnotesize
%           using this macro.
%           (Use only sparingly, e.g., in drawings, as it is quite small.)

%% Self-defined macros
\newcommand{\swap}[3][-]{#3#1#2} % just an example


\usepackage[capitalize,noabbrev]{cleveref}
\usepackage{amsmath}
\usepackage{amssymb}
\usepackage{mathtools}
\usepackage{amsthm}

\theoremstyle{plain}
\newtheorem{theorem}{Theorem}[section]
\newtheorem{proposition}[theorem]{Proposition}
\newtheorem{lemma}[theorem]{Lemma}
\newtheorem{corollary}[theorem]{Corollary}
\theoremstyle{definition}
\newtheorem{definition}[theorem]{Definition}
\newtheorem{assumption}[theorem]{Assumption}
\theoremstyle{remark}
\newtheorem{remark}[theorem]{Remark}

\usepackage{amsfonts}
\usepackage{thmtools,thm-restate}
\usepackage{algorithm, algorithmic}
\usepackage{jmei}

\usepackage{titletoc}

\title{KrADagrad: Kronecker Approximation-Domination Gradient\\Preconditioned Stochastic Optimization}

% The standard author block has changed for UAI 2023 to provide
% more space for long author lists and allow for complex affiliations
%
% All author information is authomatically removed by the class for the
% anonymous submission version of your paper, so you can already add your
% information below.
%
% Add authors
\author[1$\dagger$]{\href{mailto:j@mei.to}{Jonathan~Mei}}
\author[1$\dagger$]{Alexander~Moreno}
\author[1$\dagger$]{Luke~Walters}
% Add affiliations after the authors
\affil[1]{Independent Researcher}
\begin{document}
\maketitle
\footnotetext[2]{Funded by and conducted at Luminous Computing.}
\begin{abstract}
    Second order stochastic optimizers allow parameter update step size and direction to adapt to loss curvature, but have traditionally required too much memory and compute for deep learning. Recently, Shampoo \citep{gupta18shampoo} introduced a Kronecker factored preconditioner to reduce these requirements: it is used for large deep models \citep{anil20scalable} and in production \citep{anil22factory}.
    However, it takes inverse matrix roots of ill-conditioned matrices. This requires 64-bit precision, imposing strong hardware constraints. In this paper, we propose a novel factorization, Kronecker Approximation-Domination (KrAD). Using KrAD, we update a matrix that directly approximates the inverse empirical Fisher matrix (like full matrix AdaGrad), avoiding inversion and hence 64-bit precision. We then propose %two algorithms: 1) KrADagrad, using alternating updates of the preconditioner factors. This has significantly reduced computations compared to Shampoo but only theoretically achieves within $\varepsilon$-tolerance of optimal regret; and 2)
KrADagrad$^\star$, with similar computational costs to Shampoo and the same regret. Synthetic ill-conditioned experiments show improved performance over Shampoo for 32-bit precision, while for several real datasets we have comparable or better generalization.
\end{abstract}


\section{Introduction}
    \label{sec:intro}
    \iffalse
    To do
    \begin{itemize}
    	\item Add experimental design writeup
    	\item Add a table of compute/memory costs
    	\item Finish Tensor proofs
    	\item Add 1-2 sentences in some sections for key takeaways.
    \end{itemize}
    \textcolor{red}{We need some empirical wins.}
    Selling points
    \begin{itemize}
    	\item A different approach to dealing with the ill-conditioned:
    	\item Because of this approximation, may not be getting around this problem, you may not be able to achieve a sufficiently accurate pre-conditioner. Potential path.
    	\item There's not a substantial difference between SGD/ADAM/Shampoo. We tried a bunch of experiments on different tasks and after hyperparameter optimization, it was actually quite difficult to demonstrate a substantial advantage of any methods.
    	\item batch norm needs to be removed
    	\item Architectural things you can do to make the performance much more similar.
    	\item We have a win of krad32 over shampoo32, but only in synthetic experiments.
    \end{itemize}
    \fi
    
    Second order stochastic optimization methods adapt to loss curvature, allowing for smaller parameter update steps in regions where the gradient changes quickly, avoiding bouncing behavior, and larger ones in flat regions. Traditionally, they required storing and inverting the Hessian to update parameters: this requires quadratic memory and cubic computation in the number of parameters. Thus, methods using only a diagonal Hessian/Fisher approximation \citep{duchi11adaptive,kingma14adam} have dominated the field. However, diagonal preconditioners only scale gradients in the canonical basis, while full preconditioners can potentially perform scaling in a rotated basis aligning more closely with loss curavture.
    
    Recently, Shampoo \citep{gupta18shampoo,anil20scalable} proposed approximating the (empirical) Fisher matrices using Kronecker factorized matrices. The matrix version of Shampoo factorizes the full preconditioner matrix into left and right Kronecker factors, which allows storing and inverting the smaller factors instead of the full matrix. For parameters $\W\in \mathbb{R}^{m\times n}$, this reduces computation costs from $O(m^3n^3)$ to $O(m^3+n^3)$ and storage costs from $O(m^2n^2)$ to $O(m^2+n^2)$. AdaGrad \citep{duchi11adaptive} uses regret bound techniques based on Online Mirror Descent (OMD) \cite{srebro11universality} designed for vector updates. To use these techniques for matrix/tensor updates, \citep{gupta18shampoo} exploit domination results that relate vector update preconditioners to their matrix/tensor counterparts. However, Shampoo still requires inverse matrix roots, which are numerically unstable or inaccurate for ill-conditioned matrices in 32-bit precision. For preconditioned gradient descent, it is important for Shampoo to maintain the accuracy of the smallest eigenvalues (largest when inverted). 
    It thus needs 64-bit precision, which requires some combination of slow TPU-CPU data transfers, stale preconditioner matrices, or even new machine learning accelerator hardware \citep{anil20scalable} supporting fast 64-bit matrix multiplication or fast accurate eigendecomposition.
    
    A primary motivator to use any optimizer is to reach the same quality solution in less time or reach a better solution that other optimizers fail to reach. Second order optimizers are currently not as popular as first order methods due to: a) inertia to adoption, with a lack of highly optimized implementations in all major ML frameworks; b) the added compute and memory requirements; c) numerical stability and consequently the additional considerations required to get them to work or to debug them (e.g. numerical linear algebra, computer number formats); and d) even though they sometimes reach a solution unreachable by 1st order methods (or the same solution with lower wall clock time (WCT) if properly optimized), they don’t consistently for every task or architecture.
    
    The key tradeoff is (b, c) vs (d). \cite{shukla22understanding} of Weights and Biases noted that some of their customers using Shampoo do find solutions that generalize better than those found with ADAM for their real world tasks (d), but that 2nd order optimizers are still more expensive (b) and have additional considerations (c).
    
    In this paper, we address these limitations (c) by introducing a novel factorization, Kronecker Approximation-Domination (KrAD): it has a simple form that updates the preconditioning matrix without explicitly inverting it. Shampoo constructs Kronecker factors of intermediate statistics such that their Kronecker product dominates the gradient outer product matrix. Our key idea is to construct factors for those statistics such that the \textit{inverse} of their Kronecker product dominates the gradient outer product matrix. This leads to preconditioners that require fractional powers rather than inverse fractional powers of factors. While this does not decrease the computational complexity compared to Shampoo, it avoids needing 64-bit precision.
    
    This paper has three primary contributions: 1) we introduce two new algorithms, both with $O(m^3+n^3)$ and $O(m^2+n^2)$ computational and memory complexity, respectively and only requiring positive matrix roots, in contrast to previous work requiring inverse matrix roots; 2) we show domination properties and use them to prove that 
    our algorithm, which has similar computation cost to Shampoo, 
    %one algorithm, requiring far fewer FLOPs than Shampoo, has within $\varepsilon$ tolerance of optimal stochastic regret, while the second algorithm, with similar FLOP requirements, 
    achieves optimal regret; 3) we show empirically that in 32-bit precision, we outperform Shampoo in synthetic experiments and perform similarly on some real experiments.%that we achieve similar performance to Shampoo and only require 32-bit precision. %{\color{teal} Something else worth exploring is the fact that current Hessian-free optimizers involve only FP32 vector ops, but o/w matmuls through the model can be performed with bfloat16 inputs (but FP32 results) when using mixed precision. (i.e. matmuls normally don't need to be performed in full FP32). If there is any hope that Kradagrad would still work if all of the matmuls involved could take bfloat16 inputs (by casting the inputs to bfloat16), this would be another big win}.
    We first describe some mathematical tools, set up the problem, and describe second order optimization and established results related to our method in Section~\ref{sec:background}. Then we present our method and its theoretical properties in Section~\ref{sec:kradagrad}. Next, we consider the practical implementation of our method in Section~\ref{sec:implementation}. Then, we show empirical results in Section~\ref{sec:experiments}. Finally, we discuss implications in Section~\ref{sec:conclusion}.
    
    
    \section{Background and related work}
    \label{sec:background}
%\textcolor{red}{This is a non-standard related work section, as it's presented more as mathematical background/tutorial than a high-level summary of recent ideas. It explains the math very well, but it also makes it hard to immediately contextualize the high-level novelty. If you look at my 2020 paper shorturl.at/yBVY2 that is a standard way to do related work. Yours is a bit like the following paper from my former labmate shorturl.at/sTVXZ, where she reviews CTMC mathematically as related work (instead of just diving into CT-HMM). Some reviewers don't like the latter style, and one of the reviewers of that paper said `I feel that too much space is spent explaining prior art (until line 235, 4.5 pages into the paper).' That said, in her case, as in this, it does help to introduce prior work mathematically so the reader isn't shocked by what's going on with a `where did this come from???,' so it's potentially fine to do it this way, but we need to be very careful that we deliver accents in the right place to emphasize novelty. Sections 2.1-2.4 read more like a tutorial on previous methods. Some have that in the main paper successfully, but it's a little harder to pull off in conference papers due to page limits. We should discuss whether we want to do it: you can decide in the end but it's worth having the discussion. An alternative would be to have a more standard related work section, have a very brief background section, and the rest of the tutorial in the appendix.}

%\textcolor{red}{We need to mention the other recent methods for 2nd order optimization and relate what we're doing to them. Tensor normal training, kfac, etc. Not mentioning them can get the paper majorly dinged. Ideally we compare to some of them empirically, but only if we have time.}

Here, we set up the notation and the problem and then describe relevant related works. We briefly describe optimization for vector-valued parameters then extend the discussion to matrix-valued parameters. We can further generalize to tensors, but continue with matrices in the main text for clarity and leave the tensor formulation for Appendix~C.5.%\ref{app:implementation_tensors}.

\subsection{Notation and preliminaries}
\label{subsec:notation}
We use bold lower case letters to denote column vectors (e.g. $\g\in\Rbb^{n}$), bold upper case letters to denote matrices (e.g. $\G\in\Rbb^{m\times n}$), and calligraphic letters to denote matrices composed of stacking vectors of interest (e.g. $\Gcal_k=\begin{pmatrix}\g_k & \g_{k-1} &\ldots & \g_1 \end{pmatrix}\in \Rbb^{n\times k}$). For square matrix $\A\in\Rbb^{n\times n}$, let the trace be $\tr(\A)=\sum_{i=1}^{n}a_{i,i}$, where $a_{i,j}$ denotes the element in row $i$ and column $j$ of $\A$. For matrices $\A,\B\in\Rbb^{m\times n}$, let $\A\cdot\B=\tr(\A^\top\B)$ be the matrix (Frobenius) inner product, and the induced Frobenius norm $\|\A\|_F = (\A\cdot\A)^{1/2}$. We use $\|\A\|_2$ to denote the spectral norm, the largest singular value of $\A$. We write $\A\succeq 0$ to mean $\A$ is symmetric positive semi-definite (PSD), while $\A\succ 0$ means $\A$ is symmetric positive definite (PD). For two PSD matrices $\B\succeq \A$ means that $\B-\A \succeq 0$ (similarly for $\succ$). For a PSD matrix, take $\A=\V\boldsymbol{\Lambda}\V^\top$ to be the eigenvalue decomposition, which results in orthonormal $\V$ (i.e. $\V^{-1}=\V^\top$). Define $f(\boldsymbol{\Lambda})$ such that the diagonal elements are $(f(\boldsymbol{\Lambda}))_{i,i}=f(\lambda_{i})$, the principal values of the function applied to the scalar eigenvalues. Then we take $f(\A)=\V f(\boldsymbol{\Lambda})\V^\top$. In this way, we define a unique value for functions applied to PSD matrices with eigenvalues within the domain of the function. In particular, we have a definition for real powers of PD matrices.

%The Kronecker product $\otimes:\Rbb^{m\times n}\times \Rbb^{q\times r}\rightarrow \Rbb^{mq\times nr}$ is\par
Let $\otimes$ denote the Kronecker product, for matrices $\A\in\Rbb^{m\times n}$ and $\B\in\Rbb^{q\times r}$ defined as\par
{\centering
	$\B\otimes\A=
	\begin{pmatrix}
		b_{1,1}\A & \ldots & b_{1,r}\A \\
		\vdots & \ddots & \vdots \\
		b_{q,1}\A & \ldots & b_{q,r}\A
	\end{pmatrix} \in \Rbb^{mq\times nr}$.
	\par}
Let the vectorization operation for a matrix $\A\in\Rbb^{m\times n}$ be
\par
{\centering
	$\vec(\A) = \begin{pmatrix} \a_{1}^\top & \ldots & \a_{n}^\top \end{pmatrix}^\top \in\Rbb^{mn}$:
	\par}
$\a_i$ is the $i$-th column of $\A$, and the corresponding inverse vectorization for $\a\in\Rbb^{mn}$ given a target matrix in $\Rbb^{m\times n}$ is\par
{\centering $\vec^{-1}_{m,n}(\a) = \begin{pmatrix} \a_{1:m} & \ldots & \a_{m(n-1)+1:mn} \end{pmatrix} \in\Rbb^{m\times n}$,
	\par}
where $\a_{i:j}=\begin{pmatrix}a_i & \ldots & a_j\end{pmatrix}^\top$.%size of $(m, n)$

Several properties \citep{bellman80some,vanloan00ubiquitous,boyd04convex,baumgartner11inequality} of trace and Kronecker products are important for our results. We list them in Appendix~D %\ref{app:derivations}
due to space constraints.% algorithm and theoretical

\subsection{Optimization in machine learning}
\label{subsec:opt_in_ml}
We are interested in iterative empirical risk minimization under loss $f$, with $\x\sim p_n(x)$ an empirical density and parameters $\w\in\Rbb^{N}$
% a function
% of a vector $\w\in\Rbb^{N}$ (e.g. parameters of a model) that also takes additional inputs $\x\in\Rbb^{M}$ (e.g. data) that arise from distribution $\x\sim p(\x)$,
\par    
{\centering
	$\w^* = \underset{\w}{\argmin}\, \mathbb{E}_{p_n}[f(\w, \x)]$.
	\par}
% We are interested in iteratively optimizing a function
% of a vector $\w\in\Rbb^{N}$ (e.g. parameters of a model) that also takes additional inputs $\x\in\Rbb^{M}$ (e.g. data) that arise from distribution $\x\sim p(\x)$,\par    
% {\centering
	%     $\w^* = \underset{\w}{\argmin}\, \mathbb{E}_{p}[f(\w, \x)]$.
	% \par}
%make the standard assumption that we have
We assume access to gradients $\nabla_\w f$. Let $\g_k = \sum_{i\in B_k}\nabla_\w f(\w_k,\x_i)\in \Rbb^{N}$ be the estimated gradient of $f$ w.r.t. $\w$ evaluated at $\w_k$ with data from batch $B_k$ at iteration $k$. From here, we omit but imply ``stochastic'' gradients.% w.r.t $\w$

For step size $\eta_k\in \Rbb^+$, gradient-based methods update
\begingroup\abovedisplayskip=3pt plus 1pt minus 3pt\belowdisplayskip=3pt plus 1pt minus 3pt
\begin{align}
	\label{eq:preconditioned_update}
	\w_{k+1} = \w_k - \eta_k \P_k \g_k
\end{align}
\endgroup
where $\P_k\in \Rbb^{N\times N}$ is a \textit{preconditioner} matrix (some sources instead refer to $\P_k^{-1}$ as the preconditioner). In some algorithms, additional intermediate \textit{statistics} are stored and updated to aid preconditioner computation. Gradient descent uses $\P_k = \I_N$ (no preconditioning); Newton's method takes $\P_k=\mathbf{H}_k^{-1}$ to be the (pseudo)inverse of the Hessian.


While vanilla gradient descent updates are trivial to compute, convergence can require many iterations. Newton updates are more expensive, but may require far fewer iterations. In practice, the chosen form of the preconditioner matrix appears to exist along a trade-off between computational tractability and improved convergence properties.
%be slow in terms of the number of iterations%\textcolor{red}{ here we mention gradient descent and not SGD., while in the previous paragraph it read as if we were talking about SGD since it mentioned a mini-batch. In the full data setting, we can actually make a stronger statement for strongly convex and Lipschitz gradient loss: gradient descent has linear convergence while Newton's method has quadratic convergence. For stochastic optimization though we can't say that and they have the same regret (I think). If we're talking about stochastic optimization then this paragraph should read like it's stochastic optimization. If not we can mention the stronger result.} \textcolor{magenta}{We should prioritize the claims about what has been observed in practice in the stochastic setting over the theory we can achieve in full data setting. There's a note in the previous paragraph about omitting the designation ``stochastic'' for brevity. Do you think that's confusing or inadequate for the discussion?}\textcolor{red}{It's fine: I missed it. I think it was black and surrounded by other colors so it got drowned out.}


\subsection{Adaptive gradient preconditioners}
\label{subsec:ada_grad_preconditioners}
We can collect gradients through iteration $k$,\par
{\centering    $\Gcal'_k = \begin{pmatrix}\g_k & \g_{k-1} &\ldots & \g_i &\ldots & \g_1 \end{pmatrix}\in \Rbb^{N\times k}$ \par}
and augment this with a scaled identity,\par
{\centering
	$\Gcal_k = \begin{pmatrix}
		\Gcal'_k & \epsilon \I_N
	\end{pmatrix}\in \Rbb^{N\times (k+N)}$.
	\par}
\iffalse
so that we have the recursion,
\begin{align}
	\label{eq:gradient_recursion}
	{\small \Gcal_0 \Gcal_0^\top &= \epsilon^2 \I_N \nonumber \\
		\Gcal_{k+1} \Gcal_{k+1}^\top &= \g_{k+1} \g_{k+1}^{\top} + \Gcal_{k} \Gcal_{k}^\top}.
\end{align}
\fi
One form of adaptive gradient update is\par%\textcolor{red}{so if I'm understanding this, we have, excluding the scaled identity term, a geometric sum of rank one approximations to the Hessian, right?} \textcolor{magenta}{sorta. the whole sum including the scaled identity is the approximation of Hessian, since each rank-one term is itself a pretty poor approximation (only captures the curvature in one dimension)}
{\centering
	${\small \w_{k+1} = \w_k - \eta_k \left(\Gcal_{k} \Gcal_{k}^\top\right)^{-1/2} \g_k}$.
	\par}
Expressing this in terms of the non-augmented $\mathcal{G}'_k$ and taking $\delta=\epsilon^2$ gives the full version of AdaGrad~\citep{duchi11adaptive}: $\Gcal_{k} \Gcal_{k}^\top$ can be seen as the statistic that is stored and updated in each iteration, and $\big(\delta\I_N +\Gcal'_{k} \Gcal_{k}^{'\top}\big)^{-1/2}$ is the preconditioner computed from the statistic. Unfortunately, storing the full matrix $\Gcal_{k} \Gcal_{k}^\top$ is memory intensive at $O(N^2)$, and taking the inverse square root is computationally expensive at $O(N^3)$ to compute the SVD.
%.The intermediate quantity

Diagonal AdaGrad reduces computational complexity\par
% The commonly implemented modification to reduce computational complexity leads to the diagonal version that is perhaps the most well-known,\par
{\centering
	$\w_{k+1} = \w_k - \eta_k \left(\delta\I_N + \textrm{diag}\left(\Gcal'_{k} \Gcal_{k}^{'\top}\right)\right)^{-1/2} \g_k$.
	\par}
This is $O(N)$ complexity in both memory and computation.

\iffalse
Also of note is Online Newton Step (ONS)~\cite{hazan06logarithmic}, in which the parameter update is
\begin{align}
	\label{eq:ONS_update}
	\w_{k+1} = \w_k - \eta_k \left(\delta\I_N +\Gcal'_{k} \Gcal_{k}^{'\top}\right)^{-1} \g_k.
\end{align}
This still requires an ill-conditioned matrix inverse, but avoids the additional potential difficulty of computing a matrix square root. This method is not as widely used in machine learning, but~\cite{anil20scalable} note that one setting of an exponent hyperparameter in the actual implementations of their method corresponds to a factored approximation to the ONS, while another corresponds to that of AdaGrad.
\fi

\iffalse
\subsection{Approximating inverse square roots}
\label{subsec:approx_inv_sqrts}
While for the ONS update, we do not need to compute square roots, we discuss methods for doing so in service of the other updates.

Recently developed methods for low-rank updates to matrix square roots~\cite{shmueli22low} may enable us to trade off the computational complexity of performing a full matrix inverse square root at each time step in exchange for additional memory. Here we note that while the additional memory required may seem prohibitive at $O(N^2)$ given we just declared $O(Nh)$ too expensive, we address this concern in following sections.

Suppose at iteration $k$, we have an existing cached approximation for the square root and its inverse for our adaptive gradient estimator from Equation~\eqref{eq:naive_adaptive_gradient},
\begin{align}
	\Q_k &= \left(\Gcal_{k} \Gcal_{k}^\top\right)^{1/2}\\
	\P_k &= \Q_k^{-1}.
\end{align}
By our recursion from Equation~\eqref{eq:gradient_recursion}, the updates for the next iteration are,
\begin{align}
	\label{eq:update_Q_full}
	\Q_{k+1} &= \left(\g_{k+1}\g_{k+1}^\top + \Q_{k}^{2} \right)^{1/2} \nonumber\\
	&\approx \Vcal_{k+1}\Vcal_{k+1}^\top + \Q_{k}
\end{align}
and
\begin{align}
	\label{eq:update_P_full}
	\P_{k+1} &= \Q_{k+1}^{-1} \nonumber\\
	&\approx \left(\Vcal_{k+1}\Vcal_{k+1}^\top + \Q_{k} \right)^{-1}\nonumber\\
	&= -\Ucal_{k+1}\Ucal_{k+1}^\top + \P_{k},
\end{align}
where $\Vcal_{k+1}\in\Rbb^{N\times r}$ in Equation~\eqref{eq:update_Q_full} for some $1 \le r < \sqrt{N}$ can be obtained in $O(N^2 r^2)$ via a low-rank Algebraic Riccati Equation due to~\cite{shmueli22low}. Note that this cost is due to the cost of matrix-vector multiplication of $\Q_k$ and $\P_k$. Then, $\Ucal_{k+1}\in\Rbb^{N\times r}$ in Equation~\eqref{eq:update_P_full} can be determined in $O(N^2 r)$ via an application of the Woodbury~\yrcite{woodbury50inverting} update
\begin{align}
	\label{eq:update_U_full}
	\T_{k+1} &= \I_r +\Vcal_{k+1}^\top \P_k\Vcal_{k+1} \in \Rbb^{r\times r} \nonumber \\
	\Ucal_{k+1} &= \P_k \Vcal_{k+1}\T_{k+1}^{-1/2}. 
\end{align}
Note that this cost is due to the cost of multiplying $\P_k$ with $\Vcal_{k+1}$.
\fi

\subsection{Matrix variables and Shampoo}
\label{subsec:mat_vars}
Now, we consider $N=mn$ and optimize w.r.t. a matrix $\W\in\Rbb^{m\times n}$. We consider a single matrix for clarity, but note that the derivations and analyses can be extended to tensors and applied individually to each tensor-valued parameter in a given model (e.g. to compute the total costs). Then we can use the same optimization framework, now taking $\w=\vec(\W)\in\Rbb^{mn}$. However, in forming a preconditioner, we utilize the fact that our parameter now has the additional structure of being a matrix. Consider\par
{\centering $\G_k = \vec^{-1}_{m,n}(\g_k)$.
	\par}
%where $\vec^{-1}_{m,n}$ undoes the effect of $\vec$ on a $m\times n$ matrix.  %% already defined in notation

One convenient factorized form of a preconditioner is\par
{\centering 
	$\P_k = \R_k \otimes \L_k$,
	\par}
where $\L_k\in\Rbb^{m\times m}$ and $\R_k\in\Rbb^{n\times n}$ are symmetric. This reduces storage and computation while not necessarily being low-rank. To see this, we return to Equation~\eqref{eq:preconditioned_update} and simplify
\begin{align*}
	\w_{k+1} &= \w_k - \eta_k \P_k \g_k \nonumber\\
	%&=\w_k - \eta_k (\R_k\otimes\L_k) \vec(\G_k) \nonumber\\
	&=\w_k - \eta_k \vec(\L_k\G_k\R_k) \qquad \because (\textrm{P}18) %\pref{property:kron_vec}
	\nonumber\\
	\Rightarrow \W_{k+1} & = \W_k - \eta_k \L_k\G_k\R_k. \qquad \because (\vec^{-1}_{m,n})
\end{align*}

This requires $O(m^2 + n^2)=O\left(N\left(\frac{m}{n}+\frac{n}{m}\right)\right)$ storage and $O(N(m+n))$ compute.
%Define a matrix shape condition number that describes how far from square it is
%\begin{align}
%    \kappa = \frac{m}{n}+\frac{n}{m}.
%\end{align}
%Note that $2\le \kappa\le N + \frac{1}{N}$.
Unless otherwise specified, we assume w.l.o.g. that $m\le n$. Shampoo~\citep{gupta18shampoo} tracks statistics
\begingroup\abovedisplayskip=3pt plus 1pt minus 3pt\belowdisplayskip=3pt plus 1pt minus 3pt
\begin{align}
	{\small \B_{k} =\epsilon\I_m + \textrm{$\sum_{i=1}^{k}$} \G_i \G_i^\top}, \label{eq:shampoo_stat1} \quad
	{\small \C_{k} =\epsilon\I_n + \textrm{$\sum_{i=1}^{k}$} \G_i^\top \G_i} ,
\end{align}
\endgroup
and forms the preconditioner from the Kronecker factors
\begingroup\abovedisplayskip=3pt plus 1pt minus 3pt\belowdisplayskip=3pt plus 1pt minus 3pt
\begin{align}
	\label{eq:shampoo_precon1}
	(\L_{k}, \R_{k}) &= (\B_k^{-1/4}, \C_k^{-1/4}).
\end{align}
\endgroup

\cite{gupta18shampoo} relies on 3 key conditions to prove Shampoo achieves optimal regret. First,
\begingroup\abovedisplayskip=3pt plus 1pt minus 3pt\belowdisplayskip=3pt plus 1pt minus 3pt
\begin{align}
	\label{eq:shampoo_dominate_grad}
	\R_k^{-2} \otimes \L_k^{-2} &\succeq \epsilon\I_{mn} + \sum_{i=1}^{k} \g_i \g_i^\top.
\end{align}
\endgroup
It secondly requires that
\begingroup\abovedisplayskip=3pt plus 1pt minus 3pt\belowdisplayskip=3pt plus 1pt minus 3pt
\begin{align}
	\label{eq:shampoo_dominate_prev}
	\R_k^{-2} \otimes \L_k^{-2} &\succeq \R_{k-1}^{-2} \otimes \L_{k-1}^{-2}.
\end{align}
\endgroup
%
%
\iffalse
Also, they show that individually,
\begin{align}
	\label{eq:shampoo_dominate_both}
	\begin{aligned}
		\I_n \otimes \L_k^{-4} &\succeq \epsilon\I_{mn} + \frac{1}{r}\sum_{i=1}^{k} \g_i \g_i^\top\\
		\R_k^{-4} \otimes \I_m  &\succeq \epsilon\I_{mn} + \frac{1}{r}\sum_{i=1}^{k} \g_i \g_i^\top.
	\end{aligned}
\end{align}
\fi
%
%
Finally it requires that under mild conditions,
\begin{align}
	\label{eq:shampoo_trace_rate}
	{\small \tr(\L_k), \tr(\R_k) = O(k^{1/4})}.
\end{align}
An additional $O(n^3+m^3)$ cost comes from taking the inverse fractional powers via a high-precision (64-bit) Newton iteration (which involves repeated matrix multiplications; see Equation (25) %\eqref{eq:coupled_newton_iteration}
in Appendix~C.3 %\ref{app:implementation_matrix_roots}
for further details) or SVD, which dominates the previous $O(N(m+n))$.

The motivation for these properties stems from OMD analysis. The regret vector parameter $\w_k$ updates and general preconditioners $\P_k \succ 0$ is initially bounded by sums of quadratic forms $\w_k^\top \P_k \w_k$. If the domination property holds, these can be bounded in terms of $\tr(\P_k)$. Using Property~(P13), %\pref{property:trace_kron},
this can be further expressed in terms of $\tr(\L_k)$ and $\tr(\R_k)$. The trace growth rates give the final bound.

\subsection{Related Work}
\label{subsec:related_work}

Recently, there has been a surge in interest in tractable preconditioned gradient methods. We briefly contrast some of the most similar or otherwise notable methods to ours.

\textbf{Kronecker Factored}: KFAC \citep{martens15optimizing} and TNT \citep{ren21tensor} use Kronecker factors to approximate Fisher matrices while reducing storage and computation costs. KFAC requires knowledge of network architecture and thus modifications or even re-implementations corresponding to each parametric layer type within the network. KFAC and TNT reqiore an additional backward pass and matrix inversion. KBFGS \citep{goldfarb20practical} does not require matrix inversion but requires an additional forward and backward pass. We note that the empirical performance achieved by KBFGS is partially due to initialization using curvature estimated from the entire training set \citep{goldfarb20practical}, which is not available in a truly online setting. Shampoo is the most closely related work, relying on the empirical Fisher matrix rather than estimating the Fisher matrix, thus not requiring additional sampling or forward/backward passes. In addition, only tensor shape knowledge is required. While estimating the Fisher matrix may have intuitively desirable properties over the empirical Fisher, the empirical Fisher is more practical to compute \citep{martens20new}, since in distributed data or model parallel settings, additional forward/backward passes become prohibitive (both in terms of computation and engineering cost). To our knowledge, Shampoo is the only second order optimizer that has been successfully implemented in a large-scale, production, deep learning setting \citep{anil22factory}, which makes it of primary interest.

\textbf{Limited Memory}: GGT~\citep{agarwal19efficient} uses a limited history of $h$ past gradients to form a low-rank approximation to the full AdaGrad matrix, reducing storage costs to $O(Nh)$ and compute costs to $O(Nh^2)$; however, this requires many copies ($200$, for the problems they consider) of the full gradient to be stored as statistics ($h$ still scales as a function of $N$), which can become prohibitive without modifications to reduce $N$.

\textbf{Sketching}: AdaHessian \citep{yao21adahessian}, SENG \citep{yang22sketch}, and SketchySGD \citep{frangella22sketchysgd} estimate the Hessian (either the diagonal or a low-rank approximation) via automatic differentiation to compute Hessian-vector products (HVP), which require additional backpropagation steps and two batches of data per step, one for gradient computation and one for Hessian sketching. This is less expensive than a full forward/backward pass, but can still be expensive in distributed settings. In addition, the low-rank factorization still requires many times the storage of the full gradient of the model ($100-200\times$ in the problems they consider). While KrADagrad does not use sketching or HVP, these methods could potentially be combined with KrAD factorization as future work.
    
    \section{KrADagrad}
    \label{sec:kradagrad}

Here, we derive a pair of new optimization algorithms, KrADagrad and KrADagrad$^\star$, presenting along the way intermediate results that allow us to attain domination properties analogous to those in Equations~\eqref{eq:shampoo_dominate_grad}-\eqref{eq:shampoo_dominate_prev} and trace growth rates in Equation~\eqref{eq:shampoo_trace_rate} required for good regret. Simultaneously, this maintains low computational complexity achieved by Kronecker factorized methods. To derive the algorithms, we present 
\begin{enumerate}
	\item KrAD, a method for producing a Kronecker factorization approximating a matrix that yields the property in Equation \eqref{eq:shampoo_dominate_grad}
	\item Derivation of the basic form of KrADagrad, applying to AdaGrad style updates KrAD and the Woodbury matrix identity \cite{woodbury50inverting} as the key tricks
	\item Statements confirming the property in Equation \eqref{eq:shampoo_dominate_prev}
	\item Extension to KrADagrad$^\star$.
\end{enumerate}

KrADagrad alternates updates of Kronecker factors of the statistics and has within $\varepsilon$ tolerance of optimal regret. With additional insights from KrADagrad regret analysis, we formulate KrADagrad$^\star$, which can be seen as an ``average'' of two KrADagrad estimators. KrADagrad$^\star$ updates both Kronecker factors of the statistics simultaneously and achieves optimal regret. 
We present theoretical results along the way in this section as they are needed but defer the proofs to Appendix~B. %\ref{app:proofs}.
We derive the algorithm for matrix-valued parameters for clarity and again leave the extension to tensor-valued parameters for Appendix~C.5.%\ref{app:implementation_tensors}.

\subsection{Kronecker Approximation-Domination}
\label{subsec:kron_approx_dom}
First, we state a Lemma as the goal of KrAD, which is needed to achieve the condition in Equation \eqref{eq:shampoo_dominate_grad} that allows us to prove optimal regret.
\begin{restatable}{lemma}{krad}
	\label{lem:krad}
	Let PD matrix $\C\in\Rbb^{n\times n}$, $\Ucal\in\Rbb^{mn\times r}$, $\u_i$ be the $i$-th column of $\Ucal$,  $\U_i=\vec^{-1}_{m,n}(\u_i)$. Then
	\begingroup\abovedisplayskip=2pt plus 1pt minus 2pt\belowdisplayskip=2pt plus 1pt minus 2pt
	\begin{align*}
		{\smaller\B} &{\smaller = \sum\limits_{i=1}^{r} \U_i\C^{-1}\U_i^\top \succeq 0} \\
		{\smaller\Rightarrow \C\otimes\B} &{\smaller \succeq \frac{1}{n}\Ucal\Ucal^\top.}
	\end{align*}
	\endgroup
\end{restatable}
These matrices $\U$ and $\Ucal$ are fairly general. In our setting, we will use KrAD on gradient matrices. In context, this result states that given a PD right matrix, we can express a left matrix in a quadratic form, and the Kronecker product of the right and left matrices will dominate the scaled gradient outer product matrix. We present the proof in Appendix~B.1.%\ref{app:proof_krad}.

\subsection{KrADagrad updates: derivation}
\label{subsec:KrADagrad_updates}
We start with deriving the general statistic update. We outline it here and fill in details in Appendix~D. %\ref{app:derivations}.
Suppose we have the previous Kronecker factorization of a statistic $\Q_{k-1}$ that dominates the gradient outer product matrix by a factor $t$ (which we will clarify later) and its inverse $\P_{k-1}$
\begingroup\abovedisplayskip=3pt plus 1pt minus 3pt\belowdisplayskip=3pt plus 1pt minus 3pt
\begin{align}
	\Q_{k-1} &= \C_{k-1}\otimes\B_{k-1} \succeq \frac{1}{nt}\Gcal_{k-1} \Gcal_{k-1}^\top \\
	\P_{k-1} &= \R_{k-1}\otimes\L_{k-1} = \C_{k-1}^{-1}\otimes\B_{k-1}^{-1} = \Q_{k-1}^{-1},
\end{align}
\endgroup
where $\B_{k-1}$, $\C_{k-1}$, $\L_{k-1}$, and $\R_{k-1}$ are PD. 
Our initial intermediate update, which we will apply our KrAD factorization to, is\par 
{\centering $\widetilde\Q_{k} = \frac{1}{nt_k}\g_{k}\g_{k}^\top + \Q_{k-1},$ \par}
for some $t_{k}\le t$. 
We will then compute an intermediate version of the update on the inverse of the left statistic $\widetilde\B_{k}=\widetilde\L_{k}^{-1}$. Letting $\widetilde\B_k=\B_{k-1} + \Delta\B_k$, we can apply KrAD to $\widetilde\Q_{k}$ with a fixed $\C_k=\C_{k-1}$ (i.e., we use the old right statistic to update the current left one),
\begingroup\abovedisplayskip=3pt plus 1pt minus 3pt\belowdisplayskip=3pt plus 1pt minus 3pt
\begin{align}
	&\Delta\B_{k} = \frac{1}{t_k}\G_k\C_{k}^{-1}\G_k^{\top}\\
	\Rightarrow&\C_k \otimes \Delta\B_{k} \succeq \frac{1}{nt_k}\g_k\g_k^{\top} \qquad \because \textrm{Lemma \ref{lem:krad}}\\
	\Rightarrow& \C_k \otimes (\B_{k-1} + \Delta\B_{k}) \succeq \frac{1}{nt}\Gcal_k\Gcal_k^\top \quad \because (\textrm{P15})%\pref{property:kron_prod_distributive}
	\label{eqn:b-tilde-dominating-scaled-gradients}.
\end{align}
\endgroup
Then, letting $\widehat{\C}_k = \C_{k}\!+\!\frac{1}{t_k}\G_k^{\top}\L_{k\!-\!1}\G_k$ and applying the Woodbury matrix identity,
\begingroup\abovedisplayskip=3pt plus 1pt minus 3pt\belowdisplayskip=3pt plus 1pt minus 3pt
\begin{align}
	\widetilde\L_{k}&=\widetilde\B_{k}^{-1}=(\B_{k-1}+\Delta\B_{k})^{-1} \\
	%&=\!\L_{k\!-\!1} \!-\! \frac{1}{t_k}\L_{k\!-\!1}\G_k(\widehat{\C}_k)^{-1}\G_k^{\top}\L_{k\!-\!1}\\
	%&\succeq \L_{k\!-\!1} \!-\! \frac{1}{t_k}\L_{k\!-\!1}\G_k(\C_{k})^{-1}\G_k^{\top}\L_{k\!-\!1} \\
	&\succeq \underbrace{\L_{k\!-\!1} \!-\! \frac{1}{t_k}\L_{k\!-\!1}\G_k\R_{k}\G_k^{\top}\L_{k\!-\!1}}_{\L_k} \label{eq:krad_update_dominate_inverse}
\end{align}
\endgroup
(we provide more detail in Appendix~D). %\ref{app:derivations}).
Note that our update for $\L_k$ neither depends on $\B$, $\C$ nor requires any other expensive matrix inverses to compute. This suggests that we do not need to actually store $\B$ or $\C$ to obtain a computationally tractable implementation.

Here, we state intermediate results that suggest that our proposed updates are reasonable (proof in Appendix~B.2). %\ref{app:proof_middle_term}).
\begin{restatable}{proposition}{propmiddleterm}
	\label{prop:middle_term}
	Taking $t_k=1+\|\L_{k-1}\G_k\R_{k-1}\G_k^\top\|_2$ (or the looser but more computationally friendly $t_k=1+\|\L_{k-1}\G_k\R_{k-1}\G_k^\top\|_F$), the PSD matrix\par
	{\centering $\M_k \overset{\Delta}{=} \frac{1}{t_k}\L_{k-1}^{1/2}\G_{k}\R_{k-1}\G_{k}^\top \L_{k-1}^{1/2} \prec \I.$ \par}
\end{restatable}
\begin{restatable}{corollary}{corKrADupdatePD}
	\label{cor:KrAD_update_PD}
	If $\L_{k-1}\succ 0$, the updated $0 \prec \L_k \preceq \L_{k-1}$.
\end{restatable}
In Corollary~\ref{cor:KrAD_update_PD}, the second inequality $\L_k \preceq \L_{k-1}$ has the effect of not increasing the step size, while the first inequality $\L_k \succ 0$ guarantees that we do not reverse the direction of the gradient. These are both valued theoretical properties of useful preconditioners for avoiding divergent behavior. In practice, they may be lightly violated to great effect; for example, in Adam, it is technically possible to have increasing step size as shown by~\citep{reddi18convergence}. The result from Corollary~\ref{cor:KrAD_update_PD} also allows us to leverage existing techniques to bound the regret of our algorithm.

\iffalse
One additional practical consideration is the shrinking norm of the estimator. This can lead to round off issues in matrix-vector products, so in practice we can rescale the estimate by factoring out the norm every $\nu$ iterations.
\fi

\subsubsection{Update schemes}
Thus far, we have glossed over the fact that we have actually only updated $\L_k$. We may update $\R_k$ in the same way; since we have already achieved domination by just updating $\L_k$, we are interested in the process for jointly updating $(\L_k, \R_k)$. Further, we have only hinted at a method for updating statistics; we must still compute the preconditioner.

In the next few subsections, we discuss these issues more concretely, proposing KrADagrad$^\star$, an algorithm that combines two sets of KrAD preconditioners to obtain optimal regret but requires updating two matrix statistics and involves higher order matrix roots. This algorithm is reminiscent of Shampoo, but avoids the numerical difficulty of inverting ill-conditioned matrices. We also propose, KrADagrad, a separate scheme that has suboptimal regret but is more intuitive, showing why we arrive at this form of update. Due to space constraints, we leave this to Appendix~C.1. %\ref{app:implementation_kradagrad}.

\iffalse
\subsection{ONS-style}
\label{subsec:KrADagrad_ONS}

Once we have $\L_k$ and $\R_k$, we can apply it to Equation~\eqref{eq:ONS_update},
\begin{align}
	\label{eq:KrADagrad_ONS_update}
	\w_{k+1} &= \w_k - \eta_k (\R_k\otimes \L_k) \g_k\\
	\Rightarrow \W_{k+1} &= \W_k - \eta_k \L_k\G_k \R_k.
\end{align}
Since we have held one factor constant, we can store its value from the previous iteration and only need to compute the update for the other factor. 
\fi



\subsection{KrADagrad$^\star$: combining preconditioners}
Suppose we have two distinct sets of KrADagrad preconditioners. Here we overload notation a bit, holding the iteration $k$ fixed and dropping it from the subscript, instead using the subscript to denote the index for the set of preconditioners to which each matrix belongs,\par
{\centering        $\R_1 \otimes \L_1 \succeq \Gcal\Gcal^\top, \qquad        \R_2 \otimes \L_2 \succeq \Gcal\Gcal^\top$,
	\par}
recalling that $\Gcal=\begin{pmatrix}
	\Gcal^{'} & \epsilon \I_{N}
\end{pmatrix}$, the augmented matrix collecting the history of observed gradients.
%By (P9) and (P10), %\pref{property:domination_operator_monotone} and \pref{property:domination_quadratic}, 
%we can take certain positive powers and products of both sets of statistics, and the results will also dominate $\Gcal\Gcal^\top$. For instance, u
Using the matrix geometric mean for two matrices \citep{ando04geometric},\par
{\centering
	$\begin{aligned}M_g(\A, \B) &\overset{\Delta}{=} \A^{1/2}(\A^{-1/2}\B\A^{-1/2})^{1/2}\A^{1/2}\\
		& = \A (\A^{-1}\B)^{1/2}\end{aligned}$
	\par}
and due to the geometry of the manifold of PD matrices \citep{bhatia09positive},\par
{\centering
	$\R_c \otimes \L_c \succeq \Gcal\Gcal^\top$.
	\par}
where $\L_c=M_g(\L_1, \L_2)$ and $\R_c=M_g(\R_1, \R_2)$ form another pair of KrAD estimates\footnote{the subscript $c$ stands for ``combined''}.

In this case, we may need to take additional square roots of the quantities $(\L_1^{-1}\L_2, \R_1^{-1}\R_2)$ and of $(\L_c,\R_c)$ themselves. We will first show how these combined estimators are relevant to Shampoo, and use this insight to arrive at a second version of preconditioning, which we call KrADagrad$^\star$.% (read as ``kradagrad-star'').

\subsubsection{Shampoo combines inverse KrAD estimates}
If we keep a pair of KrAD estimates for $\Q$ (an integral power of the preconditioner inverse) instead of directly for $\P$, one in which we only update $\B_1$ and keep $\C_1=\I$ fixed, while in the other we only update $\C_2$ and keep $\B_2=\I$ fixed, we end up with exactly the Shampoo statistics updates\par
{\centering
	$\begin{aligned}[b]\Delta\B_1 &= \G\C_1^{-1}\G^{\top}=\G\G^{\top} \\
		\Delta\C_2 &= \G^{\top}\B_2^{-1}\G=\G^{\top}\G\end{aligned}$.
	\par}
Then, the full Shampoo statistics matrices are related to a combination of these two statistics,\par
{\centering
	$\begin{aligned}[b]\B_c &= \! M_g(\B_1,\I) \!=\! \B_1^{1/2}\\
		\C_c &= \! M_g(\I,\C_2) \!=\! \C_2^{1/2}\end{aligned}$,
	\par}
These still need to be inverse square rooted to be applied as the preconditioner, since here we have constructed\par
{\centering
	$\C_c\otimes\B_c \succeq \Gcal\Gcal^{\top}$.
	\par}
% \textcolor{red}{we should write out preconditioner explicitly for clarity}
The inverse square root ultimately gets grouped together with the square roots in the expressions above to yield the inverse $1/4$-th power in the preconditioning,\par
{\centering
	$\begin{aligned}\W_k &= \W_{k-1} - \eta\B_c^{-1/2}\G_k\C_c^{-1/2}\\
		&= \W_{k-1} - \eta\B_1^{-1/4}\G_k\C_2^{-1/4}\end{aligned}$
	\par}
(compare with Equations~\eqref{eq:shampoo_stat1}-\eqref{eq:shampoo_precon1}). Each individual estimator keeps one Kronecker factor as identity.

\subsubsection{KrADagrad$^\star$}
Inspired by Shampoo's optimality, we can maintain a pair of KrADagrad preconditioners $(\L_{k,1},\R_{k,1})=(\L_{k},\I_m)$ and $(\L_{k,2},\R_{k,2})=(\I_n,\R_{k})$, where the second index in the LHS subscripts denotes the estimator. Since we will hold the identity matrices constant, we do not need to store or perform multiplication with them explicitly. In addition, this means we can unambiguously drop the second index in the subscripts. On iteration $k$, we update both $\L_{k}$ and $\R_{k}$, just as in shampoo. With\par
{\centering
	$\begin{aligned}[b]t_{k,L} & \leftarrow 1+\|\G_k\G_k^\top \L_{k-1}\|_F \\
		t_{k,R} & \leftarrow 1+\|\G_k^\top \G_k\R_{k-1}\|_F\end{aligned}$,
	\par}
we summarize in Algorithm \ref{alg:KrADagrad_star}.
\begin{algorithm}[tb]
	\caption{KrADagrad$^\star$: Precondition \textit{without} inversion}
	\label{alg:KrADagrad_star}
	\begin{algorithmic}
		\STATE {\bfseries Input:} Parameters $\W_0\in\Rbb^{m\times n}$, iterations $K$, step size $\eta$, exponent $\alpha = 1/2$.
		\STATE {\bfseries Initialize:} $(\L_0, \R_0)=(\I_m, \I_n)$
		\FOR{$k=1, \ldots, K$}
		\STATE Obtain gradient $\G_k$
		\STATE Compute $\Delta\L_{k}=\frac{1}{t_{k,L}}\L_{k-1}\G_{k}\G_{k}^\top \L_{k-1}$
		\STATE Compute $\Delta\R_{k}=\frac{1}{t_{k,R}}\R_{k-1}\G^\top_{k}\G_{k}\R_{k-1}$
		\STATE Update $(\L_k,\R_k) \leftarrow (\R_{k-1}-\Delta\R_{k}, \L_{k-1}-\Delta\L_{k})$
		\STATE Compute $(\L^{\alpha/2}_k, \R^{\alpha/2}_k)$ from $(\L_k, \L^{\alpha/2}_{k\!-\!1}, \R_k, \R^{\alpha/2}_{k\!-\!1})$
		\STATE Apply preconditioned gradient step\par
		{\centering
			$\W_k = \W_{k-1} - \eta\L^{\alpha/2}_k\G_k\R^{\alpha/2}_k$
			\par}
		\ENDFOR
	\end{algorithmic}
\end{algorithm}
In order to bound the regret, we need a few intermediate results.
\begin{restatable}{lemma}{lemtracebound}
	\label{lem:trace_bound}
	Assume a 1-Lipschitz loss, implying $\|\G_k\|_2 \le 1 \forall k.$
	%. This implies\par
	%{\centering $\|\G_k\|_2 \le 1 \quad \forall k.$ \par}
	Suppose by iteration $k$, $\L_k$ is updated in $k - k'$ of those steps for $0\le k' \le k$ (and thus $\R_k$ is updated in the remaining $k'$ steps). Letting $\B_k=\L_k^{-1}$, $\C_k=\R_k^{-1}$, and $s=1 \!+\! \|\L_0\|_2\|\R_0\|_2$,
	\begingroup\abovedisplayskip=3pt plus 1pt minus 3pt\belowdisplayskip=3pt plus 1pt minus 3pt
	\begin{align}
		\tr(\B_k) & \le\! \tr(\B_0) \!+\! k' ms\|\R_0\|_2 \label{eq:trace1}\\
		\tr(\C_k) & \le\! \tr(\C_0) \!+\! (k \!-\! k') ns\|\L_0\|_2. \label{eq:trace2}
	\end{align}
	\endgroup
\end{restatable}
Next, we restate Theorem 7 from~\cite{gupta18shampoo} in our notational setting. We additionally make one minor but straightforward substitution of the smaller matrix dimension instead of the rank, as the dimensions upper bound the rank. 
\begin{lemma}[From~\cite{gupta18shampoo}]
	\label{lem:gupta}
	Let $\w_{*}\in\Rbb^{mn}$ with $m\ge n$, $t>0$, and $D\overset{\Delta}{=}\max_{1\le k \le K}\|\w_k-\w_{*}\|_2$. The regret from using a Kronecker factorized preconditioner\par
	{\centering {\small $\P_k=\C_{k}^{-1/2}\otimes\B_{k}^{-1/2}$} \par}
	that dominates the empirical Fisher matrix\par
	{\centering {\small $\C_{k} \otimes \B_{k} \succeq \frac{1}{nt} \Gcal_k \Gcal_k^\top$} \par}
	is bounded\par
	{\centering {\small $\sum_{k=1}^{K}(f_k(\w_k)-f_k(\w_{*})) \le D\sqrt{2tn}\tr(\B_K^{1/2})\tr(\C_K^{1/2})$ } \par}
\end{lemma}

Now, we have our regret result: a proof is in Appendix~B.4. %\ref{app:proof_krad_regret}.
\begin{theorem}
	\label{thm:krad_regret}
	Assuming a 1-Lipschitz loss, the regret from using KrADagrad$^\star$ scales as $O(\sqrt{K})$.
\end{theorem}
%We provide the proof in Appendix B.



    
    \section{Implementation}
    \label{sec:implementation}
%We analyze the algorithms and compare it with others. Due to symmetry of the Kronecker product, we choose to present results in terms of only $\L_k$ for brevity, noting that the analogous claims hold for $\R_k$.
Now we discuss the algorithmic considerations for actually implementing KrADagrad$^\star$\footnote{\href{http://github.com/jonathanmei/kradagrad}{http://github.com/jonathanmei/kradagrad}}. The main difficulty lies in efficiently computing the matrix roots. Differentiable matrix square roots for machine learning have been the subject of a substantial amount of research~\citep{song21approximate}. For preconditioning, we do not require differentiability for our roots, so they can be computed using numerical linear algebra methods (e.g. SVD) without concern for the backward pass. While such algorithms are gaining hardware support on current GeMM accelerators and software frameworks, they are still not universal (Pytorch supports SVD on CUDA in double precision via cuSOLVER and MAGMA \citep{pytorch21blog,cusolver23api,dongarra14accelerating}, but TPUs to our knowledge do not \citep{jouppi21ten}).

However, as~\citep{anil20scalable} discovered, accurate inverse powers require high precision to retain the important contributions of the eigenvectors corresponding to the smallest eigenvalues, and current general matrix multiplication (GeMM) accelerators do not prioritize 64-bit computation (i.e. GPUs see significant speed reductions, while TPUs do not support it at all). Plus, they assume preconditioners are a slowly-varying sequence of matrices. Thus, they propose computing matrix roots iteratively on CPU due to these algorithmic and hardware architectural constraints.

% However, as~\citep{anil20scalable} have discovered, accurate inverse powers require high precision to retain the important contributions of the eigenvectors corresponding to the smallest eigenvalues, and current general matrix multiplication (GeMM) accelerators do not prioritize 64-bit computation performance (i.e. GPU's see significant speed reductions, while TPU's do not support it at all). Plus, they assume the preconditioners are a slowly-varying sequence of matrices. Thus, they propose computing matrix roots iteratively on CPU due to this combination of algorithmic and hardware architectural constraints.

In contrast, Kradagrad$^\star$ deals with positive roots, and both these iterative matrix methods and the numerical linear algebra techniques are otherwise actually amenable to computation using GeMM accelerators. Ultimately, as a straightforward solution for the roots that appear in KrADagrad$^\star$, the SVD on GPU suffices. We note that avoiding inversion is thus not about reducing the computational complexity, rather we aim to avoid needing 64-bit computation. Further details of matrix roots are discussed in Appendix~C.3. %\ref{app:implementation_matrix_roots}. 

Additionally, diagonal damping is a common feature of preconditioned methods due to their contribution to numerical stability. While we do not need this for numerical stability, we find empirically that diagonal damping still helps reduce the effect of gradient noise and smooth out the loss curve. We describe how this can be applied to KrADagrad$^\star$ in further detail in Appendix~C.4. %\ref{app:implementation_damping}.

So far, our proposed algorithms and analyses apply to matrix parameters. As noted earlier, we can extend this to tensor parameters in a fairly straightforward manner as mainly a matter of additional notation and bookkeeping. Hence, we relegate the extension to Appendix~C.5. %\ref{app:implementation_tensors}.

The baseline Pytorch Shampoo implementation we use is not optimized, so is not totally fair to the true capability of the baseline. Similarly, our KrADagrad variants, being derived from the mentioned implementation of Shampoo, did not have optimized implementations.
Nonetheless, we have conducted some preliminary wall clock time comparisons between our implementation of KrADagrad and Shampoo, which we summarize in Table~\ref{tab:wct}. For each data set, we report the time in seconds to run a single epoch for (KrADagrad, KrADagrad$^\star$, Shampoo) along with a 2 standard deviation interval, computed from epochs 5 through 10 on an NVidia A40 GPU.

\setlength{\belowcaptionskip}{-3pt}
\begin{table}[ht]
	\caption{Comparison of wall clock times (in seconds) between our implementations of KrADagrad variants and Shampoo on a single epoch and a 2 standard deviation interval computed from 5 epochs for CIFAR-10/100 data sets on an NVidia A40 GPU.}
	\label{tab:wct}
	\vspace{-10pt}
	\begin{center}
		\begin{smaller}
			\begin{sc}
				\begin{tabular}{lcccr}
					\toprule
					Data set & KrADagrad & KrADagrad$^\star$ & Shampoo \\
					\midrule
					CIFAR-10    & 29.50$\pm$ 0.88& 45.76$\pm$ 0.20& 35.62 $\pm$ 0.17 \\
					CIFAR-100 & 54.71$\pm$ 0.36& 76.12$\pm$ 1.72& 66.45 $\pm$ 0.33 \\
					\bottomrule
				\end{tabular}
			\end{sc}
		\end{smaller}
	\end{center}
	\vspace{-16pt}
\end{table}


\subsection{Compute and memory costs}
For Kradagrad, computing $\Delta\L_k$ has a computational cost of $(2N(2m+n)+2m^2+2m^3)$
While we still require $O(m^3+n^3)$ matrix square roots for the KrADagrad update, this is less difficult numerically and thus computationally than the $-4$th root required by Shampoo.

KrADagrad$^\star$ requires 4th root computation, which is comparable in cost to the inverse 4th root, but without the need for high precision.

Storage involves tracking the two matrix factors, and so is $O(m^2+n^2)$ for all methods.

We summarize these costs in Table~\ref{tab:complexity}.
\begin{table}[ht]
	\caption{Comparison of complexity between our implementations of KrADagrad variants and Shampoo.}
	\label{tab:complexity}
	\vspace{-10pt}
	\begin{center}
		\begin{smaller}
			\begin{sc}
				\begin{tabular}{lcccr}
					\toprule
					&KrADagrad & KrADagrad$^\star$ & Shampoo \\
					\midrule
					Compute    & $O(m^3 + n^3)$ & $O(m^3 + n^3)$ & $O(m^3 + n^3)$ \\
					Memory &  $O(m^2 + n^2)$ & $O(m^2 + n^2)$ & $O(m^2 + n^2)$ \\
					\bottomrule
				\end{tabular}
			\end{sc}
		\end{smaller}
	\end{center}
	\vspace{-16pt}
\end{table}

    
    \section{Experiments}
    \label{sec:experiments}
We address the following: 1) How sensitive is our method to matrix conditioning compared to Shampoo? 2) How does convergence speed compare in the number of training steps? (as measured by task-specific validation metrics) 3) How do our methods compare in the achieved model quality at or near convergence? The goal of each is to compare optimizers in various challenging loss landscapes, rather than to achieve state of the art performance.% 4) How do our wall clock times compare to alternatives? 

To answer 1, in a synthetic experiment we minimize a multidimensional quadratic function with a non-diagonal, poorly-conditioned Hessian that neatly factorizes into a Kronecker product of two individually poorly-conditioned PD matrices.

To answer 2 and 3, we compare KrADagrad and KrADagrad$^\star$ to alternatives across a variety of tasks: image classification (IC), autoencoder problems (AE), recommendation (RecSys), continual learning (CL). For IC experiments we train ResNet-32/56~\citep{he16deep} without BatchNorm (BN) on CIFAR-10/100~\citep{krizhevsky09learning}. For AE, we train simple autoencoders consistent with ~\citep{goldfarb20practical} on MNIST~\citep{mnist}, as well as CURVE and FACES~\citep{curves_faces}. For recommendation, we train H+Vamp Gated~\citep{Kim_2019} on MovieLens20M~\citep{movielens20m}. In the continual learning setting we train on two benchmarks from \citep{lomonaco2021avalanche}: GEM~\citep{GEM} on Permuted MNIST~\citep{permuted_mnist} and LaMAML~\citep{lamaml} on Split CIFAR100~\citep{split_cifar100}.% For LM, we train GPT-2~\citep{radford2019better} on \citep{}.

We choose Shampoo as a baseline 2nd order optimizer as: a.) KrADagrad variants are most similar to Shampoo b.) we expect at best similar performance unless KrADagrad's approximations and lower precision are in practice detrimental. To ground all comparisons, we always include SGD, Adam, plus any unique optimizer from an existing benchmark.

In all our experiments, we initialized from common seeds across optimizers. However, the first points on the training curves follows the number of training steps per evaluation interval, so they do not visually appear to start from the same point. Doing so would have required modifying the codebases separately for each experiment. In addition, ideally we would have had the resources to run multiple shared seeds for the selected HPs of every task we explored. For smaller tasks where convergence is achieved quickly, such as the autoencoder experiments, it was feasible for us to do this and we share this in Figures \ref{fig:ae} and 8. %\ref{fig:ae_seeds}. 
In the tradeoff of how to use our compute and time resources, we opted to present fair evaluations of the optimizers (i.e. ensuring shared seeds and optimized HPs per optimizer per task) across a diversity of tasks, and repeatability statistics from different initial guesses as much as possible.

% Note use of \abovespace and \belowspace to get reasonable spacing
% above and below tabular lines.

% \begin{table}[t]
	%     \caption{Classification accuracies for naive Bayes and flexible
		%     Bayes on various data sets.}
	%     \label{tab:samp}
	%     \vskip 0.15in
	%     \begin{center}
		%     \begin{small}
			%     \begin{sc}
				%     \begin{tabular}{lcccr}
					%         \toprule
					%         Data set & Naive & Flexible & Better? \\
					%         \midrule
					%         Breast    & 95.9$\pm$ 0.2& 96.7$\pm$ 0.2& $\surd$ \\
					%         Cleveland & 83.3$\pm$ 0.6& 80.0$\pm$ 0.6& $\times$\\
					%         Glass2    & 61.9$\pm$ 1.4& 83.8$\pm$ 0.7& $\surd$ \\
					%         Credit    & 74.8$\pm$ 0.5& 78.3$\pm$ 0.6&         \\
					%         Horse     & 73.3$\pm$ 0.9& 69.7$\pm$ 1.0& $\times$\\
					%         Meta      & 67.1$\pm$ 0.6& 76.5$\pm$ 0.5& $\surd$ \\
					%         Pima      & 75.1$\pm$ 0.6& 73.9$\pm$ 0.5&         \\
					%         Vehicle   & 44.9$\pm$ 0.6& 61.5$\pm$ 0.4& $\surd$ \\
					%         \bottomrule
					%     \end{tabular}
				%     \end{sc}
			%     \end{small}
		%     \end{center}
	%     \vskip -0.1in
	% \end{table}

\iffalse
\subsection{Metrics/Performance Reporting}
We should generate in general
\begin{itemize}
	\item train/val loss as a function of epochs
\end{itemize}
We should generate for classification tasks (note: if we have some and we don't have all of these that's ok: we should make sure any future other experiments have at least a consistent set of metrics)
\begin{itemize}
	\item Precision/Recall, ideally as a function of epochs
	\item AUC
	\item Accuracy
\end{itemize}
We should generate for language modeling
\begin{itemize}
	\item Train/val perplexity
\end{itemize}
We can refine what is worth reporting after.
\fi
\setlength{\parskip}{.25\baselineskip}
\subsection{Conditioning Experiment}
To see the effects of Hessian conditioning on optimizers of interest, we create a synthetic deterministic convex problem. For $\X\in\Rbb^{128\times 128}$, we minimize a quadratic loss function\par
{\centering
	$\min_{\X}\, \tr(\X^\top\A\X\B)$
	\par}
where $\A,\B \succ 0$ are non-diagonal and have condition numbers $\kappa(\A)=\kappa(\B)=10^{10}$. We seed each optimizer with the same starting point and sweep learning rates and pick the one with the lowest loss for each optimizer. We provide further details in Appendix~G. %\ref{app:addl_exp}.
While this stationary problem without a validation or test dataset is different from the online setting assumed in Theorem \ref{thm:krad_regret}, and even from a typical offline machine learning problem, it helps isolate differences in behavior between optimizers in a badly conditioned convex loss landscape.%across a grid 

\setlength{\belowcaptionskip}{-8pt}
\begin{figure}[t]
	\includegraphics[width=0.5\textwidth]{images/quad_rel.png}
	\caption{Loss on synthetic quadratic in log scale, relative to SGD. Each curve is \textit{divided} by that of SGD.}
	\label{fig:quad}
\end{figure}

In Figure \ref{fig:quad}, while Shampoo in double precision outperforms the others in terms of final loss achieved, single precision KrADagrad$^\star$ outperforms single precision Shampoo. Adam does not converge quite as quickly or to as good a loss, as the adaptivity it provides is aligned to the canonical axis, while $\A, \B$ being non-diagonal act in a rotated basis.

\subsection{Real Datasets}

For each task and optimizer we sweep hyperparameters (HP), select those yielding the best validation metric at some epoch or iteration, and display the corresponding learning curves in Figures \ref{fig:resnet} through \ref{fig:hvamp}.  We set the preconditioner update rate for Shampoo and KrADagrad$^\star$ to every 20 training steps. For all experiments, Shampoo uses double precision for negative matrix roots, and SGD includes momentum, unless stated otherwise. We summarize our observations here and provide additional detail in Appendix~G. %\ref{app:experiments}.%, and summarize our observations here.%\ref{app:experiments}. experiment descriptions and details for each task 

\begin{figure}[t]
	\includegraphics[width=0.5\textwidth]{images/resnet32_rel.png}
	\includegraphics[width=0.5\textwidth]{images/resnet56_rel.png}
	\caption{Vision experiments. Each curve is initialized from the same single seed and \textit{subtracted} from that of Adam. Top: Top 1 Accuracy (in \%) of ResNet-32 (without BN) on CIFAR-10, relative to Adam; Bottom: Top 1 Accuracy (in \%) of ResNet-56 (without BN) on CIFAR-100, relative to Adam.}
	\label{fig:resnet}
\end{figure}
\setlength{\belowcaptionskip}{-6pt}

For the three autoencoder experiments, after HP tuning we retrain the best HPs with 5 unique seeds shared across optimizers, with results in Figures \ref{fig:ae} and 8. %\ref{fig:ae_seeds}. 
We also find the auto encoder tasks to be relevant baselines for evaluating Shampoo alternatives because Shampoo consistently outperforms SGD and Adam. While KrAdagrad$^\star$ also outperforms SGD and Adam, Shampoo reaches better solutions or similar solutions in fewer steps. Following the synthetic experiments, we test the hypothesis that KrAdagrad$^\star$ might perform better than 32-bit Shampoo by running our best HPs for Shampoo with single precision. The 32- and 64-bit versions perform similarly on these tasks, falsifying the hypothesis in the general sense. KrAdagrad, slower in the number of steps and performing worse in general, eventually reaches similar performance to KrAdagrad$^\star$ on CURVES, with some runs reaching that of 32 bit Shampoo (see Figure 8b). %\ref{fig:ae_seeds}b). 
One hypothesis is that on these particular tasks, effectively traversing the loss landscape requires eigenvectors that the KrADagrad approximation has more difficulty capturing than Shampoo.

\begin{figure*}[!h]
	\centering
	\includegraphics[width=0.33\textwidth]{images/mnist_seeds_CI95.png}
	\includegraphics[width=0.33\textwidth]{images/curves_seeds_CI95.png}
	\includegraphics[width=0.33\textwidth]{images/faces_seeds_CI95.png}
	\caption{Reconstruction validation error for fully connected auto encoder a) mean cross entropy on MNIST b) mean cross entropy on CURVES c) mean squared error on FACES.  The learning curves shown are averaged across 5 unique seeds with CI95 error bars.  We do this analysis for the autoencoder experiments since they're relatively inexpensive to run.}
	\label{fig:ae}
\end{figure*}



% \begin{figure}[t]
	% \includegraphics[width=0.5\textwidth]{images/mnist.png}
	% \caption{Reconstruction validation error, as mean cross entropy, for fully connected auto encoder on MNIST. }
	% \label{fig:ae_mnist}
	% \end{figure}

% \begin{figure}[t]
	% \includegraphics[width=0.5\textwidth]{images/curves.png}
	% \caption{Reconstruction validation error, as mean cross entropy, for a fully connected auto encoder on CURVES dataset.}
	% \label{fig:ae_curves}
	% \end{figure}

% \begin{figure}[t]
	% \includegraphics[width=0.5\textwidth]{images/faces.png}
	% \caption{ Reconstruction validation error, as mean squared error, for a fully connected auto encoder on FACES dataset.}
	% \label{fig:ae_faces}
	% \end{figure}

\begin{figure}[t]
	\vspace{-6pt}
	\centering
	\includegraphics[width=0.45\textwidth]{images/gem_bottom.png}
	\includegraphics[width=0.45\textwidth]{images/lamaml_bottom.png}
	\caption{a) Top 1 Accuracy of GEM on Permuted MNIST, relative to SGD. Each curve averages multiple seeds, subtracted from the corresponding SGD curve to illuminate otherwise unobservable differences. See Appendix~G %\ref{app:addl_exp} 
		for absolute curves. b) Top 1 Accuracy of LaMAML on Split CIFAR-100, relative to Adam. }\label{fig:pmnist-lamaml}
	\vspace{-6pt}
\end{figure}

While \cite{gupta18shampoo} performs experiments on ``standard'' machine learning tasks, their regret results pertain to an online setting. We include continual learning problems to test the limits of the theoretical setting. Figure \ref{fig:pmnist-lamaml}a displays learning curves for GEM on Permuted MNIST relative to SGD to illuminate small differences (we include a plot of the actual accuracies in Appendix~G). % \ref{app:addl_exp}).
On average, KrADagrad stays slightly ahead of others throughout training, unlike KrADagrad$^\star$.


% \begin{figure}[t]
	% \includegraphics[width=0.5\textwidth]{images/gem_bottom.png}
	% \caption{Top 1 Accuracy of GEM on Permuted MNIST, relative to SGD. Each curve averages multiple seeds, subtracted from the corresponding SGD curve to illuminate otherwise unobservable differences. See Appendix G %\ref{app:experiments}in
		% for absolute curves.}
	% \label{fig:pmnist}
	% \end{figure}

Figure \ref{fig:pmnist-lamaml}b shows similar accuracy curves for LaMAML on Split CIFAR100 relative to Adam (again, the absolute accuracies are in Appendix~G). %\ref{app:addl_exp}). 
KraADagrad$^\star$ closely follows Shampoo and the adaptive learning rate optimizer from \citep{lamaml}, while KrADagrad performs only better than Adam.

% \begin{figure}[t]
	% \includegraphics[width=0.5\textwidth]{images/lamaml_bottom.png}
	% \caption{Top 1 Accuracy of LaMAML on Split CIFAR-100, relative to Adam. }
	% \label{fig:lamaml}
	% \end{figure}

We also evaluate Shampoo and KrAdagrad variants against the baseline for H+Vamp Gated\footnote{https://github.com/psywaves/EVCF}. Shampoo reaches a lower training loss compared to KrAdagrad$^\star$, while both Adam variants converge faster and to lower loss values (Figure \ref{fig:hvamp}).

\begin{figure}[t]
	\includegraphics[width=0.45\textwidth]{images/hvamp_train.png}
	\caption{Training loss of H+Vamp Gated on ML-20M. We choose to compare the training loss for this experiment because it comprises a weighted combination of components, RE + $\beta\cdot $ KL, where $\beta$ changes according to a predefined schedule (causing loss growth in the 1st 100 epochs), which is not consistent in the validation loss.}
	\label{fig:hvamp}
\end{figure}

% \begin{figure}
	%     \centering
	%     \includegraphics[width=0.45\textwidth]{images/causal-lm.png}
	%     \caption{Causal language model on codeparrot dataset. We train for a single epoch as the dataset is fairly large and it takes several hours to do a sweep of a single optimizer on four T4s. The flat lines are due to the improvements being difficult to discern at this scale. We see that Kradagrad$^*$ outperforms Shampoo substantially.}
	%     \label{fig:causal-lm}
	% \end{figure}




    
    \section{Conclusion}
    \label{sec:conclusion}
\setlength{\parskip}{.5\baselineskip}
We introduced KrADagrad$^\star$, a preconditioned gradient optimizer that avoids matrix inversion, and proved that it has optimal regret properties. We showed that its performance exceeds that of Shampoo in single precision on an ill-conditioned synthetic convex optimization problem. Our experiments on real datasets show that KrADagrad$^\star$ often performs similarly to Shampoo.

Future work includes tractable methods to compute low-rank updates to matrix roots, for instance with rational Krylov subspace methods, to improve the performance of KrADagrad$^\star$ and other Kronecker-factored preconditioners. It is also worth considering the application of the optimizer to other application areas, such as decision making and controls, or scientific discovery.

%{\color{teal} My two cents, but I would delete this entire paragraph based on irrelevance.}The potential societal impacts of newly developed optimizers are wide ranging, as they can be used to train many types of models for all types of applications However, the nature of the impact for each case depends on the application, model design and deployment, and datasets involved, among other factors; thus the direct impact of a new optimizer on social advancement and equity are out of the scope of this discussion.



    \begin{acknowledgements}
        We thank Luminous Computing for funding this work. In particular, we thank David Baker for encouraging this line of investigation and David Scott for fruitful discussions about numerical linear algebra.
    \end{acknowledgements}
\clearpage
% References
\bibliography{mei_555.bib}


\end{document}
