Proximal Mean Field Learning in Shallow Neural Networks

Published: 07 Jan 2024, Last Modified: 07 Jan 2024Accepted by TMLREveryoneRevisionsBibTeX
Abstract: We propose a custom learning algorithm for shallow over-parameterized neural networks, i.e., networks with single hidden layer having infinite width. The infinite width of the hidden layer serves as an abstraction for the over-parameterization. Building on the recent mean field interpretations of learning dynamics in shallow neural networks, we realize mean field learning as a computational algorithm, rather than as an analytical tool. Specifically, we design a Sinkhorn regularized proximal algorithm to approximate the distributional flow for the learning dynamics over weighted point clouds. In this setting, a contractive fixed point recursion computes the time-varying weights, numerically realizing the interacting Wasserstein gradient flow of the parameter distribution supported over the neuronal ensemble. An appealing aspect of the proposed algorithm is that the measure-valued recursions allow meshless computation. We demonstrate the proposed computational framework of interacting weighted particle evolution on binary and multi-class classification. Our algorithm performs gradient descent of the free energy associated with the risk functional.
Submission Length: Long submission (more than 12 pages of main content)
Changes Since Last Submission: The final camera ready version submitted here includes the following changes with respect to the accepted version: - we changed all blue text highlights back to regular text in black. - in Appendix D, we included the additional plot Fig. 5 sent earlier to the Action Editor, better illustrating good quality of learnt approximates for the sinusoid using the proposed algorithm across randomized runs. This figure is referred to inline in Appendix D, first paragraph, last sentence. - added an Acknowledgement section right before the bibliography to acknowledge the grant support and the reviewers' feedback.
Code: https://github.com/zalexis12/Proximal-Mean-Field-Learning.git
Assigned Action Editor: ~Russell_Tsuchida1
License: Creative Commons Attribution 4.0 International (CC BY 4.0)
Submission Number: 1512
Loading