Tensor Normal Training for Deep Learning ModelsDownload PDF

Published: 09 Nov 2021, Last Modified: 22 Oct 2023NeurIPS 2021 SpotlightReaders: Everyone
Keywords: Optimization for deep networks
Abstract: Despite the predominant use of first-order methods for training deep learning models, second-order methods, and in particular, natural gradient methods, remain of interest because of their potential for accelerating training through the use of curvature information. Several methods with non-diagonal preconditioning matrices, including KFAC, Shampoo, and K-BFGS, have been proposed and shown to be effective. Based on the so-called tensor normal (TN) distribution, we propose and analyze a brand new approximate natural gradient method, Tensor Normal Training (TNT), which like Shampoo, only requires knowledge of the shape of the training parameters. By approximating the probabilistically based Fisher matrix, as opposed to the empirical Fisher matrix, our method uses the block-wise covariance of the sampling based gradient as the pre-conditioning matrix. Moreover, the assumption that the sampling-based (tensor) gradient follows a TN distribution, ensures that its covariance has a Kronecker separable structure, which leads to a tractable approximation to the Fisher matrix. Consequently, TNT's memory requirements and per-iteration computational costs are only slightly higher than those for first-order methods. In our experiments, TNT exhibited superior optimization performance to state-of-the-art first-order methods, and comparable optimization performance to the state-of-the-art second-order methods KFAC and Shampoo. Moreover, TNT demonstrated its ability to generalize as well as first-order methods, while using fewer epochs.
Code Of Conduct: I certify that all co-authors of this work have read and commit to adhering to the NeurIPS Statement on Ethics, Fairness, Inclusivity, and Code of Conduct.
TL;DR: We proposed a new second-order method (approximate natural gradient method) using tensor normal distribution for training deep learning models.
Supplementary Material: pdf
Code: https://github.com/renyiryry/tnt_neurips_2021
Community Implementations: [![CatalyzeX](/images/catalyzex_icon.svg) 2 code implementations](https://www.catalyzex.com/paper/arxiv:2106.02925/code)
12 Replies