Three factors influencing minima in SGD

Stanisław Jastrzębski, Zac Kenton, Devansh Arpit, Nicolas Ballas, Asja Fischer, Amos Storkey, Yoshua Bengio

Feb 15, 2018 (modified: Feb 15, 2018) ICLR 2018 Conference Blind Submission readers: everyone Show Bibtex
  • Abstract: We study the statistical properties of the endpoint of stochastic gradient descent (SGD). We approximate SGD as a stochastic differential equation (SDE) and consider its Boltzmann Gibbs equilibrium distribution under the assumption of isotropic variance in loss gradients.. Through this analysis, we find that three factors – learning rate, batch size and the variance of the loss gradients – control the trade-off between the depth and width of the minima found by SGD, with wider minima favoured by a higher ratio of learning rate to batch size. In the equilibrium distribution only the ratio of learning rate to batch size appears, implying that it’s invariant under a simultaneous rescaling of each by the same amount. We experimentally show how learning rate and batch size affect SGD from two perspectives: the endpoint of SGD and the dynamics that lead up to it. For the endpoint, the experiments suggest the endpoint of SGD is similar under simultaneous rescaling of batch size and learning rate, and also that a higher ratio leads to flatter minima, both findings are consistent with our theoretical analysis. We note experimentally that the dynamics also seem to be similar under the same rescaling of learning rate and batch size, which we explore showing that one can exchange batch size and learning rate in a cyclical learning rate schedule. Next, we illustrate how noise affects memorization, showing that high noise levels lead to better generalization. Finally, we find experimentally that the similarity under simultaneous rescaling of learning rate and batch size breaks down if the learning rate gets too large or the batch size gets too small.
  • TL;DR: Three factors (batch size, learning rate, gradient noise) change in predictable way the properties (e.g. sharpness) of minima found by SGD.
  • Keywords: SGD, Deep Learning, Generalization