Tracking the Median of Gradients with a Stochastic Proximal Point Method

TMLR Paper5211 Authors

26 Jun 2025 (modified: 07 Jul 2025)Under review for TMLREveryoneRevisionsBibTeXCC BY 4.0
Abstract: There are several applications of stochastic optimization where one can benefit from a robust estimate of the gradient. For example, domains such as distributed learning with corrupted nodes, the presence of large outliers in the training data, learning under privacy constraints, or even heavy-tailed noise due to the dynamics of the algorithm itself. Here we study SGD with robust gradient estimators based on estimating the median. We first derive iterative methods based on the stochastic proximal point method for computing the median gradient and generalizations thereof. Then we propose an algorithm estimating the median gradient across *iterations*, and find that several well known methods are particular cases of this framework. For instance, we observe that different forms of clipping allow to compute online estimators of the *median* of gradients, in contrast to (heavy-ball) momentum, which corresponds to an online estimator of the *mean*. Finally, we provide a theoretical framework for any algorithm computing the median gradient across *samples*, and show that the resulting method can converge even under heavy-tailed, state-dependent noise.
Submission Length: Long submission (more than 12 pages of main content)
Assigned Action Editor: ~Matthew_J._Holland1
Submission Number: 5211
Loading