Learning Hierarchical Polynomials with Three-Layer Neural Networks

Published: 16 Jan 2024, Last Modified: 12 Mar 2024ICLR 2024 posterEveryoneRevisionsBibTeX
Code Of Ethics: I acknowledge that I and all co-authors of this work have read and commit to adhering to the ICLR Code of Ethics.
Keywords: Hierarchical polynomials, feature learning, three-layer networks, sample complexity, gradient descent
Submission Guidelines: I certify that this submission complies with the submission instructions as described on https://iclr.cc/Conferences/2024/AuthorGuide.
TL;DR: We show that three-layer neural networks learn hierarchical polynomials of the form $h = g \circ p$, where $p : \mathbb{R}^d \rightarrow \mathbb{R}$ is a degree $k$ polynomial, in $\widetilde O(d^k)$ samples.
Abstract: We study the problem of learning hierarchical polynomials over the standard Gaussian distribution with three-layer neural networks. We specifically consider target functions of the form $h = g \circ p$ where $p : \mathbb{R}^d \rightarrow \mathbb{R}$ is a degree $k$ polynomial and $g: \mathbb{R} \rightarrow \mathbb{R}$ is a degree $q$ polynomial. This function class generalizes the single-index model, which corresponds to $k=1$, and is a natural class of functions possessing an underlying hierarchical structure. Our main result shows that for a large subclass of degree $k$ polynomials $p$, a three-layer neural network trained via layerwise gradient descent on the square loss learns the target $h$ up to vanishing test error in $\widetilde O(d^k)$ samples and polynomial time. This is a strict improvement over kernel methods, which require $\widetilde \Theta(d^{kq})$ samples, as well as existing guarantees for two-layer networks, which require the target function to be low-rank. Our result also generalizes prior works on three-layer neural networks, which were restricted to the case of $p$ being a quadratic. When $p$ is indeed a quadratic, we achieve the information-theoretically optimal sample complexity $\widetilde O(d^2)$, which is an improvement over prior work (Nichani et al., 2023) requiring a sample size of $\widetilde\Theta(d^4)$. Our proof proceeds by showing that during the initial stage of training the network performs feature learning to recover the feature $p$ with $\widetilde O(d^k)$ samples. This work demonstrates the ability of three-layer neural networks to learn complex features and as a result, learn a broad class of hierarchical functions.
Anonymous Url: I certify that there is no URL (e.g., github page) that could be used to find authors' identity.
No Acknowledgement Section: I certify that there is no acknowledgement section in this submission for double blind review.
Primary Area: learning theory
Submission Number: 6768