$\alpha$-VAEs : Optimising variational inference by learning data-dependent divergence skewDownload PDF

Published: 15 Jun 2021, Last Modified: 05 May 2023INNF+ 2021 posterReaders: Everyone
Keywords: variational inference, variational autoencoders, constrained optimisation, lossy compression
TL;DR: Generalising VAE optimisation to competing loss terms based on learning skew in geometric jensen-shannon divergence improves reconstruction and automatically balances KL terms.
Abstract: The {\em skew-geometric Jensen-Shannon divergence} $\left(\textrm{JS}^{\textrm{G}_{\alpha}}\right)$ allows for an intuitive interpolation between forward and reverse Kullback-Leibler (KL) divergence based on the skew parameter $\alpha$. While the benefits of the skew in $\textrm{JS}^{\textrm{G}_{\alpha}}$ are clear---balancing forward/reverse KL in a comprehensible manner---the choice of optimal skew remains opaque and requires an expensive grid search. In this paper we introduce $\alpha$-VAEs, which extend the $\textrm{JS}^{\textrm{G}_{\alpha}}$ variational autoencoder by allowing for learnable, and therefore data-dependent, skew. We motivate the use of a parameterised skew in the dual divergence by analysing trends dependent on data complexity in synthetic examples. We also prove and discuss the dependency of the divergence minimum on the input data and encoder parameters, before empirically demonstrating that this dependency does not reduce to either direction of KL divergence for benchmark datasets. Finally, we demonstrate that optimised skew values consistently converge across a range of initial values and provide improved denoising and reconstruction properties. These render $\alpha$-VAEs an efficient and practical modelling choice across a range of tasks, datasets, and domains.
Questions/feedback Request For Reviewers: We have noted in the paper that we find practical optimisation of skew in the dual divergence to be well behaved, suggesting that training skew in the dual divergence is also well-posed and does not reduce to an invalid divergence or one of the KL directions - despite an equivalent proof of convexity via differentiation not being obvious, due to a mixture term being outside of the log. We are particularly interested in feedback on this line of work, including whether any convexity result (or related result) may be possible and whether subsequent convergence guarantees are possible.
4 Replies