Distributed Second-Order Optimization using Kronecker-Factored ApproximationsDownload PDF

Published: 21 Jul 2022, Last Modified: 05 May 2023ICLR 2017 PosterReaders: Everyone
Abstract: As more computational resources become available, machine learning researchers train ever larger neural networks on millions of data points using stochastic gradient descent (SGD). Although SGD scales well in terms of both the size of dataset and the number of parameters of the model, it has rapidly diminishing returns as parallel computing resources increase. Second-order optimization methods have an affinity for well-estimated gradients and large mini-batches, and can therefore benefit much more from parallel computation in principle. Unfortunately, they often employ severe approximations to the curvature matrix in order to scale to large models with millions of parameters, limiting their effectiveness in practice versus well-tuned SGD with momentum. The recently proposed K-FAC method(Martens and Grosse, 2015) uses a stronger and more sophisticated curvature approximation, and has been shown to make much more per-iteration progress than SGD, while only introducing a modest overhead. In this paper, we develop a version of K-FAC that distributes the computation of gradients and additional quantities required by K-FAC across multiple machines, thereby taking advantage of method’s superior scaling to large mini-batches and mitigating its additional overheads. We provide a Tensorflow implementation of our approach which is easy to use and can be applied to many existing codebases without modification. Additionally, we develop several algorithmic enhancements to K-FAC which can improve its computational performance for very large models. Finally, we show that our distributed K-FAC method speeds up training of various state-of-the-art ImageNet classification models by a factor of two compared to Batch Normalization(Ioffe and Szegedy, 2015).
TL;DR: Fixed typos pointed out by AnonReviewer1 and AnonReviewer4 and added the experiments in Fig. 6 showing the poor scaling of batch normalized SGD using a batch size of 2048 on googlenet.
Conflicts: cs.toronto.edu, google.com
Keywords: Deep learning, Optimization
13 Replies