\section{EXISTENCE OF AN IDEAL INFERENCE FUNCTION}
\label{sec:existence}

% A first step in understanding whether A-VI should be applied to a model is to establish the existence of an ideal inference function.
% That is we need to check that $x_n = x_m$ implies $\nu^*_n = \nu^*_m$.
%  and moreover that $x_n$ provides all the information we need to distinguish the different factors of $q(z_n \s \nu_n^*)$.
% We will see that when an ideal inference function exists, the optimal variational factor takes the form
% \begin{equation} \label{eq:learnable-density}
%     q(z_n \s \nu^*_n) \propto \int_\Theta g(\theta, {\bf x}) (m(\theta, z_n) + h(\theta, z_n, x_n)) \ \text d \theta.
% \end{equation}
% %
% Through $g$, the ideal inference function depends on the entire dataset.
% However, since $g$ is common to all factors, \mbox{$x_n = x_m \implies q(z_n \s \nu^*_n) = q(z_m \s \nu^*_m)$} and so an ideal inference function exists.

We first present a lemma that characterizes the optimal variational parameters of F-VI for any model. 
%
\begin{lemma} \label{lemma:cavi} (CAVI rule)
  Consider a probabilistic model $p(\theta, \mbz, \mbx)$.
  The optimal solution for F-VI verifies,
  \begin{align}
    q(z_n \s \nu^*_n)
    \propto
    \exp \left \{
    \EE{q(\theta \s \nu_0^*)}{
      \EE{q(\mbz_{-n} \s \nu^*)}{
        \log p(\theta, \mbz, \mbx)}} \right \},
  \end{align}
   where $\mathbb{E}_{q(\mbz_{-n} \s \nu^*)}$ is the expectation with respect to all $z_j$'s except $z_n$.
\end{lemma}
%
\begin{proof}
  See Appendix~A.1.
\end{proof}
%
The proof follows from applying the coordinate ascent VI update rule \citep[Eq. 17]{Blei:2017} at the optimal solution $\nu^*$.
Note that the optimal variational parameters depend on the data and so, where helpful, we write $\nu^* = \nu^*({\bf x})$.

\begin{figure*}
      \centering
      \begin{tikzpicture}
    [
      Empty/.style={circle, draw=white!, fill=green!0, thick, minimum size=1mm},
      Round/.style={circle, draw=black!, fill=green!0, thick, minimum size=10mm},
    ]

    % Simple hierarchical
    \node[Empty] (g1) at (1, 1){Simple hierarchical};
    \node[Round] (z_1) at (0, 0){$z_{n - 1}$};
    \node[Round] (z_2) at (2, 0){$z_n$};
    \node[Round] (x_1) at (0, -1.5){$x_{n - 1}$};
    \node[Round] (x_2) at (2, -1.5){$x_n$};
    \node[Round] (theta) at (1, -3) {$\theta$};

    \path[->, draw] (z_1) -- (x_1);
    \path[->, draw] (z_2) -- (x_2);
    \path[->, draw] (theta) -- (x_1);
    \path[->, draw] (theta) -- (x_2);

    % Saw time-series
    \node[Empty] (g2) at (4.5, 1){Saw time series};
    \node[Round] (2z_1) at (3.5, 0){$z_{n - 1}$};
    \node[Round] (2z_2) at (5.5, 0){$z_n$};
    \node[Round] (2x_1) at (3.5, -1.5){$x_{n - 1}$};
    \node[Round] (2x_2) at (5.5, -1.5){$x_n$};
    \node[Round] (2theta) at (4.5, -3) {$\theta$};

    \path[->, draw] (2z_1) -- (2x_1);
    \path[->, draw] (2z_2) -- (2x_2);
    \path[->, draw] (2x_1) -- (2z_2);
    \path[->, draw] (2theta) -- (2x_1);
    \path[->, draw] (2theta) -- (2x_2);

    % Hidden Markov model
    \node[Empty] (g3) at (8, 1){Hidden Markov model};
    \node[Round] (3z_1) at (7, 0){$z_{n - 1}$};
    \node[Round] (3z_2) at (9, 0){$z_n$};
    \node[Round] (3x_1) at (7, -1.5){$x_{n - 1}$};
    \node[Round] (3x_2) at (9, -1.5){$x_n$};
    \node[Round] (3theta) at (8, -3) {$\theta$};

    \path[->, draw] (3z_1) -- (3x_1);
    \path[->, draw] (3z_2) -- (3x_2);
    \path[->, draw] (3z_1) -- (3z_2);
    \path[->, draw] (3theta) -- (3x_1);
    \path[->, draw] (3theta) -- (3x_2);

    % Dense hierarchical
    \node[Empty] (g4) at (11.5, 1){Dense hierarchical};
    \node[Round] (4z_1) at (10.5, 0){$z_{n - 1}$};
    \node[Round] (4z_2) at (12.5, 0){$z_n$};
    \node[Round] (4x_1) at (10.5, -1.5){$x_{n - 1}$};
    \node[Round] (4x_2) at (12.5, -1.5){$x_n$};
    \node[Round] (4theta) at (11.5, -3) {$\theta$};

    % \path[->, draw, dotted] (4z_1) -- (4z_2);
    \path[->, draw] (4z_1) -- (4x_1);
    \path[->, draw] (4z_2) -- (4x_2);
    \path[->, draw] (4z_1) -- (4x_2);
    \path[->, draw] (4z_2) -- (4x_1);
    \path[->, draw] (4theta) -- (4x_1);
    \path[->, draw] (4theta) -- (4x_2);
    

    \end{tikzpicture}
    \caption{
    \textit{For the simple hierarchical model (\Cref{eq:simple-hier}), an ideal inference function $f_{\bf x}$ such that $f_{\bf x}(x_n) = q(z_n \s \nu_n^*)$ exists.
    The saw time-series requires learning a map with two inputs $(x_{n -1}, x_n)$.
    For the Hidden Markov and dense hierarchical graphs, there is no ideal inference function.
    In the dense hierarchical model, there is an edge between every element of $\mbz$ and every element of $\mbx$.
    For clarity we removed edges between $\theta$ and $z_n$ in all graphs.
    }}
    \label{fig:models}
  \end{figure*}

The CAVI rule uses the factorization of $q$ but makes no assumption about the model.
We will reason about the factorization of $p(\theta, {\bf z}, {\bf x})$ using a directed acyclic graph (DAG) representation and define an \textit{exchangeable latent variable model} based on a set of common assumptions.
%
\begin{definition}
An exchangeable latent variable model $p(\theta, {\bf z}, {\bf x})$ verifies 
%
\begin{itemize}
  \item[] (i) {\normalfont local dependence}, i.e. there is an edge between $x_n$ and $z_n$ and \mbox{$p(x_n \mid \mbz, \theta) \neq p(x_n \mid \mbz_{-n}, \theta)$}.

  \item[] (ii) \textbf{{\normalfont conditional independence}} of $x_n$ on $\mbx_{-n}$ given $\mbz$ and $\theta$, i.e. \mbox{$p({\bf x} \mid {\bf z}, \theta) = \prod_{n = 1}^N p(x_n \mid {\bf z}, \theta)$}.

  \item[] (iii) {\normalfont common distributional forms}: no distribution involving $\theta$, ${\bf z}$ or ${\bf x}$ depends on the index of the random variables.

\end{itemize}

% (i) {\normalfont local dependence}, i.e. there is an edge between $x_n$ and $z_n$ and \mbox{$p(x_n \mid \mbz, \theta) \neq p(x_n \mid \mbz_{-n}, \theta)$}, (ii) {\normalfont conditional independence}, i.e. \mbox{$p({\bf x} \mid {\bf z}, \theta) = \prod_{n = 1}^N p(x_n \mid {\bf z}, \theta)$} and (iii) {\normalfont common distributional forms}: no distribution involving $\theta$, ${\bf z}$ or ${\bf x}$ depends on the index of the random variables.
\end{definition}
%
Figure~\ref{fig:models} presents several graphical models which conform to the above definition, including hierarchical models and certain time series.

\begin{definition}
   Following \citet{Agrawal:2021}, we define a \textit{simple hierarchical model} as an exchangeable latent variable model that factorizes according to:
   %
   \begin{equation*}
     p(\theta, \mbz, \mbx) = p(\theta) \prod_{n = 1}^N p(z_n \mid \theta) p(x_n \mid z_n, \theta).
   \end{equation*}
\end{definition}


The main result of this section states that the existence of an ideal inference function is, in general, equivalent to $p(\theta, {\bf z}, {\bf x})$ being a simple hierarchical model.
%
\begin{theorem}  \label{thm:learnable}
  Consider an exchangeable latent variable model $p(\theta, {\bf z}, {\bf x})$.
  \begin{enumerate}
      \item Suppose $p(\theta, {\bf z}, {\bf x})$ is a simple hierarchical model. Then an ideal inference function exists.

      \item Suppose an ideal inference function exists for each $p(\theta, \mbz, \mbx)$ that factorizes according to a graph. Then this graph is the class of simple hierarchical models.
  \end{enumerate}
\end{theorem}
%
\begin{proof}
    See Appendix~A.2.
\end{proof}
%
\begin{remark}
The converse in Theorem~\ref{thm:learnable} (item 2) is stated for a class of models, meaning the result must hold for any choice of distribution $p(\theta, {\bf z}, {\bf x})$ supported by the graph.
This excludes edge cases that arise due trivial symmetries (see Appendix~A.3).
\end{remark}
%This is to exclude cases where we have a model that is not a simple hierarchical model but for a very particular choice of distribution, an ideal inference function still exists.
% Such a case can be constructed using trivial symmetries (see Appendix).

We now provide an outline of the proof for Theorem~\ref{thm:learnable}.
Applying the CAVI rule to the simple hierarchical model, we can show that F-VI's optimal solution takes the form
%
\begin{align} \label{eq:simple-map}
  q(z_n \s \nu^*_n) \propto \int_\Theta g(\theta, \mbx) [m(\theta, z_n) + h(\theta, z_n, x_n)] \text d \theta,
\end{align}
%
where $g(\theta, \mbx) = q(\theta \s \nu^*_0(\mbx))$, $m(\theta, z_n) = \log p(z_n \mid \theta)$ and $h(\theta, x_n) = \log p(x_n \mid z_n, \theta)$.
While the function $g$ depends on the entire data set $\mbx$, it is common to all factors of $q(\mbz)$.
Meanwhile, elements specific to $z_n$ only depend on $x_n$.
Moreover, $x_n = x_m$ implies $q(z_n \s \nu^*_n) = q(z_m \s \nu^*_m)$.
Since a parametric density is uniquely defined by its parameter, we also have a map between $x_n$ and $\nu^*_n$.

% \begin{align} \label{eq:simple-map}
%   & q(z_n \s \nu^*_n) \nonumber \\ 
%   & \propto \exp \left \{ \int_\Theta q(\text d \theta \s \nu^*_0({\bf x}))
%   \log p(z_n \mid \theta) + \log p(x_n \mid z_n, \theta) \right \}.
% \end{align}
% }
%
% \hspace{-0.14in}
%
% \Cref{eq:simple-map} takes the form prescribed by \Cref{eq:learnable-density}, with $g(\theta, {\bf x}) = q(\text d \theta \s \nu^*_0({\bf x}))$, $m(\theta, z_n) = \log p(z_n \mid \theta)$ and \mbox{$h(\theta, x_n) =  \log p(x_n \mid z_n, \theta)$}.
% That is we have a global contribution which depends on all elements of ${\bf x}$ and is common to all factors of $q$, and a local contribution which only depends on $x_n$.
% Hence there is a dataset-dependent function from $x_n$ to $q(z_n \s \nu^*_n)$. 

% Thus per Propositions~\ref{prop:interpolation} and \ref{prop:condition} the amortization gap can be closed.

We show item (2) by starting with the CAVI rule for a general $p(\theta, {\bf z}, {\bf x})$ and identifying terms in the kernel of $q(z_n \s \nu^*_n)$ which are not common to all factors of $q(\mbz)$ but depend on elements of ${\bf x}$ other than $x_n$.
To ``eliminate'' these terms, we need to sever edges in the graphical representation of $p(\theta, {\bf z}, {\bf x})$.
Once we remove the offending terms, we are left with a simple hierarchical model.

\subsection{EXAMPLE: LINEAR PROBABILISTIC MODEL} \label{sec:linear}

We now provide an illustrative example where an ideal inference function can be written in closed form.
% In this example, $p(z_n \mid \mbx) \neq p(z_n \mid x_n)$ and furthermore, we find that $f_\phi$ can solve the amortization interpolation problem using a constant number of variational parameters, no matter how large $N$ is.
%Furthermore we find that in this example the inference function $f_\phi$ can solve the amortization interpolation problem by learning $f_\mbx$ with a constant number of variational parameters, no matter how large $N$ is.
%
Consider the simple hierarchical model,
\begin{equation}
    p(\theta) \propto 1; \ \  p(z_n) = \mathcal N (0, 1); \ \ p(x_n \mid z_n, \theta) = \mathcal N (\theta + \tau z, \sigma^2),
\end{equation}
%
where $\tau \in \mathbb R$ and $\sigma \in \mathbb R$ are fixed.
In this example, the marginal posterior $p(z_n \mid \mbx)$ depends on the entire data set $\mbx$, rather than on $x_n$ alone, a phenomenon known as \textit{partial pooling} in the Bayesian statistics litterature \citep{Gelman:2013}.
This is because rather than hold $\theta$ fixed, we marginalize over it to do full Bayesian inference.


% In our notation, the second argument of normal() is the standard deviation.
% For this example, FVI's optimal solution can be worked out analytically.
%
\begin{proposition}  \label{thm:linear}
   Let $q(z_n \s \nu^*)$ be the optimal solution returned by F-VI, when optimizing over the family of factorized Gaussians.
   Then
   \begin{equation} \label{eq:FVI-linear}
     \mathbb E_{q(z_n \s \nu^*_n)} (z_n) = \frac{\tau}{\sigma^2 + \tau^2} (x_n - \bar x); \ \ \mathrm{Var}_{q(z_n \s \nu^*)} (z_n) = \xi^2,
   \end{equation}
   where $\bar x$ is the average of $x_{1:N}$, and $\xi$ is a constant.
\end{proposition}
%
\begin{proof}
  See Appendix~A.4.
\end{proof}
%
The bulk of the proof is to work out the posterior $p(\theta, \mbz \mid \mbx)$ analytically.
Note a simple argument of conjugacy does not suffice since we also need to marginalize over $\theta$.
%
% \begin{equation}  \label{eq:posterior}
% p(z_n \mid {\bf x}) = \mathcal N \left (\frac{\tau}{\sigma^2 + \tau^2} (x_n - \bar x), s  \right ),
%  \end{equation}
 %
% for some constant $s \neq \xi$.
% 
% Since the posterior is normal, F-VI correctly estimates the correct posterior mean \citep{Turner:2011, Margossian:2023} and, in this case, the variational standard deviation is constant.
% % (F-VI  only correctly estimates the mean and $s \neq \xi$.)
% The thorny details are left to the Appendix.

We can rewrite the optimal mean (\Cref{eq:FVI-linear}) as a linear function,
 \begin{eqnarray}
   \mathbb E_{q(z_n \s \nu^*_n)} (z_n) = \alpha_0({\bf x}) + \alpha x_n; \nonumber \\
   \alpha_0 ({\bf x}) = - \frac{\tau \bar x}{\sigma^2 + \tau^2}; \ \
   \alpha = \frac{\tau}{\sigma^2 + \tau^2}.
 \end{eqnarray}
 %
 For A-VI to match F-VI's solution, we need to learn a linear function for the mean and a constant for the variance, and so, regardless of the number of observations $N$, we can close the amortization gap by learning 3 variational parameters.


% Now, the optimal mean (\cref{eq:FVI-linear}) may be written as a linear function,
%  \begin{eqnarray}
%    \mathbb E_{q(z_n \s \nu^*_n)} (z_n) = \alpha_0({\bf x}) + \alpha x_n; \nonumber \\
%    \alpha_0 ({\bf x}) = - \frac{\tau \bar x}{\sigma^2 + \tau^2}; \ \
%    \alpha = \frac{\tau}{\sigma^2 + \tau^2},
%  \end{eqnarray}
% %
% while the optimal standard deviation a constant.
% So, for A-VI to match F-VI's solution, we need to learn a linear function for the mean and a constant for the standard deviation. 
% Regardless of the number of observations $N$, we can close the amortization gap by learning 3 variational parameters.
% We will corroborate this numerically in \Cref{sec:experiment}.
 
This example provides intuition behind Theorem~\ref{thm:learnable}, which connects A-VI to classical ideas in hierarchical Bayesian modeling.
In the considered example, the posterior mean demonstrates partial pooling, a key property of hierarchical models \citep{Gelman:2013}:
the posterior mean of $z_n$ depends on both the local observation $x_n$ and on the non-local observations through $\bar x$.
Even though $p(z_n \mid {\bf x}) \neq p(z_n \mid x_n)$, the posterior density of each latent variable is distinguished by the local influence of $x_n$, while the global influence of $\bar x$ is the same for all latent variables.
As a result
% \begin{equation}
    $x_n = x_m \implies p(z_n \mid {\bf x}) = p(z_m \mid {\bf x})$,
 %\end{equation}
 and an ideal inference function exists.

\begin{figure*}
      \centering
      \includegraphics[width=1.67in]{figures/elbo_time_lin_normal_1959.pdf}
      \includegraphics[width=1.6in]{figures/elbo_time_nonlin_normal_1797.pdf}
      \includegraphics[width=1.6in]{figures/elbo_time_BNN_1954.pdf}
      \caption{ \textit{Examples of optimization paths. As benchmarks, we use F-VI and a constant factor algorithm which assigns the same distribution to all $q(z_n)$. A-VI is then run using different classes of inference functions: (left) we vary the degree $d$ of a learning polynomial; (middle, right) we vary the width $k$ of an inference neural network.
      For a sufficiently complex inference function, we find that A-VI attains the same ELBO as F-VI, meaning the amortization gap is closed.
      For results across multiple seeds, see Figure~\ref{fig:iter_to_convergence}.
      }}
      \label{fig:optimization_paths}
\end{figure*}

\subsection{Further factorizations of $p(\theta, \mbz, \mbx)$}

Theorem~\ref{thm:learnable} tells us that in general, A-VI cannot achieve F-VI's solution for latent variable models other than the simple hierarchical model.
We show however that for certain models, it is possible to extend the domain of the inference function in order to close the amortization gap.

A general strategy to verify if the amortization interpolation problem can be solved is to prove the existence of a (potentially expanded) ideal inference function by applying the CAVI rule (Lemma~\ref{lemma:cavi}) to any model $p(\theta, \mbz, \mbx)$ of interest.

%In another example, we establish that such a strategy is not feasible.

{\bf Saw time series.} Consider the saw time series model,
%
\begin{equation} \label{eq:saw_time}
  p(\theta, {\bf z}, {\bf x}) = p(\theta) \prod_{n = 1}^N p(z_n \mid x_{n - 1}) p(x_n \mid z_n, \theta),
\end{equation}
%
where each latent variable $z_n$ depends on the previous observation $x_{n - 1}$.
Applying the CAVI rule, we have
%
\begin{equation}  \label{eq:saw-FVI}
          q(z_n; \nu^*_n) \propto p(z_n \mid x_{n - 1}) \exp \left \{\mathbb E_{q(\theta \s \nu_0^*)} [\log p(x_n \mid z_n, \theta)] \right \},
\end{equation}
 %
which defines a (data-set dependent) function from $(x_{n - 1}, x_n)$ to the optimal variational factor.
There is no ideal inference function, \mbox{$f_{\bf x}: \mathcal X \to \mathcal U$}, however, there exists an ideal inference function $f_{\bf x}: \mathcal X \times \mathcal X \to \mathcal U$, such that $f_{\bf x}(x_{n - 1}, x_n) = \nu^*_n$ for all $n > 1$.

\begin{remark}  \label{remark:edge}
  When extending the domain of the inference function, we must address edge cases which may not have the requisite argument.
  For example, inference for $z_1$ requires passing $(x_0, x_1)$ to $f_\phi$, but $x_0$ is not observed.
  In this case, we assign a distinct variational parameter $\nu_1$ to the factor $q(z_1)$, rather than use amortization.
\end{remark}


% In this case extending the domain of the inference function ensures the existence of a solution to the amortization interpolation problem.

{\bf Hidden Markov model (HMM).} We now consider another time series, where even after expanding the domain of the inference function, the amortization gap cannot be closed.
The joint of the HMM is
%
\begin{equation} \label{eq:hmm}
      p(\theta, {\bf z}, {\bf x}) = p(\theta) \prod_{n = 1}^N p(z_n \mid z_{n - 1}) p (x_n \mid z_n, \theta).
\end{equation}
%
The next proposition states that there is in general no ideal inference function $f: \mathcal X \to \mathcal U$ and furthermore expanding the domain of the inference functions still yields no ideal function.
%
\begin{proposition} \label{thm:hmm}
  Consider the HMM of \Cref{eq:hmm}.
  Let ${\bf w}_n \in \mathcal W$ be a strict subset of ${\bf x}$.
  There exist HMMs with no ideal inference function $f_{\bf x}: \mathcal W \to \mathcal U$.
  That is, we cannot construct an $f_\mbx$ such that $f_{\bf x}({\bf w}_n) = \nu^*_n$ for all $n$ which do not constitute an edge case (\Cref{remark:edge}).
\end{proposition}
%
\begin{proof}
  See Appendix~A.5.
\end{proof}

\begin{remark}
    If we extend the domain of the inference function to the entire data set $\mbx$, then A-VI reduces to F-VI and the amortization gap is trivially closed.
\end{remark}

The proof of Proposition~\ref{thm:hmm} is obtained by constructing a (non-adversarial) example.
We argue the above result holds in general and provide a conceptual explanation as to why.
In the simple hierarchical and saw time series models the existence of an ideal inference function (respectively over $\mathcal X$ and $\mathcal X \times \mathcal X$) is due to the fact that each data point either has a local or a global influence on $q(z_n \s \nu^*_n)$.
In the case of an HMM, there is no common global influence: any observation $x_m$ will have a different influence on the variational factor for each latent variable $z_n$.
Moreover, each observation is, to a varying degree, local to any latent variable.
%
A similar reasoning can be applied to the dense hierarchical model (Figure~\ref{fig:models}), which includes Gaussian process models.
