R2D2-Net: Shrinking Bayesian Neural Networks via R2D2 Prior

18 Sept 2023 (modified: 25 Mar 2024)ICLR 2024 Conference Withdrawn SubmissionEveryoneRevisionsBibTeX
Keywords: Bayesian deep learning, variable selection, global-local shrinkage prior, uncertainty estimation
TL;DR: We propose a novel Bayesian neural network design with the R2D2 prior which provides the best shrinkage performance to parameters while preserving predictability.
Abstract: Bayesian neural networks (BNNs) treat neural network weights as random variables, which aim to provide posterior uncertainty estimates and avoid overfitting by performing inference on the posterior weights. However, the selection of the appropriate prior distributions remains a challenging task, and BNNs may suffer from catastrophic inflated variance or poor predictive performance when poor choices are made. Previous BNN designs apply different priors to weights, but the behaviours of these priors make it difficult to sufficiently shrink noisy signals or easily overshrink important signals in the weights. To alleviate this problem, we propose a novel R2D2-Net, which imposes the $R^2$-induced Dirichlet Decomposition (R2D2) prior to the BNN weights. R2D2-Net can effectively shrink irrelevant coefficients towards zero, while preventing key features from over-shrinkage. To more accurately approximate the posterior distribution of weights, we further propose a variational Gibbs inference algorithm that combines the Gibbs updating procedure and gradient-based optimization. We also analyze the ELBO and develop analytical forms of the KL divergences of the shrinkage parameters. Empirical studies on image classification and uncertainty estimation tasks demonstrate that our proposed method outperforms the existing BNN designs with different priors, which indicates that the R2D2-Net can select more relevant variables for predictive tasks. On the other hand, we empirically show that the R2D2-Net yields relatively better predictive performance and smaller variance with the increase in neural network depth, which indicates that the R2D2-Net alleviates the catastrophic inflation of variance when BNNs are scaled. Codes are anonymously available at https://anonymous.4open.science/r/r2d2bnn-EF7D.
Primary Area: probabilistic methods (Bayesian methods, variational inference, sampling, UQ, etc.)
Code Of Ethics: I acknowledge that I and all co-authors of this work have read and commit to adhering to the ICLR Code of Ethics.
Submission Guidelines: I certify that this submission complies with the submission instructions as described on https://iclr.cc/Conferences/2024/AuthorGuide.
Anonymous Url: I certify that there is no URL (e.g., github page) that could be used to find authors' identity.
No Acknowledgement Section: I certify that there is no acknowledgement section in this submission for double blind review.
Submission Number: 1262
Loading