Neural Collapse Under MSE Loss: Proximity to and Dynamics on the Central PathDownload PDF

29 Sept 2021, 00:35 (edited 10 May 2022)ICLR 2022 OralReaders: Everyone
  • Keywords: neural collapse, deep learning theory, deep learning, inductive bias, equiangular tight frame, ETF, nearest class center, mean squared error loss, MSE loss, invariance, renormalization, gradient flow, dynamics, adversarial robustness
  • Abstract: The recently discovered Neural Collapse (NC) phenomenon occurs pervasively in today's deep net training paradigm of driving cross-entropy (CE) loss towards zero. During NC, last-layer features collapse to their class-means, both classifiers and class-means collapse to the same Simplex Equiangular Tight Frame, and classifier behavior collapses to the nearest-class-mean decision rule. Recent works demonstrated that deep nets trained with mean squared error (MSE) loss perform comparably to those trained with CE. As a preliminary, we empirically establish that NC emerges in such MSE-trained deep nets as well through experiments on three canonical networks and five benchmark datasets. We provide, in a Google Colab notebook, PyTorch code for reproducing MSE-NC and CE-NC: The analytically-tractable MSE loss offers more mathematical opportunities than the hard-to-analyze CE loss, inspiring us to leverage MSE loss towards the theoretical investigation of NC. We develop three main contributions: (I) We show a new decomposition of the MSE loss into (A) terms directly interpretable through the lens of NC and which assume the last-layer classifier is exactly the least-squares classifier; and (B) a term capturing the deviation from this least-squares classifier. (II) We exhibit experiments on canonical datasets and networks demonstrating that term-(B) is negligible during training. This motivates us to introduce a new theoretical construct: the central path, where the linear classifier stays MSE-optimal for feature activations throughout the dynamics. (III) By studying renormalized gradient flow along the central path, we derive exact dynamics that predict NC.
  • One-sentence Summary: Neural Collapse occurs empirically on deep nets trained with MSE loss and studying this setting leads to insightful closed-form dynamics.
19 Replies