Deep Linear Network Training Dynamics from Random Initialization: Data, Width, Depth, and Hyperparameter Transfer
TL;DR: A theory of train and test loss dynamics for randomly initialized deep linear networks with applications to hyperparameter transfer.
Abstract: We theoretically characterize gradient descent dynamics in deep linear networks trained at large width from random initialization and on large quantities of random data. Our theory captures the ``wider is better" effect of mean-field/maximum-update parameterized networks as well as hyperparameter transfer effects, which can be contrasted with the neural-tangent parameterization where optimal learning rates shift with model width. We provide asymptotic descriptions of both non-residual and residual neural networks, the latter of which enables an infinite depth limit when branches are scaled as $1/\sqrt{\text{depth}}$. We also compare training with one-pass stochastic gradient descent to the dynamics when training data are repeated at each iteration. Lastly, we show that this model recovers the accelerated power law training dynamics for power law structured data in the rich regime observed in recent works.
Lay Summary: Methods that scale up neural networks in a way that preserves certain properties of training can reduce the need for retuning hyperparameters at each model size. One popular approach to this is to increasing the width and depth of a neural network in maximum-update ($\mu$P) scaling. Prior works have empirically shown that (1) wider networks perform better in $\mu$P and (2) that optimal learning rates are approximately constant across widths & depths. However, to date no theory has been proposed that can capture this effect. In this work, we develop a minimal theory of the learning rate transfer effect in randomly initialized deep linear networks. Our theory captures both (1) arbitrarily large deviations from lazy learning and (2) the harmful finite width effects. Our theory accurately captures the failure of hyperparameter transfer in NTK scaling and the success of hyperparameter transfer across widths and depths in $\mu$P with $1/\sqrt{\text{depth}}$ residual branch scaling. Our results are based on a dynamical mean field theory approach, where the finite width effects and SGD noise effects gradually build up over training time and corrupt the dynamics of finite models compared to infinite width models.
Primary Area: Deep Learning->Theory
Keywords: deep learning, mean field, learning dynamics, $\mu$P, residual networks, dynamical mean field theory
Submission Number: 5169
Loading