Generalized and Optimal Straight-Through Estimators
TL;DR: Axiomatic approach to approximate chain rule estimators, describing a general class and a minimum variance family
Abstract: Modern ML models often utilize discrete components within their computational graphs, making training challenging.
In such cases, approximate-chain-rule gradient estimators can be applied. They work reasonably well but are obtained by combining diverse rationales with ad-hoc choices.
In this work, we propose a principled axiomatic approach to define a general family of gradient estimators and show that it subsumes many existing methods. Within this family, we derive optimal estimators with respect to a minimum variance criterion subject to interpretable bias-limiting constraints, addressing integer and one-hot categorical discrete variables.
We empirically demonstrate that our estimator can achieve a better bias-variance trade-off than existing ones on synthetic problems and outperforms them on training variational auto-encoders with discrete latent variables.
Submission Number: 1403
Loading