RASP Quadratures: Efficient Numerical Integration for High-Dimensional Mean-Field Variational Inference

21 Sept 2023 (modified: 11 Feb 2024)Submitted to ICLR 2024EveryoneRevisionsBibTeX
Primary Area: probabilistic methods (Bayesian methods, variational inference, sampling, UQ, etc.)
Code Of Ethics: I acknowledge that I and all co-authors of this work have read and commit to adhering to the ICLR Code of Ethics.
Keywords: quadrature, cubature, sigma points, loss topography, Hessian approximation, variational inference, quasi-Newton variational Bayes
Submission Guidelines: I certify that this submission complies with the submission instructions as described on https://iclr.cc/Conferences/2024/AuthorGuide.
TL;DR: RASP quadratures support variational inference by approximating locally-averaged gradients and Hessian-diagonals using only 3 gradient evaluations, eliminating errors from over half of all second-order terms.
Abstract: Efficient high-dimensional integration enables novel approaches to calibrate and control model uncertainty during training. In particular, the recently-proposed projective integral update formulation of variational inference derives model uncertainty from expectations that extract the local loss topography. Thus, we propose random-affinity sigma-point (RASP) quadratures, which are designed to eliminate integration errors from basis functions that drive Gaussian mean-field updates. Using only 3 gradient evaluations, RASP quadratures can extract locally-averaged gradients and Hessian diagonals from the loss, while eliminating errors from over half of all quadratic total-degree terms. Alternatively, we can use 6-point RASP quadratures to obtain 5th-order exactness in all univariate terms as well as 3rd-order exactness for two-thirds of bivariate terms. This work presents the design of RASP quadratures, theoretical guarantees on exactness, and analysis of expected errors. We also provide an open-source PyTorch implementation of RASP quadratures with quasi-Newton variational Bayes (QNVB), i.e. the projective integral update algorithm for Gaussian mean fields. Although RASP quadratures are designed to support QNVB, they are also compatible with other forms of variational inference, such as stochastic gradient variational Bayes (SGVB). Our experiments compare alternative integration schemes and training methods using three different learning tasks and architectures, demonstrating that efficient integration can improve generalizability for architectures with suitable loss structure.
Anonymous Url: I certify that there is no URL (e.g., github page) that could be used to find authors' identity.
No Acknowledgement Section: I certify that there is no acknowledgement section in this submission for double blind review.
Submission Number: 3761
Loading