Keywords: sparse training, training dynamics, normalization layers, batch norm, dynamic sparse training
TL;DR: We study the role of normalization layers in sparse training and show that BatchNorm causes instability by scaling up the gradients. To address this, we propose a new optimization method that improves both generalization and convergence rates.
Abstract: Normalization Layers have become an essential component of Deep Neural Networks for better
training dynamics and convergence rate. While the effects of layers like BatchNorm and LayerNorm
are well studied for dense networks, their impact on the training dynamics of Sparse Neural Networks
(SNNs) is not well understood. In this work, we analyze the role of Batch Normalization (BN) in
the training dynamics of SNNs. We theoretically and empirically show that BatchNorm induces
training instability in SNNs, leading to lower convergence rates and worse generalization performance
compared to the dense models. Specifically, we show that adding BatchNorm layers into sparse neural
networks can significantly increase the gradient norm, causing training instability. We further validate
this instability by analyzing the operator norm of the Hessian, finding it substantially larger in the case
of sparse training that the dense training. This indicates that the sparse training operates further beyond
the “edge of stability” bound of 2/η. To mitigate this instability, we propose a novel preconditioned
gradient descent method for sparse networks with BatchNorm. Our method takes into account the
sparse topology of the neural network and rescales the gradients to prevent blow-up. We empirically
demonstrate that our proposed preconditioned gradient descent improves the convergence rate and
generalization for Dynamic Sparse Training.
Student Paper: Yes
Submission Number: 63
Loading