% \documentclass{uai2022} % for initial submission
\documentclass[accepted]{uai2022} % after acceptance, for a revised
                                    % version; also before submission to
                                    % see how the non-anonymous paper
                                    % would look like
%% There is a class option to choose the math font
% \documentclass[mathfont=ptmx]{uai2022} % ptmx math instead of Computer
                                         % Modern (has noticable issues)
% \documentclass[mathfont=newtx]{uai2022} % newtx fonts (improves upon
                                          % ptmx; less tested, no support)
% NOTE: Only keep *one* line above as appropriate, as it will be replaced
%       automatically for papers to be published. Do not make any other
%       change above this note for an accepted version.

%% Choose your variant of English; be consistent
\usepackage[american]{babel}
% \usepackage[british]{babel}

%% Some suggested packages, as needed:
\usepackage{natbib} % has a nice set of citation styles and commands
    % \bibliographystyle{plainnat}
    % \renewcommand{\bibsection}{\subsubsection*{References}}
    
    
% \usepackage[pdftex]{graphicx}
\usepackage{subcaption}
\usepackage[inline]{enumitem}
\usepackage[linesnumbered,ruled,vlined]{algorithm2e}
\usepackage{wrapfig}
\usepackage{xcolor}
\usepackage{amssymb}
\usepackage{tabularx}
\usepackage[page]{appendix}

\usepackage{xr-hyper}
\usepackage{hyperref}

\makeatletter
\newcommand*{\addFileDependency}[1]{% argument=file name and extension
  \typeout{(#1)}
  \@addtofilelist{#1}
  \IfFileExists{#1}{}{\typeout{No file #1.}}
}
\makeatother

\newcommand*{\myexternaldocument}[1]{%
    \externaldocument{#1}%
    \addFileDependency{#1.tex}%
    \addFileDependency{#1.aux}%
}

\myexternaldocument{menon_319-supp}

\usepackage{mathtools} % amsmath with fixes and additions
% \usepackage{siunitx} % for proper typesetting of numbers and units
\usepackage{booktabs} % commands to create good-looking tables
\usepackage{tikz} % nice language for creating drawings and diagrams

%% Provided macros
% \smaller: Because the class footnote size is essentially LaTeX's \small,
%           redefining \footnotesize, we provide the original \footnotesize
%           using this macro.
%           (Use only sparingly, e.g., in drawings, as it is quite small.)

%% Self-defined macros
\newcommand{\swap}[3][-]{#3#1#2} % just an example

\title{Forget-me-not! Contrastive Critics for Mitigating Posterior Collapse}

% The standard author block has changed for UAI 2022 to provide
% more space for long author lists and allow for complex affiliations
%
% All author information is authomatically removed by the class for the
% anonymous submission version of your paper, so you can already add your
% information below.
%
% Add authors

\author[1]{\href{mailto:<sachit.menon@columbia.edu>?Subject=Your UAI 2022 paper}{Sachit Menon}{}}
\author[1]{David Blei}
\author[1]{Carl Vondrick}
% Add affiliations after the authors
\affil[1]{%
    Computer Science Dept.\\
    Columbia University\\
    New York, New York, USA
}
\begin{document}
\maketitle

\begin{abstract}
Variational autoencoders (VAEs) suffer from posterior collapse, where the powerful neural networks used for modeling and inference optimize the objective without meaningfully using the latent representation. We introduce \emph{inference critics} that detect and incentivize against posterior collapse by requiring correspondence between latent variables and the observations. By connecting the critic’s objective to the literature in self-supervised contrastive representation learning, we show both theoretically and empirically that optimizing inference critics increases the mutual information between observations and latents, mitigating posterior collapse. This approach is straightforward to implement and requires significantly less training time than prior methods, yet obtains competitive results on three established datasets. Overall, the approach lays the foundation to bridge the previously disconnected frameworks of contrastive learning and probabilistic modeling with variational autoencoders, underscoring the benefits both communities may find at their intersection.%enabling better representations endowed with the strengths of both approaches
\end{abstract}

\section{Introduction}

% \cv{Throughout paper: Sometimes we use single quotes `critic' and other times double quotes ``critic'' -- we should be consistent}
    Variational autoencoders (VAEs) provide an integrated approach for simultaneously performing representation learning and generative modeling. Unlike other approaches, such as generative adversarial networks (GANs), VAEs marry the two steps of probabilistic machine learning -- inference and modeling -- into one framework. They have seen wide success in a number of applications, such as in vision, language, and drug discovery \citep{kingma_auto-encoding_2014, kingma_introduction_2019}. 
    
    VAEs posit a very general model,  where latent variables $\boldsymbol{z}$ give rise to the data $\textbf{x}$. The model thus defines the joint distribution $p_\theta(\textbf{x}, \textbf{z})$, which factorizes as $p(\textbf{z})p_\theta(\textbf{x}|\textbf{z})$. In this factorization, $p(\textbf{z})$ corresponds to a prior (for example, a spherical Gaussian), while $p_\theta(\textbf{x}|\textbf{z})$ defines an exponential family likelihood (usually a Gaussian) with natural parameter dependent on $\textbf{z}$. Much of the power of VAEs as generative models comes from how we define this dependence. Typically, we use the powerful function approximation afforded by neural networks to parametrize this relationship. 
    
    But in a VAE, the power of neural networks can also be its downfall. With a flexible likelihood, the model can learn to abandon the latents entirely.
    %, finding it easier to try to directly approximate the data distribution. In other words, the conditional distribution can produce reasonable samples that have no relationship to the latents used to produce them, 
    This allows the approximate posterior, which is also powered by a neural net, to exactly match the prior. This conspiracy of the inference network and the model network allows the VAE to achieve high values for its objective despite both networks forgetting their respective inputs. While we may achieve some generative modeling goals, this \emph{posterior collapse} phenomenon fails at the goal of representation learning \citep{bowman_generating_2016}. 
    % This is referred to as ``posterior collapse" and is well-studied in the VAE literature (see Related Work). %[CITE]
    % In this work, we propose a simple trick towards mitigating posterior collapse. Suppose we have a batch of latent samples $z_0, \ldots, z_k$, and a batch of corresponding samples from the approximated data distribution $x_0, \ldots, x_k$. If posterior collapse has occurred, corresponding samples are independent, since the likelihood is not using information from the latents - that is, $z_i \perp x_i$. Conversely, if we are able to `pair up' corresponding pairs, we have not collapsed. 
    
    This paper proposes a new approach to mitigate posterior collapse. The key idea is that we can use a %self-supervised \cv{I would remove self-supervised here. I'm not sure that aspect is integral.} 
    \emph{critic} to detect posterior collapse and directly incentivize against it. Consider a set of samples of latent variables and the corresponding observations. If posterior collapse has occurred, corresponding latent/observation are independent. The model is not using the latents, and the approximate posterior just produces independent samples from the prior. 
    %Since the corresponding samples are independent, it is not possible to establish matches between them.
    On the other hand, if we \textit{are} able to pair up corresponding pairs, they must share some information to allow us to do this, and there is no collapse. With this intuition, we create a critic to accomplish precisely this pairing and integrate it into the VAE objective. The critic constrains the neural network to preserve the mutual information between the latent variables and the observations. The resulting generative model must use the information in the data in its posterior of the latent variables. We call this \textit{forget-me-not regularization}.
    %\cv{Last sentence feels a bit awkward.}
    
    \begin{figure*}
        \setlength\fboxsep{1pt} 
        \centering
        \includegraphics[width=0.97\linewidth]{Copy of VAE_dist_embedding.drawio.pdf}
        \vspace{-1em}
        \caption{An illustration of the critic. On the left, we have the normal \textcolor[RGB]{172,90,83}{variational network} mapping observations to \textcolor[RGB]{108, 135, 85}{variational parameters} (distributions in green). On the right, we show a \textcolor[RGB]{44,65,245}{critic}'s task for a particular \textcolor[RGB]{126, 141, 166}{latent sample} - it must determine which of the blue arrows marks a true pair.}
        \label{fig:illustcritic}
        \vspace{-1em}
    \end{figure*}
  
   % \cv{This paragraph I think doesn't belong in the introduction. The previous paragraph should have described the key insight. You can maybe move this to the the method section to provide an overview of the approach.}
   % This is the intuition behind the method: we will encourage this pairing up with a `critic' that tries to accomplish exactly this. Given a set of observations and latents, it asks: which latent goes with each of these observations? In order to keep the critic happy (able to solve its task), the VAE must use the information in the data when constructing latent samples (Figure \ref{fig:critics}) - in other words, the critic does not allow the neural network to forget what it should be conditioning on. %the information in the latent when constructing samples from the approximated data distribution (on the model side, Figure \ref{fig:modelcritic}), or
    
    Inference critics introduce minimal computational overhead and are easy to train. Unlike other posterior collapse strategies (c.f.~ \citep{zhao_infovae_2018}), this critic does not require adversarial training. We are not trying to fool the critic and have it fail its task of distinguishing corresponding pairs (which would actually encourage posterior collapse). Rather, its loss serves as a regularization, biasing the VAE towards solutions where the latent meaningfully relates to its counterpart in observation space. Moreover, this approach  avoids the practical difficulties posed by the `KL annealing’ trick \citep{bowman_generating_2016}, and it does not require multiple experiments to determine a hyperparameter schedule. By connecting the critic to the recent advances in self-supervised contrastive learning \citep{oord_conditional_2016}, we show both theoretically and empirically that the inference critic corresponds to increasing the mutual information between the samples and the latents. 
    
    Experimental results on three standard datasets (across text and image modalities) show that the inference critic provides a robust strategy for mitigating posterior collapse. The approach is also practical, requiring only minimal computational overhead to the standard VAE. It provides significant efficiency gains over established collapse mitigation strategies while achieving competitive performance. %Since the correspondence between the samples and the latent variables is naturally known, the classifier and therefore the full framework remains unsupervised.
    
    Our contributions are summarized as: 
    \begin{enumerate*}[label=(\roman*)]
        \item We introduce forget-me-not regularization with inference critics, a self-supervised modification to standard VAEs that substantially reduces effects of posterior collapse.
        \item We show that this modified ELBO formulation directly incentivizes higher mutual information between observations and latents.
        \item We introduce three types of critic: a \textit{neural network critic}, which adds a third neural network to the VAE to act as the critic; a \textit{self critic}, which uses the existing networks to obtain a closed-form optimal solution to the auxiliary task; and a \textit{hybrid critic}, which shares some parameters with the variational network but not all. We contrast these and the effect they have on the final results.
        \item We demonstrate that the method adds less overhead computation time to the standard VAE than other methods for combating posterior collapse.
    \end{enumerate*}


\section{Posterior Collapse in VAEs}
\label{gen_inst}
\vspace{-0.5em}

\subsection{VAE Fundamentals}

    % \cv{People do not read papers from top to bottom. We should avoid pointers to other parts of the text. The reader may skip the introduction completely. Maybe you can move the second paragraph of the introduction to here (the one starting with ``VAEs posit a very general model structure...'')}
    To fit the parameters of a deep generative model, we would ideally maximize the marginal likelihood (the evidence) of the data. However, this is generally an intractable quantity as it involves integrating out the hidden variables.
% \begin{wrapfigure}{r}{0.6\textwidth}
% %\vspace{-1em}
%     \centering
%     \includegraphics[width=0.8\linewidth]{VAE Variational Family_Critic.png}
%     \caption{The forget-me-not regularization in reference to the graphical models acted on. Dashed lines show computation flow. The inference critic works against information loss in the variational family, trying to ensure it can tell $z_i$ and $x_i$ correspond. It does this by solving a classification task: among a batch of different repetitions of these variables, classify which belong together. This task can only be solved if there is shared information between variables belonging to the same subgraph (repeated by the plate) to distinguish them from those belonging to others.}
%     \label{fig:critics}
% \end{wrapfigure}
    Instead, the most common approach is to use variational inference, which allows us to posit a variational family and maximize a tractable lower bound, the ELBO, over its parameters. 
    
    Specifically, the VAE makes use of \textit{amortized} variational inference, which learns a function mapping observations to variational parameters, providing us an approximate posterior over latent variables given observations $q_\phi(\textbf{x}, \textbf{z})$. This function, usually parametrized by a neural network with parameters $\phi$, is shared across data points, hence the amortization. This mechanism for amortized inference is also called the `encoder' in analogy to deterministic autoencoders, with the model referred to as the `decoder'.
    
    There are many equivalent ways to write the ELBO \citep{hoffman_elbo_2016}. Here, we will focus on a couple that illustrate the problem we are addressing and motivate the approach we propose. First consider:
    \begin{equation}
        \begin{aligned}\label{eqn:elbo1}
        \mathrm{ELBO} =&\mathbb{E}_{p_{\mathcal{D}}(\boldsymbol{x})} \mathbb{E}_{q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x})}\left[\log p_{\theta}(\boldsymbol{x} \mid \boldsymbol{z})\right]]]\\ &-\mathrm{KL}\left(q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x}) \| p(\boldsymbol{z})\right)
        \end{aligned}    
    \end{equation}
    where $p_{\mathcal{D}}(\boldsymbol{x})$ is the empirical distribution of observations from the  dataset $\mathcal{D}$. The first term can be thought of as the model conditional likelihood (reconstruction), while the second is the KL divergence between the approximate posterior and the prior.
    
\subsection{Pitfalls in VAE Training}
    
    The form of the ELBO in Equation \ref{eqn:elbo1} illustrates one reason behind the phenomenon of posterior collapse. If the chosen parametrization of the likelihood is flexible enough to learn to always output (a good approximation of) the data distribution, there is no incentive to take a penalty for the second term: the model can keep the first term high even without letting the approximate posterior deviate from the prior. This is one reason behind the phenomenon of posterior collapse: the model does not need the latent code to maximize the likelihood and thus ignores it.%\footnote{We note that there are other hypotheses regarding factors that contribute to posterior collapse, such as pathologies in the optimization landscape of the ELBO or problems arising from maximum-marginal-likelihood optimization itself \citep{lucas_don_2019} even in the linear VAE/probabilistic PCA case. In this work we focus on the perspective arising from the flexibility afforded by deep neural nets.}
    
   %\textbf{The Mutual Information Problem:}  
   We can provide another expression for the ELBO that provides insight for this case. Consider the variational joint distribution
   \begin{align}\label{eqn:varjoint}
        q_{\phi}(\textbf{x}, \boldsymbol{z}) &= p_{\mathcal{D}}(\boldsymbol{x}) q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x})    
   \end{align}
   
   and aggregate posterior
    \begin{align}\label{eqn:aggpost}
        q_{\phi}(\boldsymbol{z}) &= \mathbb{E}_{p_{\mathcal{D}}(\boldsymbol{x})} q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x})
   \end{align}
   where $p_{\mathcal{D}}(\boldsymbol{x})$ is the data distribution.
   \citep{zhao_infovae_2018,hoffman_elbo_2016,dieng_avoiding_2019,tomczak_vae_2017}.
    % \begin{equation}\label{eqn:vardists}
    %     q_{\phi}(\textbf{x}, \boldsymbol{z})=p_{\mathcal{D}}(\boldsymbol{x}) q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x})
    %     \hspace{1.5cm}
    %     q_{\phi}(\boldsymbol{z})=\mathbb{E}_{p_{\mathcal{D}}(\boldsymbol{x})} q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x})
    % \end{equation}
    % \begin{equation}
    %     q_{\phi}(\boldsymbol{z})=\mathbb{E}_{p_{\mathcal{D}}(\boldsymbol{x})} q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x})
    % \end{equation}
    Hoffman and Johnson perform `ELBO surgery' \citep{hoffman_elbo_2016} to rearrange the ELBO from Equation \ref{eqn:elbo1} into the following:
    \begin{equation}\label{eqn:elbo2}
            \begin{aligned}
        \mathrm{ELBO} =&\mathbb{E}_{p_{\mathcal{D}}(\boldsymbol{x})} \mathbb{E}_{q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x})}\left[\log p_{\theta}(\boldsymbol{x} \mid \boldsymbol{z})\right]\\ &-\mathcal{I}_{q}(\boldsymbol{x}, \boldsymbol{z})-\mathrm{KL}\left(q_{\phi}(\boldsymbol{z}) \| p(\boldsymbol{z})\right)
        \end{aligned}
    \end{equation}
    where $\mathcal{I}_{q}$ is the mutual information over the variational family. (For a discussion of the variational joint MI $I_q(x;z)$ and the model joint MI $I_p(x;z)$, see Appendix \ref{appendix:mi}.)
    
    In other words, the KL divergence between the posterior and the prior decomposes into a mutual information (MI) penalty and a KL term that encourages matching the aggregate posterior and the prior. Maximizing the ELBO thus explicitly discourages high mutual information between observations and latents, pushing them towards independence. As \citep{hoffman_elbo_2016} show, decreasing the value of the MI does not impact the likelihood term, which tends to dominate. As such, the objective enables the flexible neural nets to achieve solutions exhibiting posterior collapse.
    
    This loss of information means that the latents will not be informative about the observations (and thus cannot be useful representations). In the next section, we introduce a method to explicitly prevent this information loss or `forgetting'. By incorporating a regularization term that incentivizes higher MI into the ELBO, we will counteract the effect of the MI penalty from Equation \ref{eqn:elbo2}.
    
    % In order to reduce the harmful effects of this penalty, an appealing idea is to match the marginals directly. Direct marginal matching would replace the usual KL term with $\mathrm{KL}\left(q_{\phi}(\boldsymbol{z}) \| p(\boldsymbol{z})\right)$, which no longer imposes an MI penalty. While there is promising work in this direction, this is difficult as existing methods resort to adversarial training, Stein variational gradient descent, or difficult-to-tune kernel methods \citep{zhao_infovae_2018}. Instead, in the next section, we introduce a method to explicitly prevent this information loss.
  
\section{Forget-me-not Regularization} \label{section:method}
    % \section{Auxiliary Classification Regularization}
    
    % Our main observation is that the training process of the classic VAE provides everything we need to ensure information is shared between $\textbf{x}$ and $\textbf{z}$. 
    We will employ a critic that imposes a penalty on the objective if observations and their corresponding latents cannot be distinguished from non-corresponding pairs. The intuition is that this matching is only possible if there is information shared between observations and latents. %We will
    %implement this critic in a fully differentiable way and
    %incorporate the penalty of the critic into the training objective, thus using its signal to mitigate posterior collapse.  %Consider a joint distribution $q(\textbf{x}, \textbf{z})$ coming from the empirical data distribution $x_0, \ldots, x_k$ and their corresponding latent samples $z_0, \ldots, z_k$. If we are able to distinguish between this distribution and a noise distribution, then there must be some information shared between $\textbf{x}$ and $\textbf{z}$.
    
    \subsection{The Inference Critic}
    Consider a batch of samples from the empirical data distribution $x_0, \ldots, x_k$, and a corresponding batch of latent samples $z_0, \ldots, z_k$ (by encoding the $x_i$). Every pair with the correct correspondence comes from the  variational joint distribution $q(\textbf{x}, \textbf{z})$, while non-corresponding pairs are independent and come from the product of marginals ($\textbf{z}$ via ancestral sampling) \citep{alemi_fixing_2018}. %Similarly, we can tell an analogous story for a batch of observations from the empirical data distribution and the corresponding samples from the variational posterior. 
    Formally:
    % Suppose we obtain a batch of latent samples $z_0, \ldots, z_k$, and a batch of corresponding observations $x_0, \ldots, x_k$ (either by sampling $z$ from the prior and $x$ from the approximate data distribution, or by sampling ). Then we can view every corresponding pair as coming from the joint distribution $p(\textbf{x}, \textbf{z})$, while non-corresponding pairs are independent and come from the product of marginals ($\textbf{x}$ via ancestral sampling). We can write this as the following:
    % might need to replace prior with posterior/encoding
    \begin{equation}\label{eqn:sampling}
        % (z_i, x_j) \sim 
    %     \begin{cases} 
    %       p_\theta(\textbf{z}, \textbf{x}) & i = j \\
    %       p(\textbf{z})p_\theta(\textbf{x}) & i \neq j
    %   \end{cases}
    %   \hspace{0.75cm} \text{or} \hspace{0.75cm}
       (z_i, x_j) \sim 
        \begin{cases} 
          q_\phi(\textbf{z}, \textbf{x}) & i = j \\
          q_\phi(\textbf{z})p_\mathcal{D}(\textbf{x}) & i \neq j
       \end{cases}
    \end{equation}
%     \[ (z_i, x_j) \sim 
%     \begin{cases} 
%       p_\theta(\textbf{z}, \textbf{x}) & i = j \\
%       p(\textbf{z})p_\theta(\textbf{x}) & i \neq j
%   \end{cases}
%   \hspace{0.75cm} \text{or} \hspace{0.75cm}
%   (z_i, x_j) \sim 
%     \begin{cases} 
%       q_\phi(\textbf{z}, \textbf{x}) & i = j \\
%       q_\phi(\textbf{z})p_\mathcal{D}(\textbf{x}) & i \neq j
%   \end{cases}
%     \]\label{eqn:sampling}
    %where the left corresponds to the former and the right the latter. We will follow the story of pairing up samples for the left for notational purposes, but keep both in mind.

    If we can distinguish the joint distribution from the product of marginals, there must be some dependence between $\textbf{x}$ and $\textbf{z}$. Given samples from both distributions, the critic will try to pick out which pairs belong to which distribution. The more successful it is, the more different the distributions must be, and therefore the more $\textbf{x}$ and $\textbf{z}$ must be related. 
    
    % One idea might be to train a binary (Bernoulli likelihood) probabilistic classifier to distinguish between corresponding and non-corresponding pairs. That is, given a pair $(z_i, x_j)$, yield the indicator corresponding to whether $i=j$ or not. Intuitively, the better we can solve this problem the more we can tell samples belong together; there is more information shared between $\textbf{z}$ and $\textbf{x}$ (if they did not share information, we would not be able to determine they belong together). If we augment the ELBO loss term to tell the VAE, `you must satisfy this critic,' we are encouraging it to preserve information-sharing between the variables - to not forget. Please see Appendix \ref{appendix:binclas} for discussion of the more formal connection to MI.
    
    % However, we may not be using all the information at our disposal by treating each pair in isolation (this is informal, please see Appendix \ref{appendix:multirat} for the statistical reasoning). Rather,
    The classifier, which we will denote $f$, needs to distinguish between the joint and the product of marginals by directly \textit{contrasting} the correct pairings from the incorrect ones.
    We know that for every observation $x_i$, there is one latent in the batch $z_i$ that corresponds to it (and vice versa). We  also know that the others \textit{should not}. %Thus instead of binary classification (does this pair correspond or not?), we might consider multi-sample classification (which of these $K$ options is the corresponding counterpart?). 
    This corresponds to softmax classification \citep{bishop_pattern_2006}, or a probabilistic classifier with a categorical likelihood. %the generalization of the Bernoulli likelihood we used before to the categorical/`multinoulli' - that it has similar properties should not surprise us. 
    $f$ maximizes the objective: 
    \begin{equation}\label{eqn:classification}
        c(\textbf{x},\textbf{z}) = \mathbb{E} \left[\log \frac{f\left(x^{+}, z^{+}\right)}{\sum_{x \in X } f\left(x, z^{+}\right)}\right]
    \end{equation}
    % \textcolor{red}{TODO: fix notation}
    which is the critic's expected value for the true corresponding pairs (denoted by the $+$) relative to all the other (non-corresponding) pairs, across pairs.\footnote{Note writing $f(x,z)$ as $\exp(f_0(x,z))$ recovers the softmax classifier exactly.} (The notation within the sum is in reference to a particular positive $z^{+}$; we can think of it as considering a particular $z^{+}$ and trying to find the associated $x^{+}$ among the options for $x$, then taking the expectation over $z$. This is symmetric with respect to choosing an $x^{+}$ and finding the associated $z^{+}$.) %This behavior resembles the binary classifier we discussed before.
    
    \begin{algorithm}[h!]
      \KwIn{Dataset $\mathcal{D}$, batch size $K$, initial VAE parameters $\theta$, initial critic $f$ parametrized by $\psi$, regularization weight $\lambda$}
    %   \KwOut{Your output \cv{VAE parameters $\theta$?}}
    %   \KwData{Testing set $x$}
      \While{not converged}
       {
       		Sample from $p_\mathcal{D}(\textbf{x})$ $K$ times to obtain $(x^{(i)})_{i=1}^K$
       		
       		Sample $z^{(i)}\sim q_\phi(\textbf{z}|x^{(i)}\ \forall i\in \{1,\ldots,K\}$
       		
       		Compute $\mathcal{L}_0 = \sum_{i=1}^K \text{ELBO}(\theta, \phi, x^{(i)}, z^{(i)})$ per Eqn \ref{eqn:elbo1} \tcp*{standard minibatch ELBO}
       		
       		Compute $f_\psi (x_i,z_j) \forall i,j$ \tcp*{inference critic values}
       		
       		$\mathcal{L}_1 \leftarrow c(\textbf{x},\textbf{z})$ per Eqn \ref{eqn:classification} \tcp*{inference critic objective}
       		
       		$\mathcal{L} \leftarrow \mathcal{L}_0 + \lambda*\mathcal{L}_1$
       		
       		Perform gradient update for VAE parameters %$\theta \leftarrow \theta + \nabla_\theta \mathcal{L}$
       		
       		Perform gradient update for critic parameters %$\psi \leftarrow \psi + \nabla_\psi \mathcal{L}$
       }
    \caption{Forget-me-not regularization with neural network inference critic.}\label{alg:nncrit}
    \end{algorithm}
    
    %(\textcolor{red}{NOTE: the optimal function for \ref{eqn:classification} is proportional to it}) , 
    Crucially, this objective constitutes a lower bound of the mutual information, shown by \citet{oord_representation_2019} in the context of self-supervised representation learning. By maximizing Equation \ref{eqn:classification}, the classifier approximates the density ratio between the joint distribution and product of marginals \citep{oord_representation_2019,song_multi-label_2020}, which is precisely the ratio appearing in the mutual information.
    %and we show the full proof in Appendix \ref{appendix:multimi}.
    Therefore, by optimizing the parameters of the VAE with this objective added to the ELBO as a regularizer, we push up on this lower bound. This push increases the MI between the latents and the observations, effectively mitigating posterior collapse as desired.
    
    This critic establishes a tight connection to the contrastive learning literature, in particular the InfoNCE loss from Contrastive Predictive Coding (CPC) \citep{oord_representation_2019}, enabling the many advances in self-supervised learning to be applied to VAEs. See Appendix \ref{appendix:multimi} for details of how this connection and bound apply. Note that, in the MI estimation literature, classifiers that estimate the mutual information like CPC are sometimes also referred to as critics \citep{poole_variational_2019}. We intentionally overload this word here: the inference critics of this paper critique the variational inference optimization.
    
    % The key towards realizing this is observing that the problem we have constructed and its associated objective correspond to the contrastive learning problem in self-supervised representation learning -- in particular the InfoNCE loss from Contrastive Predictive Coding (CPC) \citep{oord_representation_2019} (see Appendix \ref{appendix:multimi} for details of how this connection and bound apply in our case). Contrastive learners learn representations via a pretext task of pairing up samples that  and %Contrastive learners aim to learn good representations by distinguishing them from noise, increasing the MI between representations and data. However, their  Their objectives can often be interpreted as estimates of the mutual information \citep{oord_representation_2019}. %The InfoNCE estimate of mutual information is known to underestimate the MI when the true MI is high \citep{poole_variational_2019} - however, this is not a problem in our situation as the true MI is close to zero when we are fighting posterior collapse. 
    
   
%   Carl says: the paragraph you are writing is good, BUT make the topic sentence the point you want to establish. I think the point you want to establish is that your framework lets you transfer the contrastive learning results into VAEs.
%   should this be the first sentence or last sentence of this paragraph after establishing what contrastive learners are? or say we can take those techniques and then explain what they're doing in that world?
   
   
%   Make the point first. Then second sentence reminds the reader about contrastive learners are.
%   there are two points to be made about contrastive: 1) we get the bound by connecting to that and 2) having made the connection, we can get lots of free goodies. I think 2 follows 1 but they are distinct
%   got it 
   
%   Yup, I agree. And make each point a different paragraph.
   
%   check out the last paragraph of related work - not sure how much belongs here vs there re: what we get out of contrastive connection
   
%   That paragraph is great. I think you can keep it there. IN the method section, it is probably safer not to speculate that everything transfers. 
   
%   so here just MI bound connection and h
   
%   I think so. the MI bound is critical to justifying why your method works. 
   
    % Given a representation $z_i = f(x_i)$ of a data point $x_i$, $z_i$ is a good representation if we can tell it corresponds to $x_i$ and not some other data point. The original CPC applies this idea to learning representations for speaker recognition from audio: their different data points $x_i$ correspond to different samples of audio, and they obtain $z_i$ via an autoregressive sequence prediction model applied to the data; the goal is to increase MI to get $z_i$ that represent $x_i$ well, and there is no decoder. But the bound works more generally than this setting - we just need corresponding (and non-corresponding) samples on two sides of a conditional distribution! Please refer to Appendix \ref{appendix:multimi} for details on using this connection in our case to obtain the MI bound, showing that optimizing this objective does indeed incentivize higher MI. We point out that in the MI estimation literature, classifiers that estimate the mutual information like CPC are sometimes also referred to as critics \citep{poole_variational_2019}; we intentionally overload this word here, as our inference critics critique our variational inference optimization. \textcolor{red}{Sachit: worth including or not?} %(See the Related Work for other examples of application of the InfoNCE MI bound.) 
    
    \subsection{Regularization}
    
    
    % The properties of the critic provide a new lower bound on the evidence. 
    In order to prevent posterior collapse, we want to maximize the mutual information between the latents and the samples. We integrate the penalty of the critic, which aims to maximize $c(\textbf{x}, \textbf{z})$, to the ELBO:
    \begin{equation}\label{eqn:elbo-ours}
        \begin{aligned}
        \mathrm{ELBO_{CRITIC}} &=\mathbb{E}_{p_{\mathcal{D}}(\boldsymbol{x})} \mathbb{E}_{q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x})}\left[\log p_{\theta}(\boldsymbol{x} \mid \boldsymbol{z})\right] \\&-\mathrm{KL}\left(q_{\phi}(\boldsymbol{z} \mid \boldsymbol{x}) \| p(\boldsymbol{z})\right) + c(\textbf{x}, \textbf{z})
        \end{aligned}    
    \end{equation}
    % where $\lambda$ is a scalar hyperparameter. (We find optimization is not very sensitive to this value, and opt to simply set it to 1 in all experiments. %\textcolor{red}{Sachit: is it worth explicitly pointing out why this is still a lower bound?}
    We optimize the parameters for the variational family, model, and critic jointly; the critic prevents the usual conspiracy between the model and variational family over the course of training, avoiding collapse.     Algorithm \ref{alg:nncrit} illustrates the training procedure.
    
    %Note that we have wrapped up the parameters of both the variational family and model into $\theta$ for notational convenience.
    \begin{figure}%{0.6\textwidth}
        % \vspace{-2em}
        \centering
        \includegraphics[width=0.9\linewidth]{VAE Variational Family_Critic.pdf}
        \caption{Forget-me-not regularization in reference to the graphical models acted on. Dashed lines show computation flow. The inference critic works against information loss in the variational family, trying to ensure it can tell $z_i$ and $x_i$ correspond. It does this by solving a classification task: among a batch of different repetitions of these variables, classify which belong together. This task can only be solved if there is shared information between variables belonging to the same subgraph (repeated by the plate) to distinguish them from those belonging to others. \vspace{-1em}}
        \label{fig:critics}
    \end{figure}
    
    Notice that the mutual information appears in both the KL-divergence term and the critic term. When the true mutual information is equal to our estimate, this corresponds to matching the marginals for $z$ and the mutual information term from Equation \ref{eqn:elbo2} is cancelled out by the regularization. This approach does not need to resort to adversarial training techniques that are difficult to use in practice. Instead, the critic provides a straightforward mechanism to mitigate posterior collapse by causing an increase in the mutual information term. 
    %We point out that in the case where the true mutual information is equal to our estimate, this corresponds to matching the marginals for $z$, as the mutual information term from Equation \ref{eqn:elbo2} cancels out - without resorting to difficult to use techniques like adversarial training.
    %The estimate is accurate when 1) the lower bound on the mutual information is tight and 2) 
    %\sim p(\textbf{z})
    %\sim p(\textbf{x}|z_0), \ldots p(\textbf{x}|z_k)
    
    We apply this regularization across the variational family - an \textit{inference} critic, illustrated in Figure \ref{fig:critics}. %This immediately leads to the idea of applying the same technique across the model, as a \textit{model critic}. Such a critic would solve the same task with observations and latents across the model conditional rather than the approximate posterior conditional. 
    See Appendix \ref{appendix:critics} for discussion of its potential application across the model. % TODO: add `Algorithm' and refer to it 
    Furthermore, while here we consider the problem of posterior collapse by examining the ELBO for the VAE, posterior collapse is a more general phenomenon, observed in many different models when using amortized variational inference with deep neural networks. In theory, this approach would work on any such problem with amortized inference. 
    %We here focus on VAEs per previous work, as they demonstrate severe collapse even with a simple model structure. %We roughly follow \citep{zhao_infovae_2018} and \citep{dieng_avoiding_2019} for preliminaries.
    
    
    % \begin{algorithm}[t]\label{alg:nncrit}
    %   \KwIn{Dataset $\mathcal{D}$, batch size $K$, initial VAE parameters $\theta$, initial critic $f$ parametrized by $\psi$, regularization weight $\lambda$}
    %   \KwOut{Your output}
    % %   \KwData{Testing set $x$}
    %   \While{not converged}
    %   {
    %   		Sample from $p_\mathcal{D}(\textbf{x})$ $K$ times to obtain $(x^{(i)})_{i=1}^K$
       		
    %   		Sample $z^{(i)}\sim q_\phi(\textbf{z}|x^{(i)}\ \forall i\in \{1,\ldots,K\}$
       		
    %   		Compute $\mathcal{L}_0 = \sum_{i=1}^K \text{ELBO}(\theta, \phi, x^{(i)}, z^{(i)})$ per Eqn \ref{eqn:elbo1} \tcp*{standard minibatch ELBO}
       		
    %   		Compute $f_\psi (x_i,z_j) \forall i,j$ \tcp*{inference critic values}
       		
    %   		$\mathcal{L}_1 \leftarrow c(\textbf{x},\textbf{z})$ per Eqn \ref{eqn:classification} \tcp*{inference critic objective}
       		
    %   		$\mathcal{L} \leftarrow \mathcal{L}_0 + \lambda*\mathcal{L}_1$
       		
    %   		Perform gradient update for VAE parameters %$\theta \leftarrow \theta + \nabla_\theta \mathcal{L}$
       		
    %   		Perform gradient update for critic parameters %$\psi \leftarrow \psi + \nabla_\psi \mathcal{L}$
    %   }
    % \caption{Forget-me-not regularization with neural network inference critic.}
    % \end{algorithm}
    

    
    % Applying this bound across the model in the VAE (Equation \ref{eqn:sampling}, left) - a \textit{model critic} - encourages higher MI $I_p(x;z)$ by matching samples that should correspond under $p_\theta(x|z)$, Figure \ref{fig:modelcritic}. The corresponding idea of an \textit{inference critic} for $q_\phi(z|x)$ instead encourages higher $I_q(x;z)$, Figure \ref{fig:infcritic}.\footnote{In this work we explore using one or the other, but it is easy to imagine combining them.} Appendix \ref{appendix:critics} discusses the differences.

    \subsection{Types of Inference Critics}
    
    This framework affords multiple types of critics (as forms for $f$ in Equation \ref{eqn:classification}) that correspond to the mutual information between the latents and the samples. We propose three types of inference critics. 
    
    The most straightforward critic is a \textbf{neural network critic}. This critic uses a third network, entirely separate from those used in the VAE, to implement the critic. The choice of the network design depends on the structure of the modality. For example, we could use an embedding layer followed by an LSTM for text data. 
    
    We also propose a \textbf{self-critic} that uses the variational family as its own critic, providing the tighest estimate of the mutual information. This originates from an idea in noise contrastive estimation: if we have a tractable conditional (as is the case with the variational family in the VAE), we can directly use it to estimate the mutual information, and in particular the tightest estimate of the MI will be found by using $f(x,z) = \log q(z|x)$  \citep{poole_variational_2019}. This formulation incentivizes the log-likelihood under the conditional to be highest for samples that actually belong together, as we would desire. This critic also has the advantage of not requiring additional parameters.
    
    The \textbf{hybrid critic}, as the name suggests, is in-between the self-critic and the neural network critic. Rather than use an entirely separate neural network, this critic shares early layers with the variational network, after which point it has its own parameters. This approach can be well-suited to text data, where we may wish to share the embedding weights between the variational network and critic but have them separate past that. This presents a way to compromise between using no additional parameters (as in the self-critic) and using a whole additional network's worth of parameters (as in the neural network critic). 
    
    %In the previous section, we have introduced forget-me-not regularization implemented with an inference critic.
    
    %Here we will discuss the three types of inference critics we propose. In addition to the neural network critic, we propose two more types of inference critics: the \textit{self-critic} and the \textit{hybrid critic}. 
    
   % The neural network critic is straightforward: it uses a third neural network, entirely separate from those used in the VAE, to implement the critic. 
    
   % The self-critic uses the variational family as its own critic, providing the tighest estimate of the mutual information. This originates from an idea in noise contrastive estimation: if we have a tractable conditional (as is the case with the variational family in the VAE), we can directly use it to estimate the mutual information, and in particular the tightest estimate of the MI will be found by using $f(x,z) = \log q(z|x)$  \citep{poole_variational_2019}. This formulation incentivizes the the log-likelihood under the conditional to be highest for samples that actually belong together, as we would desire. This critic also has the advantage of not requiring additional parameters.
    
   %The hybrid critic, as the name suggests, is in-between the self-critic and the neural network critic. Rather than use an entirely separate neural network, this critic shares early layers with the variational network, after which point it has its own parameters. This approach can be well-suited to text data, where we may wish to share the text embedding weights between the variational network and critic but have them separate past that. This presents a way to compromise between using no additional parameters and using a whole additional network's worth of parameters. 
    
   %We will observe some general differences the results these critics lead to in the next section. 
    
    % \subsubsection{The Optimal Critic}
    
    % Having noted the connection between the new objective and Noise Contrastive Estimation, we can reach into 

 \begin{table*}[h!]
        \centering
        \centerline{
        \scalebox{0.9}{
        \begin{tabular}{lcccc|ccccc}
        \hline & & \multicolumn{2}{c} { Yahoo } & & & \multicolumn{2}{c} { Yelp } & & \\
        Model & NLL & KL & MI ($I_q$) & AU & NLL & KL & MI ($I_q$) & AU \\
        % \hline & & \multicolumn{4}{c} { \textbf{Standard VAE Objective} } \\
        \hline VAE & 
        328.9 (0.1) & 0.0 (0.0) & 0.0 (0.0) & 0.0 (0.0) &
        358.3 (0.2) & 0.0 (0.0) & 0.0 (0.0) & 0.0 (0.0)\\
        SA-VAE \citep{kim_semi-amortized_2018} & 
        329.2 (0.2) & 0.1 (0.0) & 0.1 (0.0) & 0.8 (0.4) &
        357.8 (0.2) & 0.3 (0.1) & 0.3 (0.0) & 1.0 (0.0) \\
        Skip-VAE \citep{dieng_avoiding_2019} & 
        328.7 (0.3)  & 0.22 (0.1) & 0.0 (0.0) & 7.0 (0.6) &
        358.1 (0.3) & 0.15 (0.0) & 0.0 (0.0) & 4.6 (0.5) \\
        Lagging-VAE \citep{he_lagging_2019} & 
        \textbf{328.2} (0.2) & 5.6 (0.2) & 3.0 (0.0) & 8.0 (0.0) &
        \textbf{356.9} (0.2) & 3.4 (0.3) & 2.4 (0.1) & 7.4 (1.3) &\\
        VAE + Inference Critic (Self) & 
        328.7 (0.2) & 3.6 (0.1) & 2.6 (0.0) & 3.0 (0.0) &
        358.2 (0.2) & 3.8 (0.1) & 2.7 (0.0) & 3.0 (0.0) &\\
        VAE + Inference Critic (Hybrid) & 
        \textbf{328.2} (0.1) & 4.3 (0.1) & 2.8 (0.0) & \textbf{11.0} (0.4)&
        357.7 (0.2) & 4.0 (0.1) & 2.8 (0.0) & 7.0 (0.0) &\\
        % VAE + Model Critic & \textbf{328.1} & 0.0 & 0.0 & 0.0 \\
        VAE + Inference Critic (NN) & 
        338.9 (0.4) & \textbf{17.5} (1.1) & \textbf{3.3} (0.0) & 8.0 (1.0) & 
        370.5 (0.5) & \textbf{18.6} (1.8) & \textbf{3.2} (0.1) & \textbf{12.0} (2.0)& 
        \end{tabular}
        }
        }
        \caption{Quantitative results on the Yahoo and Yelp  text corpora.   Each critic improves on collapse metrics when added to the standard VAE with no other changes. Results for comparison without KL annealing were referenced from \citet{he_lagging_2019} or re-implemented in the same framework and are averages of 5 runs, with standard deviation given in parentheses. (We follow the training details in the aforementioned methods, running until convergence is achieved on the validation ELBO.)}
        \label{table:textresults}
  %  \end{table}
   %\begin{table}[ht!]
        \centering
        \begin{tabular}{lcccc}
        \hline & & \multicolumn{2}{c} { Omniglot } &  \\
        Model & NLL & KL & MI ($I_q$) & AU \\
        % \hline & & \multicolumn{4}{c} { \textbf{Standard VAE Objective} } \\
        \hline VAE & 
        89.41 (0.04) & 1.51 (0.05) & 1.43 (0.07) & 3.0 (0.0)\\
        SA-VAE \citep{kim_semi-amortized_2018} & 
        89.29 (0.04) & 2.55 (0.05) & 2.20 (0.03) & 4.0 (0.0)\\
        Skip-VAE \citep{dieng_avoiding_2019} & 
        89.41 (0.05) & 1.75 (0.20) & 1.61 (0.10) & 3.0 (0.4)\\
        Lagging-VAE \citep{he_lagging_2019} & 
        \textbf{89.05} (0.05) & 2.51 (0.14) & 2.19 (0.08) & 5.6 (0.5) \\
        VAE + Inference Critic (Self) & 
        89.18 (0.04) & 6.30 (0.12) & 3.75 (0.04) & 11.0 (1.0)\\
        VAE + Inference Critic (Hybrid) & 
        89.16 (0.04) & 6.41 (0.15) & 3.78 (0.03) & 13.0 (0.7) \\
        % VAE + Model Critic & \textbf{328.1} & 0.0 & 0.0 & 0.0 \\
        VAE + Inference Critic (NN) & 
        89.24 (0.05) & \textbf{7.66} (0.14) & \textbf{3.82} (0.04) & \textbf{28.0} (0.0)
        \end{tabular}
        \caption{Quantitative results on the Omniglot image dataset. We find each critic improves on collapse metrics when added to the standard VAE with no other changes. Results for comparison without KL annealing were referenced from \citep{he_lagging_2019} or re-implemented in the same framework and are averages of 5 runs, with standard deviation given in parentheses. (We follow the training details in the aforementioned methods, running until convergence is achieved on the validation ELBO.)}
        \label{table:imgresults}
        %\vspace{-1em}
    \end{table*}
    
% \begin{table*}[h!]
%         \centering
%         \centerline{
%         \begin{tabular}{lcccc|ccccc}
%         \hline & & \multicolumn{2}{c} { Yahoo } & & & \multicolumn{2}{c} { Yelp } & & \\
%         Model & NLL & KL & MI ($I_q$) & AU & NLL & KL & MI ($I_q$) & AU \\
%         % \hline & & \multicolumn{4}{c} { \textbf{Standard VAE Objective} } \\
%         \hline VAE & 
%         328.9 & 0.0 & 0.0 & 0.0 &
%         358.3 & 0.0 & 0.0 & 0.0 \\
%         SA-VAE \citep{kim_semi-amortized_2018} & 
%         329.2 & 0.1 & 0.1 & 0.8 &
%         357.8 & 0.3 & 0.3 & 1 \\
%         Skip-VAE \citep{dieng_avoiding_2019} & 
%         328.7 & 0.22 & 0.0 & 7.0 &
%         358.1 & 0.15 & 0.0 & 4.6 \\
%         Lagging-VAE \citep{he_lagging_2019} & 
%         \textbf{328.2} & 5.6 & 3.0 & 8.0 &
%         \textbf{356.9} & 3.4 & 2.4 & 7.4 &\\
%         VAE + Inference Critic (Self) & 
%         328.7 & 3.6 & 2.6 & 3.0 &
%         358.2 & 3.8 & 2.7 & 3.0 &\\
%         VAE + Inference Critic (Hybrid) & 
%         \textbf{328.2} & 4.3 & 2.8 & \textbf{11.0} &
%         357.7 & 4.0 & 2.8 & 7.0 &\\
%         % VAE + Model Critic & \textbf{328.1} & 0.0 & 0.0 & 0.0 \\
%         VAE + Inference Critic (NN) & 
%         338.9 & \textbf{17.5} & \textbf{3.3} & 8.0 & 
%         370.5 & \textbf{18.6} & \textbf{3.2} & \textbf{12.0} & 
%         \end{tabular}
%         }
%         \caption{Quantitative results on the Yahoo and Yelp  text corpora.   Each critic improves on collapse metrics when added to the standard VAE with no other changes. Results for comparison without KL annealing were referenced from \citet{he_lagging_2019} or re-implemented in the same framework and are averages of 5 runs. (We follow the training details in the aforementioned methods, running until convergence is achieved on the validation ELBO.)}
%         \label{table:textresults}
%   %  \end{table}
%   %\begin{table}[ht!]
%         \centering
%         \begin{tabular}{lcccc}
%         \hline & & \multicolumn{2}{c} { Omniglot } &  \\
%         Model & NLL & KL & MI ($I_q$) & AU \\
%         % \hline & & \multicolumn{4}{c} { \textbf{Standard VAE Objective} } \\
%         \hline VAE & 
%         89.41 & 1.51 & 1.43 & 3.0 \\
%         SA-VAE \citep{kim_semi-amortized_2018} & 
%         89.29 & 2.55 & 2.20 & 4.0 \\
%         Skip-VAE \citep{dieng_avoiding_2019} & 
%         89.41 & 1.75 & 1.61 & 3.0 \\
%         Lagging-VAE \citep{he_lagging_2019} & 
%         \textbf{89.05} & 2.51 & 2.19 & 5.6 \\
%         VAE + Inference Critic (Self) & 
%         89.18 & 6.30 & 3.75 & 11.0 \\
%         VAE + Inference Critic (Hybrid) & 
%         89.16 & 6.41 & 3.78 & 13.0 \\
%         % VAE + Model Critic & \textbf{328.1} & 0.0 & 0.0 & 0.0 \\
%         VAE + Inference Critic (NN) & 
%         89.24 & \textbf{7.66} & \textbf{3.82} & \textbf{28.0}
%         \end{tabular}
%         \caption{Quantitative results on the Omniglot image dataset. We find each critic improves on collapse metrics when added to the standard VAE with no other changes. Results for comparison without KL annealing were referenced from \citep{he_lagging_2019} or re-implemented in the same framework and are averages of 5 runs. (We follow the training details in the aforementioned methods, running until convergence is achieved on the validation ELBO.)}
%         \label{table:imgresults}
%         %\vspace{-1em}
%     \end{table*}

\section{Experiments}

The basic objective of our experiments is to analyze inference critics under posterior collapse. We report results across three established image and text datasets.

\subsection{Common Experimental Setup}

We follow a common setup throughout our experiments. Since the method is compatible with the VAE family, it can also be added on top of existing methods to mitigate posterior collapse. We conducted experiments adding forget-me-not regularization to a standard VAE to assess whether the theoretical improvements yielded empirical benefits, following \citep{he_lagging_2019} and collapse metrics from \citep{dieng_avoiding_2019}. We measure measuring the approximate negative log-likelihood (NLL, via $500$ importance samples - this gives a tighter bound for the evaluation than the ELBO), the ELBO's KL term, a Monte Carlo estimate of the MI under the variational joint $I_q(x;z)$, and the number of active units (AU) \citep{burda_importance_2016} on held-out data.

We report results for the Yahoo, Yelp, and Omniglot datasets, allowing systematic comparison to prior work \citep{he_lagging_2019,kim_semi-amortized_2018,dieng_avoiding_2019}. As there are both visual and text datasets, we use existing, appropriate  neural network architectures for each modality, which we describe next to each experiment. For all datasets, we follow the standard train/validation splits provided by the original dataset authors. We evaluated all results on a single NVIDIA GeForce RTX 2048 GPU. 

\textbf{Baselines:} For all experiments, we compare against multiple established baselines. We use the standard \textbf{VAE \citep{kingma_auto-encoding_2014}} without any additional strategy for handling posterior collapse. We also compare to \textbf{SA-VAE \citep{kim_semi-amortized_2018}}. Rather than solely relying on amortized inference to obtain variational parameters, the semi-amortized VAE (SA-VAE) uses amortized inference to obtain an initialization and updates the parameters directly from that point. Finally, we compare to the \textbf{Lagging-VAE \citep{he_lagging_2019}}, which aggressively updates the variational family many more times (e.g. 50x) as frequently as the model. We choose this as representative as Lagging-VAE holds the previous state-of-the-art on both modeling and posterior collapse metrics without requiring use of the KL annealing trick, making it a competitive baseline. 

\subsection{Evaluation Metrics}

We evaluate all models and baselines using the standard metrics for evaluating the posterior collapse of VAEs. We report the following:

\textbf{Negative Log Likelihood (NLL):} The negative log-likelihood indicates the modeling performance on held-out 
data. A smaller NLL indicates that the model generalizes to data samples well.

\textbf{KL-Divergence (KL):} The KL-divergence term of the ELBO in Equation \ref{eqn:elbo1} is a commmon indicator of collapse. If we obtain a good ELBO but the KL term is low, the optimization progress comes from the model likelihood (the first term of Equation \ref{eqn:elbo1}). In this case, especially if the KL is at or near zero, the posterior matches the prior too well (suggesting collapse has occurred). 

\textbf{Mutual Information (MI):} 
The estimate of the mutual information across the variational family, $I_q(x;z)$, aims to estimate if the latents have become independent of the data (over the variational joint). We compute this estimate as the difference of the previously-described KL term and the marginal KL term from equation \ref{eqn:elbo2} per \citet{hoffman_elbo_2016}. Both KL terms are obtained by Monte Carlo. The first is obtained naturally in computation of the ELBO, while the second can be computed via ancestral sampling (sampling from the dataset, then the approximate posterior) again as in \citep{hoffman_elbo_2016}. However, as pointed out by \citep{he_lagging_2019}, this estimate is biased - specifically, it is an upper bound. 

\textbf{Active Units (AU):} A standard metric from prior work, the number of active units provides a measure of how many latent dimensions are active, which is specifically how many of the stochastic units show any variation when the input varies. If few are active, we are likely collapsed. Activity is measured by $A_{z}=\operatorname{Cov}_{\mathbf{x}}\left(\mathbb{E}_{z \sim q(z \mid \mathbf{x})}[z]\right)$, with a unit considered active if its activity is above some threshold (we follow \citep{dieng_avoiding_2019} and \citep{he_lagging_2019} with a threshold of $0.01$). 
    % Quantitative results for image data (on MNIST) can be found in Table \ref{table:results} with experimental details. We observe some positive effect on the collapse metrics, which would likely be stronger for stronger models. Interestingly, we observe the ELBO values are slightly better or around the same - somewhat surprisingly, we do not seem to pay a price for having to work with the auxiliary penalty. We also see qualitatively that the samples (Figure \ref{fig:samples}) and reconstructions (Figure \ref{fig:recons}) do not seem to degrade either. 


    \begin{table}[t!]
        \centering
        \begin{tabularx}{\linewidth}{lcc}
        \hline & \multicolumn{2}{c} {Running Time}  \\
        Model & Factor & Absolute (Hrs) \\
        % \hline & & \multicolumn{4}{c} { \textbf{Standard VAE Objective} } \\
        \hline VAE & 
        1.00 & 3.5 \\
        % SA-VAE \citep{kim_semiamortized_2018} & 
        % 89.29 & 2.55 & 2.20 & 4.0 \\
        Lagging-VAE & 
        3.6 & 12.7 \\
        VAE + Inf. Critic (Hybrid) & 
        1.06 & 3.7
        % VAE + Inference Critic (Self) & 
        % 89.18 & 6.30 & 3.75 & 11.0 \\
        % % VAE + Model Critic & \textbf{328.1} & 0.0 & 0.0 & 0.0 \\
        % VAE + Inference Critic (NN) & 
        % 89.24 & \textbf{7.66} & \textbf{3.82} & \textbf{28.0}
        \end{tabularx}
        \caption{Speed comparison results. We report wall clock time (`Absolute') and factor increase over the baseline (`Factor') on the Yahoo benchmark. Adding an inference critic adds minimal overhead. Per the experimental procedure used in \citep{he_lagging_2019}, we run to convergence on the validation ELBO; the standard VAE converges after 54 epochs, Lagging-VAE converges after 49 epochs, and the VAE with a hybrid critic converges after 54 epochs. For the same number of epochs, the speed difference only increases.}
        \label{table:timeresults}
        \vspace{-1.5em}
    \end{table}

\subsection{Results on Text}

\textbf{Experimental Setup:} In this experiment, we evaluate on the Yahoo and Yelp benchmarks \citep{yang_improved_2017}. All methods use the standard train/val/test splits and follow the experimental protocol in \citep{he_lagging_2019}, fully described in Appendix \ref{appendix:protocol}. 
For the neural network architecture, this is a 1-layer LSTM with learnable embeddings for the variational network and for the model network. Following this protocol, all methods use a 32-dimensional latent space and a batch size of 32. 
% We see using an inference critic mitigates collapse effectively without using the `KL annealing' trick from \citep{bowman_generating_2016}. (Note that the weight of the regularization is always $1$ in these experiments.) 
% \textcolor{red}{consider if anything else needed here}


\textbf{Quantitative Results:} The quantitative results in Table \ref{table:textresults} show that all critics substantially improve on the collapse metrics compared to the baseline on both datasets, showing forget-me-not regularization is able to significantly mitigate posterior collapse on text data. 

Our results show that different critics have different behaviors. As the self-critic optimally solves the auxiliary task at each step, if there is any information shared between observations and latents, it is able to pair them up successfully. On the other hand, the neural network critic is entirely separate from the variational network and is decidedly sub-optimal at the auxiliary task - and thus needs more information to be shared, as that increases its chances of finding a way to tie together corresponding pairs. As the experimental results show, this encourages less collapse. In particular, the neural network critic reaches close to the theoretical maximum $I_q$ increase offered by applying an inference critic with the batch size used ($\log(32) \approx 3.4$, see Appendix \ref{appendix:multimi}). At the same time, this may lead to solutions that do not perform as well along the NLL, because there is a much stronger pull to not collapse. The hybrid critic reaches a balance between these: by not being optimal at solving the auxiliary task, it pulls more strongly away from collapse, but by sharing some parameters with the variational network, it is able to reach better modeling solutions than the entirely separate neural network critic, actually improving the NLL slightly over the standard VAE. It is interesting that the neural network critic, which uses an entirely separate set of parameters and thus should be strictly more flexible, fails to reach such an optima; we hypothesize this comes from difficulty in optimization, which is made easier when the weights are tied. None of the critics require additional hyperparameter scheduling, such as to control KL annealing.

Figure \ref{fig:mi_comparison} shows that all critics progressively push the mutual information up over the course of training. In contrast, the standard VAE remains collapsed throughout training. 

\begin{figure}%{0.5\linewidth}
    \centering
    % \vspace{-2em}
    \includegraphics[width=0.9\linewidth]{mi_comparison.pdf}
    \caption{Comparison of mutual information across the variational family ($I_q$) for various critics vs baseline over the first 20 epochs of training on the Yahoo benchmark. This is cropped for clarity; the plot over the entirety of training can be seen in Appendix \ref{appendix:miexpanded}.}
    \vspace{-2em}
    \label{fig:mi_comparison}
\end{figure}
% \vspace{-3cm}

% mitigate posterior collapse for text data, suggesting the critics effectively implement . We evaluate the previously-mentioned metrics on the Yahoo and Yelp benchmarks (\citep{yang_improved_2017}) and find that each critic out

% Quantitative results for text data (on the Yahoo and Yelp benchmarks \citep{yang_improved_2017}) and discussion can be found in Table \ref{table:textresults}.

\subsection{Results on Images}

\textbf{Experimental Setup:} We next evaluate on the 
Omniglot benchmark \citep{lake_human-level_2015} with the provided train/val/test split. All methods follow the experimental protocol in \cite{he_lagging_2019}, fully described in the Appendix.  For the architecture, all methods use a ResNet \citep{he_deep_2015} for the variational network and a 13-layer Gated PixelCNN \citep{oord_conditional_2016} for the model.
All methods use a 32-dimensional latent space with a batch size of 50. 

\textbf{Quantitative Results:} The results on images in Table \ref{table:imgresults} show a similar trend to the results on text, suggesting that the approach is robust across modalities. We find inference critic is able to mitigate collapse for images as well, improving substantially on collapse metrics compared to the baseline. 

Furthermore, the relative quantitative behavior of different critics remains the same on the images as well as text. 
%We observe similar relative behavior of the different critics on images as on text, suggesting the observations indeed hold across modalities. 
In particular, the neural network critic reaches close to the theoretical maximum $I_q$ increase offered by applying an inference critic with the batch size used ($\log(50) \approx 3.9$. See Appendix \ref{appendix:multimi}) for more details. Finally, unlike for text, all critics are able to improve NLL over the baseline. It is interesting that the regularization actually seems to help the original objective; this behavior is unlike that of approaches like Beta-VAE \citep{higgins_beta-vae_2016}. We hypothesize that while the collapsed solution may be `easy' to arrive to with SGD, there exist better optima that actually use the latents, which inference critics drive the parameters towards. 
% \cv{I don't understand why this paragraph is here. What is different in the below paragraph from the ones above? The topic sentence doesn't say.}
% We follow the experimental protocol in \citep{he_lagging_2019}, using a ResNet \citep{he_deep_2015} for the variational network and a 13-layer Gated PixelCNN \citep{oord_conditional_2016} for the model network. We again see using an inference critic prevents collapse more effectively than other methods by a significant margin without using the `KL annealing' trick from \citep{bowman_generating_2016}, and in fact . We again reach close to the theoretical maximum $I_q$ increase offered by applying an inference critic with the batch size used ($\log(50) \approx 3.9$, see Appendix \ref{appendix:multimi}). 
\subsection{Efficiency and Running Time Performance}

A key advantage of this approach is that it adds minimal overhead to standard VAE training. As Table \ref{table:timeresults} shows, the hybrid critic completes training in only 1.06x the time the standard VAE takes - 3.7 hours vs 3.5 hours in wall-clock time on one NVIDIA RTX 2048. This comes at no performance trade-off: the hybrid critic improves along the collapse metrics KL, MI, and AU over the standard baseline with no reduction in NLL. The time performance is particularly important as the overhead of previous approaches to avoiding posterior collapse is typically too high for large-scale problems. While Lagging-VAE does help against posterior collapse, it took 12.7 hours to train (3.4x the time of the hybrid critic). Other approaches are reported in the literature to require further computation - for example, SA-VAE takes between 4 and 8 times as long as Lagging-VAE in \citep{he_lagging_2019}.

How can our method have such low overhead? Contrastive learning often can be quite expensive. The 6\% overhead is made possible with the hybrid critic, as it shares most of its parameters with the inference network. The bulk of the computation is thus only done once, with the overhead introduced when the two diverge. Additionally, with inspiration from current work in self-supervised learning we employ critics that are separable \citep{poole_variational_2019}, meaning they do not need to take every $(x,z)$ pair (which would be $n^2$ forward passes) but only the x. This leads to quadratic big-O speedup compared to other approaches that need to process every potential pairing.
    % Quantitative results for image data (on the Omniglot benchmark \citep{lake_human-level_2015}) and discussion can be found in Table \ref{table:imgresults}.  We see a strong effect from the inference critic on the collapse metrics.
    % , likely because 1) a strong LSTM decoder was used and 2) text/discrete data is particularly prone to collapse.
    % We also take some penalty in terms of the NLL, but much less than the amount we have increased the KL by in avoiding collapse.
    % Of particular interest is the apparent trade-offs between the different types of critics. While the neural network critic was most effective at improving the collapse metrics, it resulted in worse NLL on the text datasets. The self-critic \textcolor{red}{fill this in}. The hybrid critic performed quite well on the collapse metrics while actually improving the NLL. 
   
 
    % \begin{table}[]
    % \begin{tabular}{l|l|l|}
    % \cline{2-3}
    % & \textbf{Factor} & \textbf{Absolute} \\ \hline
    % \multicolumn{1}{|l|}{VAE}                             & 1.00            & 3.5               \\ \hline
    % \multicolumn{1}{|l|}{Lagging-VAE}                     & 3.6             & 12.7              \\ \hline
    % \multicolumn{1}{|l|}{VAE + Inference Critic (Hybrid)} & 1.06            & 3.7               \\ \hline
    % \end{tabular}
    % \caption{Speed comparison results on the Yahoo benchmark. \\
    % We report wall clock time (`Absolute') and factor increase over the baseline (`Factor'). Adding an inference critic adds minimal overhead. Per the experimental procedure used in \citep{he_lagging_2019}, we to convergence on the validation ELBO; the standard VAE converges after 54 epochs, Lagging-VAE converges after 49 epochs, and the VAE with a hybrid critic converges after 54 epochs. (For the same number of epochs, the speed difference only increases.)}
    % \end{table}

\section{Related Work}
    % The goal of this project is to present a first step towards linking work in very different areas, so there is quite a bit of related work. Here I present a brief overview of each of these areas and the connections to this work. 
    The most relevant line of study related to this work are methods that try to avoid posterior collapse - especially those that do so by increasing the mutual information. \citep{bowman_generating_2016} identify the problem of posterior collapse for VAEs endowed with powerful generators. They prescribe `KL annealing', or slowly increasing the KL-penalty in Equation \ref{eqn:elbo1} at the start of training. From Equation \ref{eqn:elbo2}, we can interpret this as slowly ramping up the MI penalty inherent to the ELBO. Approaches like \citep{chen_variational_2017}, \citep{gulrajani_pixelvae_2016}, \citep{yang_improved_2017} modify the architecture of the generative model to reduce its flexibility, with the hope that this will prevent it from finding solutions that `forget'. \citep{dieng_avoiding_2019} instead adds skip connections to the VAE model network to increase $I_p(x;z)$. 
    
    Other approaches aim to explicitly encourage higher mutual information by modifying the objective, towards the goal of `fixing a broken ELBO' as identified by \cite{alemi_fixing_2018}. Forget-me-not regularization falls into this category. (Note that we do not claim novelty in the general idea of modifying the objective to avoid MI loss; rather, forget-me-not regularization presents a new way of accomplishing this with distinct advantages in optimization ease, simplicity, and speed.) \cite{zhao_infovae_2018} introduce an explicit marginal-KL penalty which can be traded off with the usual KL term in Equation \ref{eqn:elbo1}. They do this either with an adversarial classifier that tries to guess if a sample is from the aggregate posterior or the prior, Stein variational gradient descent \citep{liu_stein_2019}, or kernel-based methods (e.g., MMD \citep{gretton_kernel_2012}), each of which is difficult to work with. They show adversarial autoencoders \citep{makhzani_adversarial_2016} are a special case of the former. \cite{razavi_preventing_2019} enforce a minimum KL instead of using a penalty.
    
    Wasserstein autoencoders \citep{tolstikhin_wasserstein_2019} provide another approach to marginal matching, aiming to avoid the traditional KL term (and its associated MI penalty) entirely by matching marginals with the Wasserstein distance (but this also requires adversarial training or MMD - it is also a generalization of adversarial autoencoders). \cite{phuong_mutual_2018} adds an explicit penalty for the MI with the Barber and Agakov lower bound \citep{barber_im_2004}, but this is not fully differentiable and requires use of the high-variance REINFORCE gradient estimator.  Closely related to this work is \cite{rezaabad_learning_2020}, which also uses an auxiliary network on batches of samples (from the variational joint) to obtain a penalty term; instead of solving a straightforward classification problem as done here, they use these to try to estimate the dual form of the mutual information. VAE-MINE \citep{qian_enhancing_2019} use the MINE bound (a cousin of the CPC bound) to aim to increase MI, but as pointed out by \cite{poole_variational_2019}, the way it is computed does not constitute a correct MI bound. (See Appendix \ref{appendix:vaemine} for more detailed discussion of VAE-MINE.)   Along different lines, \cite{kim_semi-amortized_2018} address posterior collapse by using the inference network's outputs as an initialization for SVI, creating a hybrid procedure between amortized and non-amortized inference. \cite{he_lagging_2019} modify the optimization procedure to take more steps for the inference network than the model network.
    
    The idea of including an auxiliary classification task to generative models has a long history. The most famous of these are GANs \citep{goodfellow_generative_2014}, which train an auxiliary classifier adversarially on the empirical data distribution and the model data distribution, aiming to make these indistinguishable. \cite{uehara_generative_2016} interprets this in the framework of density ratio estimation. Bayesian GANs \cite{tran_hierarchical_2017} modify these with Bayesian neural networks for Bayesian inference. Hybrids between VAEs and GANs have been introduced in many forms \citep{larsen_autoencoding_2016,donahue_adversarial_2017,srivastava_veegan_2017,mescheder_adversarial_2017}.  All of these use the auxiliary classifier adversarially, to encourage distribution matching, whereas we are encouraging distributions \textit{not} to match - specifically, we are encouraging dependence between latents and observations by making the joint different from the product of marginals. \cite{aneja_contrastive_2021} use an auxiliary classifier for noise-contrastive estimation between the prior and the aggregate posterior to address the `prior hole' problem. Rather than auxiliary classifiers, \cite{seybold_dueling_2019} add auxiliary decoders aiming to avoid collapsed optima.
    
    This work provides a foundation for connecting advances in contrastive learning with the VAE framework, giving us control over \textit{what} information is preserved (vs what the representations are invariant to).
    %We believe this is to be a critical connection this paper makes.
    CPC \citep{oord_representation_2019} introduces the InfoNCE mutual information bound, applying it to representations obtained by applying autoregressive models to sequential data. \citep{wu_unsupervised_2018} proposes the instance discrimination problem, which suggests matching observations with encodings of perturbed versions of the same observation - for example, two different augmentations of an image, like brightness shifts - learning representations that are invariant to these factors. MoCo \citep{he_momentum_2020} proposes a `momentum queue' to hold encodings from recent batches to increase the number of `negatives' being compared to substantially; it surpasses the performance of representations learned by supervised neural networks on various computer vision tasks. %\citep{chen_simple_2020} shows that stronger augmentations yield further gains, in their follow-up \citep{chen_big_2020} showing strong performance for such models in the semi-supervised regime. 
    \cite{zhu_eqco_2020} allows for contrastive learning without large numbers of negatives via a small modification to the InfoNCE objective, that provides a more practical bound on the mutual information. \cite{tschannen_mutual_2020} provides connections of these techniques and associated mutual information bounds to metric learning, which could also be an interesting perspective to bring to representation learning with VAEs. \vspace{-1em}
\section{Conclusions}
% \textcolor{red}{restate abstract - what did you do, state facts}
We present a new method for protecting against posterior collapse in VAEs. In doing so, we establish a connection between VAEs and contrastive representation learning. We show inference critics increase the mutual information between latents and observations by maximizing the CPC lower bound. Experiments on three datasets show the effectiveness of the approach, with significant efficiency gains. %Experimentally, we demonstrate competitive performance with established baselines in preventing posterior collapse with substantial speed improvements.

% Improving the quality of representations learned during VAE training could lead to advances in a wide range of downstream tasks, such as classification, future/past prediction (e.g. in the context of world models \citep{ha_world_2018}), and more. These tasks in turn have uncountable applications. Some, like solving robotics tasks or enabling autonomous vehicles, stand to benefit society substantially. On the other hand, these same techniques can be applied in inappropriate or even malicious ways, such as in developing biased classifiers used for hiring purposes \citep{a_amazon_a}. We acknowledge both types of applications. 

% \subsubsection*{References}
\bibliographystyle{plainnat}
\bibliography{references.bib}

% \clearpage



\end{document}
