Robust Learning of A Group DRO Neuron
TL;DR: We propose an efficient primal-dual algorithm to learn a single neuron that achieves robust learning guarantees under label noise and group-level distributional shifts.
Abstract: We study the problem of learning a single neuron under standard squared loss in the presence of arbitrary label noise and group-level distributional shifts, for a broad family of covariate distributions. Our goal is to identify a "best-fit" neuron parameterized by ${\boldsymbol w}^{\star}$ that performs well under the most challenging reweighting of the groups. Specifically, we address a Group Distributionally Robust Optimization problem: given sample access to $K$ distinct distributions ${\mathcal p_{[1]}},\dots, {\mathcal p_{[K]}}$, we seek to approximate ${\boldsymbol w}^*$ that minimizes the worst-case objective over convex combinations of group distributions ${\boldsymbol \lambda} \in \Delta_K$, where the objective is $\sum_{i \in [K]}\lambda_{[i]},\mathbb E_{(\mathbf x,y)\sim{\mathcal p_{[i]}}}(\sigma(\boldsymbol w\cdot\boldsymbol x)-y)^2 - \nu d_f(\boldsymbol\lambda,\tfrac1K\boldsymbol1)$ and $d_f$ is an $f$-divergence that imposes (optional) penalty on deviations from uniform group weights, scaled by a parameter $\nu \geq 0$.
We develop a computationally efficient primal-dual algorithm that outputs a vector $\widehat{\boldsymbol w}$ that is constant-factor
competitive with ${\boldsymbol w}^{*}$ under the worst-case group weighting.
Our analytical framework directly confronts the inherent nonconvexity of the loss function, providing robust learning guarantees in the face of arbitrary label corruptions and group-specific distributional shifts. The implementation of the dual extrapolation update motivated by our algorithmic framework shows promise on LLM pre-training benchmarks.
Submission Number: 1729
Loading