Reliable Estimation of KL Divergence using a Discriminator in Reproducing Kernel Hilbert SpaceDownload PDF

Published: 09 Nov 2021, Last Modified: 05 May 2023NeurIPS 2021 SpotlightReaders: Everyone
Keywords: RKHS, KL Divergence, Sample Complexity, Statistical Learning Theory, Kernel Methods, Bayesian inference, Mutual Information, Reliable estimation
Abstract: Estimating Kullback–Leibler (KL) divergence from samples of two distributions is essential in many machine learning problems. Variational methods using neural network discriminator have been proposed to achieve this task in a scalable manner. However, we noticed that most of these methods using neural network discriminators suffer from high fluctuations (variance) in estimates and instability in training. In this paper, we look at this issue from statistical learning theory and function space complexity perspective to understand why this happens and how to solve it. We argue that the cause of these pathologies is lack of control over the complexity of the neural network discriminator function and could be mitigated by controlling it. To achieve this objective, we 1) present a novel construction of the discriminator in the Reproducing Kernel Hilbert Space (RKHS), 2) theoretically relate the error probability bound of the KL estimates to the complexity of the discriminator in the RKHS space, 3) present a scalable way to control the complexity (RKHS norm) of the discriminator for a reliable estimation of KL divergence, and 4) prove the consistency of the proposed estimator. In three different applications of KL divergence -- estimation of KL, estimation of mutual information and Variational Bayes -- we show that by controlling the complexity as developed in the theory, we are able to reduce the variance of KL estimates and stabilize the training.
Code Of Conduct: I certify that all co-authors of this work have read and commit to adhering to the NeurIPS Statement on Ethics, Fairness, Inclusivity, and Code of Conduct.
TL;DR: We propose a new way to construct neural function such that it lies on RKHS; and use this function as discriminator to compute KL divergence from samples. We provide theoretical insights and consistency guarantee of the KL estimator.
Supplementary Material: pdf
Code: https://github.com/sandeshgh/Reliable-KL-estimation
12 Replies

Loading