Gradient-Based Feature Learning under Structured Data

Published: 21 Sept 2023, Last Modified: 02 Nov 2023NeurIPS 2023 posterEveryoneRevisionsBibTeX
Keywords: feature learning, neural networks, single-index model, gradient descent
TL;DR: We show that when learning single-index models with neural networks using gradient-based training, additional structure in the covariance matrix can improve the sample complexity.
Abstract: Recent works have demonstrated that the sample complexity of gradient-based learning of single index models, i.e. functions that depend on a 1-dimensional projection of the input data, is governed by their information exponent. However, these results are only concerned with isotropic data, while in practice the input often contains additional structure which can implicitly guide the algorithm. In this work, we investigate the effect of a spiked covariance structure and reveal several interesting phenomena. First, we show that in the anisotropic setting, the commonly used spherical gradient dynamics may fail to recover the true direction, even when the spike is perfectly aligned with the target direction. Next, we show that appropriate weight normalization that is reminiscent of batch normalization can alleviate this issue. Further, by exploiting the alignment between the (spiked) input covariance and the target, we obtain improved sample complexity compared to the isotropic case. In particular, under the spiked model with a suitably large spike, the sample complexity of gradient-based training can be made independent of the information exponent while also outperforming lower bounds for rotationally invariant kernel methods.
Submission Number: 9211
Loading