Learning Gaussian Multi-Index Models with Gradient Flow: Time Complexity and Directional Convergence

Published: 11 Feb 2025, Last Modified: 06 Mar 2025CPAL 2025 (Recent Spotlight Track)EveryoneRevisionsBibTeXCC BY 4.0
Keywords: time complexity, gradient flow dynamics, hardness
TL;DR: We study the time complexity of learning features in arbitrary directions, and the emergence of a local minimum when the features get too close to each other, i.e. exceeding an explicit threshold.
Abstract: This work focuses on the gradient flow dynamics of a neural network model that uses correlation loss to approximate a multi-index function on high-dimensional standard Gaussian data. Specifically, the multi-index function we consider is a sum of neurons $f^*(x) = \sum_{j=1}^k \sigma^*(v_j^T x)$ where $v_1, \dots, v_k$ are unit vectors, and $\sigma^*$ lacks the first and second Hermite polynomials in its Hermite expansion. It is known that, for the single-index case ($k = 1$), overcoming the search phase requires polynomial time complexity. We first generalize this result to multi-index functions characterized by vectors in arbitrary directions. After the search phase, it is not clear whether the network neurons converge to the index vectors, or get stuck at a sub-optimal solution. When the index vectors are orthogonal, we give a complete characterization of the fixed points and prove that neurons converge to the nearest index vectors. Therefore, using $n \asymp k \log k$ neurons ensures finding the full set of index vectors with gradient flow with high probability over random initialization. When $v_i^T v_j = \beta \geq 0$ for all $i \neq j$, we prove the existence of a sharp threshold $\beta_c = c/(c+k)$ at which the fixed point that computes the average of the index vectors transitions from a saddle point to a minimum. Numerical simulations show that using a correlation loss and a mild overparameterization suffices to learn all of the index vectors when they are nearly orthogonal, however, the correlation loss fails when the dot product between the index vectors exceeds a certain threshold.
Submission Number: 22
Loading

OpenReview is a long-term project to advance science through improved peer review with legal nonprofit status. We gratefully acknowledge the support of the OpenReview Sponsors. © 2025 OpenReview