Keywords: energy networks, EBMs, structured prediction, convex analysis, Fenchel conjugates
TL;DR: We propose a theoretically-grounded family of loss functions for learning energy networks.
Abstract: Energy-based models, a.k.a. energy networks, perform inference by optimizing
an energy function, typically parametrized by a neural network.
This allows one to capture potentially complex relationships between inputs and
outputs.
To learn the parameters of the energy function, the solution to that
optimization problem is typically fed into a loss function.
The key challenge for training energy networks lies in computing loss gradients,
as this typically requires argmin/argmax differentiation.
In this paper, building upon a generalized notion of conjugate function,
which replaces the usual bilinear pairing with a general energy function,
we propose generalized Fenchel-Young losses, a natural loss construction for
learning energy networks. Our losses enjoy many desirable properties and their
gradients can be computed efficiently without argmin/argmax differentiation.
We also prove the calibration of their excess risk in the case of linear-concave
energies. We demonstrate our losses on multilabel classification and
imitation learning tasks.
Supplementary Material: pdf
16 Replies
Loading