Sparsifying Bayesian neural networks with latent binary variables and normalizing flows

TMLR Paper2611 Authors

02 May 2024 (modified: 09 May 2024)Under review for TMLREveryoneRevisionsBibTeXCC BY-SA 4.0
Abstract: Artificial neural networks are powerful machine learning methods used in many modern applications. A common issue is that they have millions or billions of parameters, and therefore tend to overfit. Bayesian neural networks (BNN) can improve on this since they incorporate parameter uncertainty. Latent binary Bayesian neural networks (LBBNN) further take into account structural uncertainty by allowing the weights to be turned on or off, enabling inference in the joint space of weights and structures. Mean-field variational inference is typically used for computation within such models. In this paper, we will consider two extensions of variational inference for the LBBNN: Firstly, by using the local reparametrization trick (LCRT), we improve computational efficiency. Secondly, and more importantly, by using normalizing flows on the variational posterior distribution of the LBBNN parameters, we learn a more flexible variational posterior than the mean field Gaussian. Experimental results on real data show that this improves predictive power compared to using mean field variational inference on the LBBNN method, while also obtaining sparser networks. We also perform two simulation studies. In the first, we consider variable selection in a logistic regression setting, where the more flexible variational distribution improves results. In the second study, we compare predictive uncertainty based on data generated from two-dimensional Gaussian distributions. Here, we argue that our Bayesian methods lead to more realistic estimates of predictive uncertainty.
Submission Length: Long submission (more than 12 pages of main content)
Previous TMLR Submission Url: https://openreview.net/forum?id=t3OGrWRUve&referrer=%5BAuthor%20Console%5D(%2Fgroup%3Fid%3DTMLR%2FAuthors%23your-submissions)
Changes Since Last Submission: Reviewer K8JC (round 2): K8JC: ...I have three more significant outstanding issues and a few minor ones before I am ready to recommend acceptance. Authors: We thank the reviewer for thorough review. The major and minor comments are now fully addressed. See the responses below. We feel that working on the reviewer’s improved the paper for which we are grateful. Additionally, we further shortened section 4 and made it more clear. The revised text in the second round text is in teal color, while the first round revision text is kept in purple. Major: K8JC: I appreciate the authors' conducting experiments with HMC-sampled BNNs. However, they should omit the HMC experiments from the paper due to the compute limitations they mention in their rebuttal. This is because the HMC results in the paper are incomparable to the other results since the authors use a significantly weaker architecture. Consequently, the conclusion drawn from these experiments in the main text, which claims that "[HMC] was underperforming" (bottom of page 9), is invalid and misleading. Authors: We have removed HMCs for BNNs (BNN-HMC). We agree that they are not fitting the design of experiments in the way we had a single “treatment” effect of BNN-HMC compared to BNN (variational) due to a) reduced architectures in BNN-HMC and b) lack of convergence of HMC. K8JC: Similarly, I appreciate the Gaussian process regression experiments; however, there is no description of the model the authors used, and hence, the numbers they obtained are not interpretable. They should at least include what covariance function they used and whether they performed exact GP regression or opted for some approximation. Authors: On the regression datasets, we include an additional baseline for comparison, exact Gaussian processes, using an out-of-the-box version of the package from Varvia et al. (2023), using the Matérn 3/2 covariance function (see chapter 4 of Rasmussen (2003) for how this is defined.) For the mean function, we use the training data average. For the covariance function, the implementation defines three hyperparameters: the kernel variance and length scale, controlling the smoothness properties of the regression, and the error variance parameter, regulating how closely to fit the training data. We define these parameters to be 1, 5 and 0.1 respectively. K8JC: Every plot in the paper (Figures 5-10) is still pixelated. Please update them to a vector graphics format. Authors: We have converted Figures 5-10 to vector graphics, additionally, for Figures 6-10 we increased the grid resolution to have more smooth illustrations. Minor: K8JC: "Earlier work (Hubin & Storvik, 2019; Bai et al., 2020; Hubin & Storvik, 2024), have considered similar settings approaches, our main contributions are" - This newly added sentence sounds strange. Authors: We have rewritten it as “Similar approaches have been considered in (Hubin & Storvik, 2019; Bai et al., 2020; Hubin & Storvik, 2024). Our main contributions are” K8JC: I'm unsure what the authors now mean by "fully variational Bayesian model averaging." Do they mean that they are "fully variational" or "fully Bayesian"? I am not sure what "fully variational" would mean, and I would disagree with the claim that their method is "fully Bayesian" since they are doing variational inference and also do not integrate out the hyperparameters. Authors: We have everywhere (both in the text and in the tables) removed fully Bayesian and fully variational Bayesian terms and just say Bayesian and/or variational Bayesian. The term fully is now only used for fully connected neural networks. K8JC: "Due to a more proper procedure for handling uncertainty, the Bayesian approach does, in many cases, result in more reliable solutions with less overfitting and better uncertainty measures." - I'm not sure what the authors mean by better uncertainty measures; I am not aware of any mainstream frequentist (or other) approach that attempts to quantify parameter uncertainty other than the Bayesian approach. Authors: We have rewritten this to: “Incorporating parameter uncertainty can lead to less overfitting and better calibrated predictions, however this comes at the expense of extremely high computational costs.” K8JC: The starting sentence of section 2 sounds quite strange, as is called a function and a distribution in the same sentence. Authors: We completely agree. The text around formula (1) at the start of section 2 is rewritten now. K8JC: Eq (10) follows directly from applying the chain rule for relative entropies; there is no need to reference another paper. Authors: We removed the reference. K8JC: Please do not redirect to a website for a definition of hard-tanh; define it in the footnote. Authors: We defined it directly in the footnote at the bottom of page 7 K8JC: In Table 9, the first appearance of BNN-HMC should just read "BNN," I think. Authors: You are right. We fixed that.
Assigned Action Editor: ~Pierre_Alquier1
Submission Number: 2611
Loading