% \documentclass{templates/uai2024} % for initial submission
\documentclass[accepted]{uai2024} % after acceptance, for a revised version; 
% also before submission to see how the non-anonymous paper would look like 
                        
%% There is a class option to choose the math font
% \documentclass[mathfont=ptmx]{uai2024} % ptmx math instead of Computer
                                         % Modern (has noticeable issues)
% \documentclass[mathfont=newtx]{uai2024} % newtx fonts (improves upon
                                          % ptmx; less tested, no support)
% NOTE: Only keep *one* line above as appropriate, as it will be replaced
%       automatically for papers to be published. Do not make any other
%       change above this note for an accepted version.

%% Choose your variant of English; be consistent
% \usepackage[american]{babel}
\usepackage[british]{babel}

%% Some suggested packages, as needed:
\usepackage{natbib} % has a nice set of citation styles and commands
    \bibliographystyle{plainnat}
    \renewcommand{\bibsection}{\subsubsection*{References}}
\usepackage{mathtools} % amsmath with fixes and additions
% \usepackage{siunitx} % for proper typesetting of numbers and units
\usepackage{booktabs} % commands to create good-looking tables
\usepackage{tikz} % nice language for creating drawings and diagrams

%% Provided macros
% \smaller: Because the class footnote size is essentially LaTeX's \small,
%           redefining \footnotesize, we provide the original \footnotesize
%           using this macro.
%           (Use only sparingly, e.g., in drawings, as it is quite small.)

\usepackage{algorithm}
\usepackage{algpseudocode}
%\usepackage{algorithm}
%\usepackage{algorithmic}
%\usepackage{algpseudocode}

\usepackage{tikz} 
\usetikzlibrary{bayesnet}
\usetikzlibrary{arrows}% nice language for creating drawings and diagrams

\usepackage{mathtools,amsmath,amssymb,amsfonts}

\input{macros}

\title{Using Autodiff to Estimate Posterior Moments, Marginals and Samples}

% The standard author block has changed for UAI 2024 to provide
% more space for long author lists and allow for complex affiliations
%
% All author information is authomatically removed by the class for the
% anonymous submission version of your paper, so you can already add your
% information below.
%
% Add authors
% \author[1]{\href{mailto:<jj@example.edu>?Subject=Your UAI 2024 paper}{Jane~J.~von~O'L\'opez}{}}
% \author[1]{Harry~Q.~Bovik}
% \author[1,2]{Further~Coauthor}
\author[1]{Sam Bowyer}
\author[2]{Thomas Heap}
\author[2]{Laurence Aitchison}



% Add affiliations after the authors
\affil[1]{%
    School of Mathematics\\
    University of Bristol
}
\affil[2]{%
    Department of Computer Science\\
    University of Bristol
}
% \affil[0]{%
%     Equal Contribution
%   }
  
\begin{document}
\maketitle

\begin{abstract}
  Importance sampling is a popular technique in Bayesian inference: by reweighting samples drawn from a proposal distribution we are able to obtain samples and moment estimates from a Bayesian posterior over latent variables.
  Recent work, however, indicates that importance sampling scales poorly --- in order to accurately approximate the true posterior, the required number of importance samples grows is exponential in the number of latent variables \citep{chatterjee2018sample}.
  Massively parallel importance sampling works around this issue by drawing $K$ samples for each of the $n$ latent variables and reasoning about all $K^n$ combinations of latent samples.
  In principle, we can reason efficiently over $K^n$ combinations of samples by exploiting conditional independencies in the generative model.
  Previous work only detailed how to compute an ELBO/marginal likelihood estimator by summing over all $K^n$ combinations.
  However, that work did not give an approach for computing other quantities of interest, namely posterior expectations, marginals and samples, as computing these quantities is far more complex. 
  Specifically, these computations involve iterating forward (following the generative process), then iterating backward through the generative model.
  These backward traversals can be very complex, and require different backward traversals for each operation of interest.
  Our contribution is to exploit the source term trick from physics to entirely avoid the need to hand-write backward traversals.  
  Instead, we demonstrate how to simply and easily compute all the required quantities --- posterior expectations, marginals and samples --- by differentiating through a slightly modified marginal likelihood estimator.
  %Massively parallel methods provide a potential route to addressing this problem, as they draw $K$ samples for each of the $n$ latent variables and reason about all $K^n$ combinations of latent samples.
  %However, previous work only detailed how to compute an ELBO/marginal likelihood estimator by summing over all $K^n$ combinations, it did not give an approach for computing other quantities of interest, namely posterior expectations, marginals and samples.
  %That is because computing posterior expectations, marginals and samples seems to require complex algorithms that traverse backwards through the graphical model.
  %Our contribution is to exploit the source term trick from physics to entirely avoid the need to hand-write backward traversals.  
  %Instead, we demonstrate how to simply and easily compute all the required quantities --- posterior expectations, marginals and samples --- by differentiating through a slightly modified marginal likelihood estimator.
\end{abstract}


\section{Introduction}

Importance weighting allows us to reweight samples drawn from a proposal in order to compute expectations of a different distribution, such as a Bayesian posterior.
However, importance weighting breaks down in larger models.
To demonstrate this, \citet{chatterjee2018sample} considered a model with data, $x$, latent variables $z$ a true posterior, $\Pc{z}{x}$ and a proposal $\Q{z}$. 
They showed that the number of samples required to accurately approximate the true posterior scales as \(\exp\b{D_{\text{KL}}\b{\Pc{z}{x}|| \Q{z}}}\).
Problematically, we expect the KL divergence to scale with \(n\), the number of latent variables.
Indeed, if we have $n$ latent variables, and \(\Pc{z}{x} = \prod_{i=1}^n \Pc{z_i}{x}\) and \(\Q{z} = \prod_{i=1}^n \Q{z_i}\) are IID over those \(n\) latent variables, then the KL-divergence is exactly proportional to \(n\).
Thus, we expect the required number of importance samples to be exponential in the number of latent variables, and hence we expect accurate importance sampling to be intractable in larger models.

To resolve this issue we use a massively parallel importance sampling scheme that in effect uses an exponential number of samples to compute posterior expectations, marginals and samples  \citep{ml}.
This involves drawing \(K\) samples of each of the \(n\) latent variables from the proposal, then individually reweighting all \(K^n\) combinations of all samples of all latent variables.
While reasoning about all \(K^n\) combinations of samples might seem intractable, we should in principle be able to perform efficient computations by exploiting conditional independencies in the underlying probabilistic generative model.
These conditional independencies can be depicted by drawing a graph with latent variables as the nodes, and dependencies as the edges; such models are often known as ``probabilistic graphical models'', or even just ``graphical models'' \citep{jordan2003introduction,koller2009probabilistic}.

However, many computations that are possible in principle are extremely complex in practice, and that turns out to be the case here.
While we should be able to perform this reasoning over \(K^n\) latent variables using methods from the discrete variable graphical model literature, this turned out to be less helpful than we had hoped because these algorithms involve highly complex backward traversals of the generative model. 
Worse, different traversals are needed for computing posterior expectations, marginals and samples, making a general implementation challenging.
Our contribution is to develop a much simpler approach to computing posterior expectations, marginals and samples, which entirely avoids the need to explicitly write backwards computations.
Specifically, we show that posterior expectations, marginals and samples can be obtained simply by differentiating through (a slightly modified) forward computation that produces an estimate of the marginal likelihood.
The required gradients can be computed straightforwardly using modern autodiff, and the resulting implicit backward computations automatically inherit potentially complex optimizations from the forward pass.

%However, our massively parallel methods allow a broad class of proposal distributions in a broad class of probabilistic models. The divide-and-conquer algorithm \citep{kuntz2021divide} proposes a sequential monte-carlo estimator that is similar to ours, but requires a more complex implementation involving reweighting steps.
%In contrast to the previously mentioned methods, we develop a new approach to computing posterior expectations, marginals and samples by differentiating through a slightly modified marginal likelihood estimator.

%There is of course a large body of work doing inference in various discrete graphical model settings \citep{dawid1992applications,pfeffer2005design,bidyuk2007cutset,geldenhuys2012probabilistic,claret2013bayesian,sankaranarayanan2013static,goodman2014design,gehr2016psi,narayanan2016probabilistic,albarghouthi2017fairsquare,wang2018pmaf,obermeyer2019tensor,holtzen2020scaling}.
%However, this past work 
%does not use our approach to use a modern autodiff frameowkr
%One approach to computing expectations, marginals and samples would be to notice a connection between reasoning over \(K^n\) particles, and message passing in a discrete graphical model.
%Of course, one of our key contributions over this past work
%As such, perhaps the closet prior work is \citet{obermeyer2019tensor}, in that they provide methods for reasoning about plated discrete graphical models.
%However, they do not provide methods for the massively parallel context, and their methods are considerably more complex than those proposed here.
%In particular, to compute expectations, marginals and samples, they ``proceed by recording an adjoint compute graph alongside the forward computation and then traversing the adjoint graph backwards starting from the final result of the forward computation''.
%This is necessarily far more complex than our approach of just differentiating through the forward computation, if for no other reason than they need to implement separate backwards traversals for each operation of interest, and these traversals need to correctly handle all special cases (including optimized implementations of plates, timeseries and Gaussian processes).
%In contrast, for approach of differentiating through the forward computation, you only need to correctly implement the forward computation.
%Indeed, this complexity is mirrored in other approaches to reasoning about discrete graphical models.



%
%
%This approach is much simpler than past methods for discrete graphical models, for two key reasons.
%First, we do not need to explicitly represent the graphical model: we just feed the conditional probability tensors into opt-einsum, which gives the answer.
%Second, we do not need to make any choices (such as the nodes to merge in a junction tree) as the opt-einsum implementation optimizes the order of the tensor reduction.
%

%One approach to computing expectations, marginals and samples would be to notice a connection between reasoning over \(K^n\) particles, and message passing in a discrete graphical model.
%While classical message passing is of course restricted to only tree-structured models, there are generalisations such as junction-trees which can deal with arbitrary graphs.
%However there are three particular difficulties in the probabilistic programming setting that make direct application of e.g.\ junction trees difficult.
%In particular, modern practical probabilistic programming languages must provide optimized implementations of e.g.\ plates (for IID sampling) and timeseries, which exploit the implied structure to dramatically speed up computations.
%However, to our knowledge there are no generalisations of e.g.\ junction tree algorithms that take account of e.g. plates and timeseries.
%While we could do the generalisation ourselves, that would itself be a considerable body of new research.
%In contrast, assuming that we have generalised the marginal likelihood estimator to e.g. timeseries and plates (which must be done anyway), our approach gives expectations, marginals and samples for free.
%
%%The requirement to exploit this structure adds considerable complexity to e.g.\ junction trees.
%%For instance, how do we merge nodes in the context of plates and timeseries? Does it make sense to merge a standard variable and a timeseries?
%%We were unable to find pre-existing junction tree algorithms that have been generalised to these settings, and indeed, generalising junction trees to these more complex settings is a formidable exercise, constituting a signficant amout of new research.
%
%Indeed, previous probabilistic programming languages have included discrete 
%
%First, we need to algorithms to choose the nodes to merge in the junction tree.
%Second, message passing algorithms require an explicit representation of the graph, and the ability to pass explicit messages in that graph.
%Third, while it may be possible to share some code for computing posterior expectations, marginals and samples, it is unclear that this sharing will be particularly extensive.
%These difficulties are of course surmountable in simple cases. 
%However, they become prohibitive in the very general probabilistic programming setting, for two reasons.
%First, in probabilistic programming, we would like to support the full range of programming constructs offered by the underlying library.
%Attempting to represent the full range of programming constructs in a graph is not typically possible (or at least is very difficult and puts a considerable burden on the user) in traditional programming environments such as Python.
%In contrast, in probabilistic programming, we typically avoid an explicit representation 
%First, in probabilistic programming we need to provide optimized implementations of e.g.\ plates (for IID sampling) and timeseries, which exploit the implied structure in the graphical model.
%The requirement to exploit this structure adds considerable complexity to e.g.\ junction trees.
%For instance, how do we merge nodes in the context of plates and timeseries? Does it make sense to merge a standard variable and a timeseries?
%We were unable to find pre-existing junction tree algorithms that have been generalised to these settings, and indeed, generalising junction trees to these more complex settings is a formidable exercise, constituting a signficant amout of new research.
%
%In the modern machine learning literature, there is some work that considers combinations of many samples in the context of for instance an importance weighted autoencoder.
%
%This work did consider similar schemes involving \(K^n\) particles, and also used similar marginal likelihood estimators.
%However, this work was entirely focused on learning proposals/approximate posteriors using VI or RWS. 
%They did not consider at all how to use importance weighting or sampling to compute critical quantities of interest in Bayesian statistics, namely posterior expectations, marginals and means.
%
%Particle filters / particle Gibbs.


%While these questions likely could be resolved, this would be a significant body of research,
%Instead, it is far more desirable avoid a graph-based representation, and instead consider only ...

%The reasoning about all \(K^n\) combinations of samples in our importance sampling scheme bears some resemblance to reasoning about the exponentially many states in a discrete graphical model.
%Of course, inference in discrete graphical models is a well-studied problem.
%First, classical message passing algorithms are usually defined only for models where the graphical model structure forms a tree.
%Generalising these algorithms to general models non-tree structured models is a complex topic.
%For instance, approaches such as ``loopy belief propagation'', but this represents an approximation, whereas we consider only exact algorithms.
%Alternatively, there are exact methods such junction trees and cutset conditioning.
%However, these algorithms all have in common that they require an explicit representation of the underlying graph and complex special-purpose code to perform the three key operations: computing posterior expectations, marginals and samples.
%In contrast, we leverage modern autodiff frameworks to massively simplify these computations.
%In particular, we reduce the problems of computing posterior expectations, marginals and samples to the problem of computing gradients of a large tensor product.
%This tensor product can be computed efficiently and exactly using an opt-einsum implementation, and the gradients can be computed efficiently using modern autodiff.

\section{Background}
\label{sec:back}

\textbf{Bayesian inference.}
In Bayesian inference, we have a prior, \(\P{z'}\) over latent variables (sometimes in the statistical literature called parameters), \(z' \in \mathcal{Z}\), and a likelihood, \(\Pc{x}{z'}\) connecting the latents to the data, \(x\).
Here, we use \(z'\) rather than \(z\) because we reserve \(z\) for future use as a collection of \(K\) samples (Eq.~\ref{eq:z}).
Our goal is to compute the posterior distribution over latent variables conditioned on observed data,
\begin{align}
  \label{eq:bayes}
  \Pc{z'}{x} &= \frac{\Pc{x}{z'} \P{z'}}{\int dz'' \; \P{x, z''}},
\end{align}
We often seek to compute posterior expectations,
\begin{align}
  \label{eq:mpost}
  m_\post &= \int dz' \; \Pc{z'}{x} m(z')
\end{align}
but these are usually intractable, so instead we are forced to use an alternative method such as importance weighting.

\textbf{Importance weighting.} In importance weighting, we draw a collection of \(K\) samples from the full joint latent space.
A single sample is denoted \(z^k\in\mathcal{Z}\), while the collection of $K$ samples is denoted,
\begin{align}
  \label{eq:z}
  z &= (z^1, z^2, \dotsc, z^K) \in \mathcal{Z}^K.
\end{align}
The collection of \(K\) samples, \(z\), is drawn by sampling \(K\) times from the proposal,
\begin{align}
  \label{eq:Qglobal}
  \Q{z} &= \prod_{k\in\mathcal{K}} \Q{z^k},
\end{align}
where \(\mathcal{K}\) is the set of possible indices, \(\mathcal{K} = \{1,\dotsc,K\}\).
As the true posterior moment is usually intractable, one approach is to use a self-normalized importance sampling estimate, \(m_\glob(z)\).
We call this a ``global'' importance weighted estimate following terminology in \citet{geffner2022variational}.
The global importance weighted moment estimate is,
\begin{align}
  \label{eq:mglobal}
  m_\glob(z) &= \tfrac{1}{K} \sum_{k\in\mathcal{K}} \frac{r_k(z)}{\Pglob(z)} m(z^k)\\
  \intertext{where, samples, \(z\), are drawn from the proposal (Eq.~\ref{eq:Qglobal}), and}
  \label{eq:rglobal}
  r_k(z) &= \frac{\P{x, z^k}}{\Q{z^k}}\\
  \label{eq:Pglobal}
  \Pglob(z) &= \tfrac{1}{K} \sum_{k\in\mathcal{K}} r_{k}(z)
\end{align}
Here, \(r_k(z)\) is the ratio of the generative and proposal probabilities, and \(\Pglob(z)\) is an unbiased estimator of the marginal likelihood,
\begin{align}
  \nonumber
  \E[\Q{z}]{\Pglob(z)} &= \E[\Q{z^k}]{\frac{\P{x, z^k}}{\Q{z^k}}} \\
  &= \int dz^k \; \P{x, z^k} = \P{x}
\end{align}
The first equality arises because \(\Pglob(z)\) is the average of \(K\) IID terms, \(\P{x, z^k}/\Q{z^k}\), so is equal to the expectation of a single term, and the second equality arises if we write the expectation as an integral.

\textbf{Source term trick.}
Here, we outline a standard trick from physics that can be used to compute expectations of arbitrary probability distribution by differentiating a modified log-normalizing constant.
This trick is used frequently in Quantum Field Theory, for instance \citep{weinberg1995quantum} (Chapter 16), and also turns up in the theory of neural networks \citep{zavatone2021asymptotics}.
But the trick is simple enough that we can give a self-contained introduction here.

In our context, Bayes theorem (Eq.~\ref{eq:bayes}) defines an unnormalized density, \(\Pc{z}{x} \propto \P{x,z}\), with normalizing constant, \(\int dz' \; \P{x, z'}\).
Of course, the normalizing constant is usually intractable, but one of our contributions will be to show that the massively parallel estimate of the normalizing constant is sufficient to apply the source term trick.
It turns out that we can compute posterior expectations using a slightly modified normalizing constant,
\begin{align}
  Z_m(J) &= \int dz' \; \P{x, z'} e^{J m(z')}.
\end{align}
where \(e^{J m(z')}\) is known as a source term, and $\E[{\P{z'|x}}]{m(z')}$ is the moment we wish to compute.
Note that setting \(J\) to zero recovers the usual normalizing constant, $Z_m(J=0) = \int dz' \; \P{x, z'}$.
%\begin{align}
%  \label{eq:Z(J=0)}
%  Z_m(J=0) &= \sum_{z'} \P{x, z'}.
%\end{align}
Now, we can extract the posterior moment by evaluating the gradient of \(\log Z_m(J)\) at \(J=0\),
\begin{align}
  \nonumber
  \at{\dd{J}}_{J=0} & \log Z_m(J) \\
  \nonumber
  &= \at{\dd{J}}_{J=0} \log \int dz' \; \P{x, z'} e^{J m(z')} \\
  \nonumber
  &= \frac{\int dz' \; \P{x, z'} \at{\dd{J}}_{J=0} e^{J m(z')}}{\int dz'' \; \P{x, z''}}.
  \intertext{Differentiating the exponential at \(J=0\) (first equality), and identifying the posterior using Bayes theorem (Eq.~\ref{eq:bayes}) (second equality),}
  \nonumber
  &= \int dz' \; \frac{\P{x, z'}}{\int dz'' \; \P{x, z''}} m(z') \\
  &= \int dz' \; \Pc{z'}{x} m(z) = m_\post
\end{align}
We get back exactly the form for the posterior moment in Eq.~\eqref{eq:mpost}.

\textbf{Massively parallel marginal likelihood estimators.}
In the massively parallel setting, we assume a probabilistic graphical model with multiple latent variables $z'_i \in \mathcal{Z}_i$, indexed $i$.
We can then form $z'$, used above, as a tuple containing all latent variables,
\begin{align}
  z' &= (z'_1, z'_2, \dotsc, z'_n) \in \mathcal{Z}.
\end{align}
We draw $K$ samples for each latent variable.
%To get an accurate yet computationally tractable marginal likelihood estimator, we use a massively parallel estimator which individually weights all \(K^n\) combinations of \(K\) samples on \(n\) latent variables.
We write the $k$th sample of the $i$th latent variable as $z_i^k$, so all $K$ samples of the $i$th latent variable can be written,
\begin{align}
  z_i &= (z^1_i, z^2_i, \dotsc, z^K_i) \in \mathcal{Z}_i^K.
\end{align}
where $\mathcal{Z}_i$ is the space for the $i$th latent variable.
And $z$ is the collection of all $K$ samples of all $n$ latent variables, (as in Eq.~\ref{eq:z}),
\begin{align}
  z &= (z_1, z_2, \dotsc, z_n) \in \mathcal{Z}^K.
\end{align}
In the massively parallel setting, proposals have graphical model structure,
\begin{align}
  \Qmp{z} &= \prod_{i=1}^n \Qmpc{z_i}{z_j \text{ for } j \in \qa{i}},
\end{align}
where \(\qa{i}\) is the set of indices of parents of \(z_i\) under that graphical model.
This massively parallel proposal over all copies of the $i$th latent variable, $z_i = (z^1_i, z^2_i, \dotsc, z^K_i)$, arises from a user-specified, single-sample approximate posterior, $Q(z_i'|z_j' \text{ for } j \in \qa{i})$, where $z_i'$ and $z_j'$ are a single copy of the $i$th and $j$th latent variable.
For instance, the massively parallel proposal might be IID over the K copies, $z_i^k$, and be based on a uniform mixture over all parent samples (other alternatives are available: see \citep{ml} for further details).
%However, 
%
%Note that this form implies that $z_i$ may depends on all samples of the parent latent variables, and that the $K$ samples of $z_i$, namely $z_i^1,\dotsc,z_i^K$, may have some dependencies.
%Usually, we derive the massively-parallel proposal from an underlying, user-specified, single-sample proposal.
%The simplest approach is just to draw $K$ samples of $z_i$ independently from the full joint latent space (other alternatives are available: see \citep{ml} for further details).
%The key property of these proposals is that the single-sample marginals, $\Qmpc{z_i^k}{z_j \text{ for } j \in \qa{i}}$ are easily computable.
%In all our experiments, we will use a permutation over samples of the parent latent variables.

For the generative model, we need to explicitly consider all \(K^n\) combinations of \(K\) samples on \(n\) latent variables.
To help us write down these combinations, we define a vector of indices, $\k$, with one index, \(k_i\) for each latent variable, \(z_i\).
\begin{align}
  \k &= \b{k_1, k_2, \dotsc, k_n} \in \mathcal{K}^n, \\
  z^\k &= \b{z_1^{k_1}, z_2^{k_2}, \dotsc, z_n^{k_n}} \in \mathcal{Z},
\end{align}
%\begin{align}
%  z^\k &= \b{z_1^{k_1}, z_2^{k_2}, \dotsc, z_n^{k_n}} \in \mathcal{Z},
%\end{align}
That, allows us to write the ``indexed'' latent variables, $z^\k$, which represents a single sample from the full joint latent space.
The generative model also has graphical model structure, with the set of indices of parents of the \(i\)th latent variable under the generative model begin denoted \(\pa{i}\) (contrast this with \(\qa{i}\) which is the parents of the \(i\)th latent variable under the proposal).
The generative probability for a single combination of samples, denoted \(z^\k\), can be written as,
\begin{align}
  \label{eq:gen}
  \begin{split}
  \P{x, z^\k} = &\Pc{x}{z_j^{k_j} \text{ for all } j \in \pa{x}} \\ 
  &\prod_{i=1}^n \Pc{z^{k_i}_i}{z_j^{k_j} \text{ for all } j \in \pa{i}}.
  \end{split}
\end{align}
Thus, we can write a massively parallel marginal likelihood estimator as,
\begin{align}
  %\label{eq:rtmc}
  %\label{eq:Pmp}
  \label{eq:Prmp}
  \Pmp(z) &= {\tfrac{1}{K^n} \sum_{\k\in\mathcal{K}^n} r_{\k}(z)} \\ 
  \intertext{where}
  r_{\k}(z) &= \frac{\P{x, z^\k}}{\prod_i \Qmpc{z_i^{k_i}}{z_j \text{ for } j \in \qa{i}}}
\end{align}
%While this looks intuitively reasonable, proving that Eq.~\eqref{eq:Prmp} is a valid marginal likelihood estimator is nontrivial: the full proofs are given in \citep{ml} (their Appendix C.1.3).
%
The next challenge is to compute the sum in Eq.~\eqref{eq:Prmp}.
The sum looks intractable as we have to sum over \(K^n\) settings of \(\k\).
However, it turns out that these sums are usually tractable.
The reason is that that if we fix the samples, \(z\), then \(r_\k(z)\) can be understood as a product of low-rank tensors,
\begin{align}
  \label{eq:r_prod}
  r_\k(z) &= f^x_{\k_{\pa{x}}}(z) \prod_i f^i_{k_i, \k_\pa{i}}(z)\\
  f^x_{\k_{\pa{x}}}(z) &= \Pc{x}{z_j^{k_j} \text{ for all } j \in \pa{x}},\\%\in \mathbb{R}^{K^{\abs{\pa{x}}}}.\\
  \label{eq:fi}
  f^i_{k_i, \k_{\pa{i}}}(z) &= \frac{\Pc{z^{k_i}_i}{z_j^{k_j} \text{ for all } j \in \pa{i}}}{\Qmpc{z_i^{k_i}}{z_j \text{ for all } j \in \qa{i}}}.% \in \mathbb{R}^{K^{1+\abs{\pa{i}}}}.
  %f^x_{\k_\pa{x}}(z) &= \P{x| z^{\k_{\pa{x}}}_{\pa{x}}}\\
  %f^i_{k_i,\k_\pa{i}}(z) &= \frac{\P{z^{k_i}_i| z^{\k_{\pa{i}}}_{\pa{i}}}}{\Qc{z_i^{k_i}}{z_{\qa{i}}^{k_i}}}
\end{align}
Here, \(f^x_{\k_\pa{x}}(z)\) is a tensor of rank \(\abs{\pa{x}}\), and \(f^i_{k_i,\k_\pa{i}}(z)\) are tensors of rank \(1+\abs{\pa{i}}\), where \(\abs{\pa{i}}\) is the number of parents of the \(i\)th latent variable.
Thus, Eq.~\eqref{eq:Prmp} is a large tensor product,
\begin{align}
  \label{eq:Pmp_tensor_product}
  \Pmp(z) &= \tfrac{1}{K^n} \sum_{\k\in\mathcal{K}^n} f^x_{\k_{\pa{x}}}(z) \prod_i f^i_{k_i, \k_\pa{i}}(z)
\end{align}
which can be efficiently computed using an opt-einsum implementation.

Now, we are in a position to define an importance sampling scheme that operates on all \(K^n\) combinations of samples,
\begin{align}
  \label{eq:mmp}
  m_\mp(z) &= \tfrac{1}{K^n} \sum_{\k\in\mathcal{K}^n} \frac{r_\k(z)}{\Pmp(z)} m(z^\k).
\end{align}
This looks very similar to the standard global importance sampling scheme in Eq.~\eqref{eq:mglobal}, except that Eq.~\eqref{eq:mglobal} averages only over $K$ samples, whereas this massively parallel moment estimator averages over all $K^n$ combinations of samples.
For a proof that this is a valid importance-sampled moment estimator, see Appendix section \ref{app:deriv}.% \textbf{Derivations}.

\section{Methods}
\label{sec:methods}

Of course, the contributions of this paper are not in computing the unbiased marginal likelihood estimator, which previously has been used in learning general probabilistic models, but instead our major contribution is a novel approach to computing key quantities of interest in Bayesian computation by applying the source term trick to the massively parallel marginal likelihood estimator. In particular, in the following sections, we outline in turn how to compute posterior expectations, marginals and samples.

\textbf{Interpreting massively parallel importance weighting as inference in a discrete variable graphical model.}
Now, individual terms  in (Eq.~\ref{eq:mmp}) can be understood as a normalized probability distribution over $\k$,
\begin{align}
  \P[z]{\k} &= \frac{1}{K^n} \frac{r_\k(z)}{\Pmp(z)}.
\end{align}
In particular, this quantity is always positive, and we can show that it normalizes to $1$ by substituting the definition of $\Pmp(z)$ from Eq.~\eqref{eq:Prmp},
\begin{align}
  \sum_\k \P[z]{\k} &= \tfrac{1}{K^n} \sum_\k \frac{r_\k(z)}{\Pmp(z)} = 
  \frac{\tfrac{1}{K^n} \sum_\k r_\k(z)}{\tfrac{1}{K^n} \sum_{\k'} r_{\k'}(z)} = 1
\end{align}
As such, we can in principle use methods for discrete variable graphical models, treating $\k$ as a random variable.
However, as discussed in Related Work, computing posterior expectations, marginals and samples in discrete variable graphical models may still involves complex backward traversals, which are especially difficult if we want to exploit structure such as plates or timeseries to speed up the computations.
%Additionally, this quantity can be written as a product of tensors (Eq.~\ref{eq:r_prod}).
%This product can be reinterpreted as an unnormalized probability density for a factor-graph over random variables, \(\k=(k_1,\dotsc,k_n)\), in which case, the massively parallel moment estimates (Eq.~\ref{eq:mmp}) can be understood as expectations under this unnormalized density.
%This interpretation justifies the use of methods for discrete graphical models.

\textbf{Computing expectations by differentiating an estimate of the normalizing constant.}
Instead, we modify our marginal likelihood estimator with a source term, $e^{J m(z^\k)}$,
\begin{align}
  \label{eq:Pmpmom}
  \Pmpexp(z, J) &= \frac{1}{K^n} \sum_{\k\in\mathcal{K}^n} r_\k(z) e^{J m(z^\k)}.
\end{align}
Remember that \(r_\k(z)\) is a product of low-rank tensors, indexed by subsets of \(\k\) (Eq.~\ref{eq:r_prod}), so the sum can be computed efficiently using opt-einsum.
Critically, the source term is just another factor with indices given by a subset of \(\k\).
For instance, most often \(m\) (the function whose expectation we want to compute) will depend on only a single latent variable \(m(z^\k) = m(z_i^{k_i})\), in which case the source term can be understood as just another tensor in the tensor product (Eq.~\ref{eq:Pmp_tensor_product}), with one index, \(k_i\).
Now, we prove that differentiating the logarithm of this modified marginal likelihood estimator gives back a massively parallel moment estimator. 
% \begin{align}
%   \Pmpexp(z, 0) &= \Pmp(z).
% \end{align}
In particular, we differentiate \(\log \Pmpexp(z, J)\) at \(J=0\) (first equality). 
Then in the numerator we substitute \(\Pmpexp(z, J)\) from Eq.~\eqref{eq:Pmpmom}, and in the denominator, we remember that \(\Pmpexp(z, J=0) = \Pmp(z)\),
\begin{align}
  \nonumber
  \at{\dd{J}}_{J=0} &\log \Pmpexp(z, J) \\
  \nonumber
  &= \frac{\at{\dd{J}}_{J=0} \Pmpexp(z, J)}{\Pmpexp(z, 0)} \\
  \nonumber
  &= \frac{\tfrac{1}{K^n} \sum_\k r_\k(z) \at{\dd{J}}_{J=0} e^{J m(z^\k)}}{\Pmp(z)}
  \intertext{Computing the gradient of \(e^{J m(z^\k)}\) at \(J=0\),}
  &= \frac{\tfrac{1}{K^n} \sum_\k r_\k(z) m(z^\k)}{\Pmp(z)} = m_\mp(z)
\end{align}
where the final equality comes from the definition of \(m_\mp(z)\) in Eq.~\eqref{eq:mmp}.
Note that this derivation is quite different from the standard ``source-term trick'' from physics described in Background, which works with either the true normalizing constant, or with a low-order perturbation to that normalizing constant.
In contrast, here we use a very different massively parallel sample-based estimate of the marginal likelihood.
Importantly, the subsequent two derivations are even more different from uses of the ``source-term trick'' in physics. 
In particular, the source-term trick is almost always used to compute moments/expectations in physics, whereas the subsequent two derivations use the same trick to compute quite different quantities (namely, probability distributions over samples).

Psuedocode for all procedures can be found in Appendix section \ref{app:algo}. %\textbf{Algorithms}.


\textbf{Computing marginal importance weights.}
Computing expectations directly is very powerful and almost certainly necessary for computing complex quantities that depend on multiple latent variables.  
However, if we are primarily interested in posterior expectations of individual variables, then it is considerably more flexible to compute ``marginal'' posterior importance weights.
Once we have these marginal importance weights, we can easily compute arbitrary posterior expectations for individual variables, along with other quantities such as effective sample sizes.
To define the marginal weights for the \(i\)th latent, note that a moment for the \(i\)th latent variable can be written as a sum over \(k_i\),
\begin{align}
  m_\mp(z) &= \sum_{\k\in\mathcal{K}^n} \frac{r_\k(z)}{\Pmp(z)} m(z_i^{k_i}) = \sum_{k_i} w^i_{k_i} m(z_i^{k_i}),\\
  \intertext{where \(w^i_{k_i}\) are the marginal importance weights for the \(i\)th latent variable, which are defined by,}
  \label{eq:wmarg}
  w^i_{k_i} &= \frac{\tfrac{1}{K^n} \sum_{\k/k_i\in\mathcal{K}^{n-1}} r_\k(z)}{\Pmp(z)},
\end{align}
where the sum is over all \(\k\) except \(k_i\).
Formally,
\begin{align}
  \k/k_i &= \b{k_1,\dotsc,k_{i-1}, k_{i+1}, \dotsc,k_n} \in \mathcal{K}^{n-1}.
\end{align}
Again we can compute the marginal importance weights using gradients of a slightly different modified marginal likelihood estimator.
Specifically, we now use a vector-valued \(\J \in \mathbb{R}^{K}\) in a slightly different modified marginal likelihood estimator,
\begin{align}
  \Pmpmarg(z, \J) &= \tfrac{1}{K^n} \sum_{\k} r_\k(z) e^{J_{k_i}}.
\end{align}
Again, \(\Pmpmarg(z, \0) = \Pmp(z)\). As before, we differentiate \(\log \Pmpmarg(z, \J)\) at \(\J=\0\),
\begin{align}
  \nonumber
  \at{\dd{J_{k_i'}}}_{\J=\0} \log & \Pmpmarg(z, \J) \\ 
  &= \frac{\at{\dd{J_{k_i'}}}_{\J=\0} \Pmpmarg(z, \J)}{\Pmpmarg(z, \0)}.
  \intertext{Substituting for \(\Pmpmarg(z, \J)\) in the numerator,}
  &= \frac{\tfrac{1}{K^n} \sum_\k r_\k(z) \at{\dd{J_{k_i'}}}_{\J=\0} e^{J_{k_i}}}{\Pmp(z)}.
  \intertext{The gradient is \(1\) when \(k_i'=k_i\) and zero otherwise which can be represented using a Kronecker delta,}
  &= \frac{\tfrac{1}{K^n} \sum_\k r_\k(z) \delta_{k_i', k_i}}{\Pmp(z)}.
  \intertext{We can rewrite this as a sum over all \(\k\) except \(k_i\),}
  \nonumber
  &= \frac{\tfrac{1}{K^n} \sum_{\k/k_i \in \mathcal{K}^{n-1}} r_\k(z)}{\Pmp(z)} = w^i_{k_i},
\end{align}
which is exactly the definition of the marginal importance weights in Eq.~\eqref{eq:wmarg}.


\begin{figure*}[t]
  \begin{center}
  \includegraphics[width=0.9\textwidth]{figures/IS_per_K.pdf}
  \end{center}
  \caption{ELBO (top row) and predictive log-likelihood (bottom row) across the four datasets achieved via MP IS and global IS, for varying values of K. The error bars represent the standard-error across 100 repeated experiments on the same data but using different random seeds.}
  \label{fig:IS_per_K}
\end{figure*}


\textbf{Computing conditional distributions for importance sampling.}
A common alternative to importance weighting is importance sampling. 
In importance sampling, we rewrite the usual estimates of the expectations in terms of a distribution over indices, \(\P[z]{\k}\),
\begin{align}
  \label{eq:Pk}
  m_\mp(z) &= \sum_{\k\in \mathcal{K}^n} \P[z]{\k} m(z^\k) \\ \intertext{where} \P[z]{\k} &= \frac{1}{K^n}\frac{1}{\Pmp(z)} r_\k(z)
\end{align}
We can obtain (approximate) posterior samples, $z^\k$, by sampling $\k$ from $\P[z]{\k}$.
However, sampling from \(\P[z]{\k}\) is difficult in our context, as there are \(K^n\) possible settings of \(\k\), so we cannot explicitly compute the full distribution.
Instead, we can factorise the distribution, and iteratively sample (e.g.\ we sample \(k_1\) from \(\P[z]{k_1}\) then sample \(k_2\) from \(\Pc[z]{k_2}{k_1}\) etc).
Formally, we use,
\begin{align}
  \P[z]{\k} &= \prod_i \Pc[z]{k_i}{\k_{\pa{i}}} 
  \\ \intertext{where} \k_\pa{i} &= \b{k_j \text{ for all } j \in \pa{i}}
\end{align}
where, remember \(\pa{i}\) is the set of indices of parents of the \(i\)th latent variable under the generative model, and $\k_\pa{i}$ is the value of $k$ for each of those parents,
\begin{align}
  \k_\pa{i} &= \b{k_j \text{ for all } j \in \pa{i}}.
\end{align}
Note that this quantity is similar to the "backward kernels" in the SMC literature \citep{del2006sequential}, what's different is our approach to computing the quantity, using the source-term trick to avoid the need for explicit backward traversals.
Now, we have the problem of computing the conditionals, \(\Pc[z]{k_i}{\k_{\pa{i}}}\).
We can compute the conditionals from the marginals using Bayes theorem,
\begin{align}
  \label{eq:marg}
  \Pc{k_i}{\k_{\pa{i}}} &= \frac{\P[z]{k_i, \k_{\pa{i}}}}{\sum_{k'_i} \P[z]{k'_i, \k_{\pa{i}}}} \\ 
  \intertext{where}
    \P[z]{k_i, \k_{\pa{i}}} &= \sum_{\k/(k_i, \k_{\pa{i}})} \P[z]{\k}
\end{align}
%where the marginals are given by,
%\begin{align}
%  \label{eq:marg}
%  \P{k_i, \k_{\pa{i}}} &= \sum_{\k/(k_i, \k_{\pa{i}})} \P{\k}
%\end{align}
Again, we can compute these marginals efficiently by differentiating a modified estimate of the marginal likelihood.
This time, we take a tensor-valued \(\J \in \mathbb{R}^{K^{1+\abs{\pa{i}}}}\), where remember \(\abs{\pa{i}}\) is the number of parents of the \(i\)th latent variable under the generative model.
\begin{align}
  \Pmpsamp(z, \J) &= \frac{1}{K^n} \sum_{\k} \frac{\P{x, z^\k}}{\Q{z^\k}} e^{J_{k_i, \k_{\pa{i}}}}
\end{align}

\begin{figure*}[t]
\begin{center}
\includegraphics[width=0.9\textwidth]{figures/IS_per_K_time.pdf}
\end{center}
\caption{Results analogous to those in Fig.~\ref{fig:IS_per_K} but with time on the x-axis. Again, the error bars represent standard errors over 100 runs on the same data but using different random seeds.}
\label{fig:IS_per_K_TIME}
\end{figure*}

As usual, we differentiate with respect to \(\J\) at \(\J=\0\),
\begin{align}
  \nonumber
  \at{\dd{J_{k'_i, \k'_\pa{i}}}}_{\J=\0} &\log \Pmpsamp(z, \J) \\
  \nonumber
  &= \frac{\tfrac{1}{K^n} \sum_\k r_\k(z) \delta_{(k_i, \k_{\pa{i}}), (k'_i, \k'_{\pa{i}})}}{\Pmp(z)}
  \intertext{Here, \(\delta_{(k_i, \k_{\pa{i}}), (k'_i, \k'_{\pa{i}})}\) is a generalisation of the Kronecker delta. It is \(1\) when all the indices match (i.e. \(k_i = k_i'\), and \(\k_{\pa{i}} = \k'_{\pa{i}}\)) and zero otherwise.
  These turn out to be precisely the marginals in Eq.~\eqref{eq:marg},}
  \nonumber
  &= \frac{\tfrac{1}{K^n} \sum_{\k/(k_i, \k_{\pa{i}})} r_\k(z) }{\Pmp(z)} \\ 
  &= \sum_{\mathclap{\k/(k_i, \k_{\pa{i}})}} \P[z]{\k} = \P[z]{k_i, \k_{\pa{i}}}.
\end{align}


\section{Experiments}

\begin{figure*}[t]
\begin{center}
\includegraphics[width=0.9\textwidth]{figures/summary.pdf}
\end{center}
\caption{Comparing the ELBO and predicitve log-likelihood of MP IS with K=15 (red star), global IS with $K=10,000$ (blue star), VI (green lines), IWAE (purple lines) and RWS (orange lines). 
For MP IS and global IS we have plotted the highest K value that we were able to compute for all models, specifically \(K=15\) and \(K=10,000\) respectively. 
Single-sample VI was performed (i.e. \(K=1\)), whereas IWAE and RWS were done using \(K=10\). 
The only iterative method we report for the Occupancy model is RWS, as this model contains discrete latent variables which precludes the use of gradient-based methods such as VI and IWAE.}
\label{fig:results_summary}
\end{figure*}


%We provide empirical results comparing the global and massively parallel importance weighting/sampling methods, as well as single-sample and multi-sample VI (the latter denoted ``IWAE'' hereafter), RWS and HMC baselines.
We considered four datasets: NYC Bus Breakdown, Chimpanzee Prosociality, MovieLens100K and North American Breeding Bird Survey.
NYC Bus Breakdown describes the length of around 150,000 delays to New York school bus journeys, segregated by year, borough, bus company and journey type.
The Chimpanzee Prosociality dataset contains the actions taken in a controlled experiment by 7 chimpanzees given the repeated choice of whether to give food to another chimpanzee or not when receiving food themselves, repeated 12 times on 6 separate occasions per chimpanzee.
MovieLens100K \citep{harper2015movielens} contains 100k ratings of 1682 films from among 943 users. 
Finally, the North American Breeding Bird Survey records the number of sightings of over 700 bird species from 1966-2021 (excluding 2020), along thousands of different road-side routes in the United States and Canada.
We use hierarchical probabilistic graphical models for these datasets, subsampling each dataset except for Chimpanzee Prosociality.
For each dataset we define a generative model $\P{x, z^\k}$ and a factorised proposal $\Q{z^\k}$ which contains the same latents as the generative model, but each latent is independently parameterised by a simple (usually standard Normal) distribution.
Note that we don't necessarily require our proposal $\Q{z^\k}$ to be factorised---this is done only for simplicity.
The full details of these models are described in Appendix section \ref{app:exps} and a repository containing the code needed to reproduce these experiments can be found at \url{https://github.com/sambowyer/MPIS}.%\textbf{Experimental Datasets And Models}.

We focus on two metrics to assess the quality of importance weighting and importance sampling: the importance weighted evidence lower bound (ELBO) and predictive log-likelihood.
We use predictive log-likelihood to evaluate importance sample quality by drawing latent samples conditioned on observed `training' data and use these to predict unobserved `test' data. 
Therefore a higher predictive log-likelihood would correspond to higher-quality sampled latents (which are closer to the true posterior).
The quality of importance weighting is measured by the tightness of the importance weighted ELBO, a bound on the model evidence.
This is widely accepted as a good proxy for the quality of the importance weighted posterior estimate \citep{geffner2022variational, agrawal2021amortized}, as was confirmed empirically in \cite{domke_divide_2019}.
This is because it can be interpreted as a single-sample ELBO with an improved approximate posterior \citep{cremer2017reinterpreting, bachman_training_2015}, and the single-sample ELBO directly measures the discrepancy between the true and approximate posterior:
\begin{align}
  \text{ELBO} = \log P(x) - D_{KL}(Q(z)|| P(z|x)).
\end{align}

%For global and massively parallel importance sampling, we use the prior as a proposal.
%For VI, IWAE and RWS, we used a factorised approximate posterior and optimized using Adam with learning rates ranging from $0.3$ to $0.01$ (learning rates faster than $0.3$ were unstable).
%For HMC, we used the GPU-enabled JAX-based \cite{jax2018github} NUTS \citep{hoffman2014no} implementation from PyMC \citep{salvatier_probabilistic_2016} with 10 warmup iterations and a target acceptance rate of 0.8. 
%All experimental results are taken as a mean over 100 runs each with different random seeds, with iterative models performing 250 iterations in each run.

We began by comparing massively parallel and global importance sampling (Fig.~\ref{fig:IS_per_K}) on each dataset using the pre-defined (and fixed) generative and proposal distributions.
We found that for a given value of $K$, the ELBO (top) and predictive log-likelihood (bottom) is far better for massively parallel than global importance sampling, as massively parallel sampling considers all $K^n$ combinations of all samples of all latent variables, while global importance sampling considers only $K$ samples from the full joint latent space.

However, for a fixed $K$, the time taken for massively parallel importance sampling is much larger than the time taken for global importance sampling (and this also meant that we used a far smaller range of $K$ for massively parallel importance sampling).
We therefore also plotted the ELBO (top) and predictive log-likelihood (bottom) with time on the x-axis (Fig.~\ref{fig:IS_per_K_TIME}).
This again shows that massively parallel importance sampling gives large improvements in ELBO and predictive log-likelihood for a fixed time budget.

%Note that for massively parallel importance sampling, the maximum 
%In the first two plots, we see dramatic improvements against standard ``global'' importance sampling for a fixed $K$ (Fig.~\ref{fig:IS_per_K}), and for a fixed time (Fig.~\ref{fig:IS_per_K_TIME}).
%We ran these experiments for \(K\in\{3,5,10,15,20,25,30,45,60,80,10^2,3\times10^2,10^3,3\times10^3,10^4,3\times10^4,10^5,3\times10^5,10^6,3\times10^6,10^7\}\), however, only global IS on the chimpanzee dataset was able to run with each of these values (and no higher) --- failures were mostly due to the large amount of memory required or due to numerical instability in the models.
%A table of the maximum K values achieved by each IS method on each model is presented below in Table~\ref{tab:IS_K}. 
%Note that whilst e.g. Chimpanzee was able to reach the highest K for global IS, it only reached \(K=20\) for MP IS.
%This is due in large part to the fact that the Chimpanzee model has multiple latent variables which `skip' plates (see Appendix section \ref{app:chimpanzees}, Fig.~\ref{fig:chimpanzees_gm} for the full specification) which can be costly in terms of memory.
%
%\begin{table}[h]
%    \centering
%    \begin{tabular}{lll}
%        \hline
%        Model         & Global IS  & MP IS \\ \hline
%        Bus Breakdown & \(3\times10^6\) & 45    \\
%        Chimpanzees   & \(10^7\)   & 20    \\
%        MovieLens     & \(3\times10^5\) & 100   \\
%        Occupancy     & \(3\times10^4\) & 15    \\ \hline
%    \end{tabular}
%    \caption{Maximum K values achieved on each dataset.}
%    \label{tab:IS_K}
%\end{table}    



Finally, we compared massively parallel importance sampling against iterative methods such as VI, IWAE and RWS that learn a better proposal (Fig.~\ref{fig:results_summary}).
These methods use the factorised proposal distribution $\Q{z^\k}$ as their initial approximate posterior and iteratively update the parameters of each latent variable's proposal distribution (in most cases this is the mean and variance of a Gaussian---see Appendix section \ref{app:exps} for full details).
VI and IWAE update the proposal parameters in order to maximise the global ELBO, the former with $K=1$ and the latter with $K>1$ (in particular we used $K=10$). 
RWS performs a maximum-likelihood update on the parameters of $Q$ using posterior samples obtained by reweighting the proposal samples with the importance weights $r_k(z)$ given in Eq.~\ref{eq:rglobal}, also with $K=10$.
In each case, this update is performed using the Adam optimizer with default hyperparameters and tuned learning rates which are discussed in Appendix sections \ref{app:exp_det} and \ref{app:more_results}.

Of course, learning a good proposal is critical to the effectiveness of practical inference methods.
Thus, comparing massively parallel importance sampling against these methods is not really fair, as massively parallel importance sampling is a ``one-shot'' method that is forced to use an extremely poor proposal (the prior).
To do a fair comparison, we would need to build massively parallel importance sampling into an iterative method that iteratively improved the proposal (but this is a considerable endeavour which is out of scope for the present work).
Even with these fundamental limitations, massively parallel importance sampling seems to fill out an important part of the speed-accuracy tradeoff.
In particular, massively parallel importance sampling gives us good results extremely rapidly, while iterative methods take at least an order of magnitude longer to reach similar ELBOs and predictive log-likelihoods (Fig.~\ref{fig:results_summary}).
Here, we plotted a single line for each method (VI, IWAE and RWS), corresponding to the best learning rate (see Appendix~\ref{app:more_results}).
We considered including HMC in Fig.~\ref{fig:results_summary}, but it turned out not to be possible because HMC methods took longer to return even a single sample than the time plotted. 
These long timescales arise because HMC requires many gradient evaluations for a single HMC step, and these steps are inherently sequential and therefore cannot make good use of GPU parallelism. 
And this is even before we consider the need for adaptation and burn-in necessary in all MCMC methods.

%In the third plot (Fig.~\ref{fig:results_summary}) we see that MP IS is able to achieve surprisingly competitive results in an extremely small amount of time when compared to the iterative methods. 
%In particular, note that whilst HMC is able to achieve very good predictive log-likelihoods, it takes a (comparably) very long time to begin producing samples, having to perform 10 warm-up iterations to tune itself.
%In Appendix section \ref{app:more_results}, we present more experiment results with every learning rate tried and consider other metrics for assessing the quality of posterior moment estimates computed using the approach laid out in the Methods section.


% Note the extremely poor performance of VI and HMC is because these are iterative methods that take many steps, and hence a long time to reach good performance.
% HMC is especially problematic on these short timescales, because it needs many gradient steps before it returns even a single posterior sample (and this is ignoring burn-in and adaptation phases).
% In contrast, importance sampling gets a reasonable answer much more quickly, as it just a single, non-iterative computation.


%
% \begin{figure*}[!ht]
% \begin{center}
%   \tikz{
%     % nodes
%      \node[obs] (rating) {\(\mathrm{Rating}_{mn}\)};%
%      \node[latent,left=of rating] (peruser) {\(\mathrm{Per \ user \ mean \ } \mathbf{z}_m\)}; %
%      \node[latent, left=of peruser, yshift=-.8cm] (psi) {\(\psi\)};
%     \node[latent, left=of peruser, yshift=.8cm] (mu) {\(\mu\)};
%     % plate
%      \plate [inner sep=.2cm] {plate2} {(rating)} {\(\mathrm{N}\) Films}; %
%      \plate [inner sep=.2cm] {plate1} {(peruser)(plate2)} {\(\mathrm{M}\) Users}; %
%     edges
%      \edge {psi,mu} {peruser}
%      \edge {peruser} {rating}}
% \caption{Graphical model for the MovieLens dataset}
% \label{fig:movielens_gm}
% \end{center}
% \end{figure*}



%In order to obtain results with comparable computation times we test the massively parallel importance samples for \(K \in \{3,10,30\}\) and the global importance samples for \(K \in \{3,10,30,100,300,1000,3000,10000\}\).

%Our results (Fig.~\ref{fig:movielens_mom}) show tighter evidence lower bounds and higher predictive log-likelihoods using massively parallel importance samples than using global importance samples for a given \(K {>} 1\), with better results obtained in less time once \(K\) is large enough in the massively parallel approach (in this case, for \(K {\geq} 10\)).
%We also see much lower variance in our importance weighted estimates for the expectation of per user means, \(\mathbf{z}_m\), for every $K {>} 1$ and again see this reduced variance achieved in less time than the global method for large enough $K$ (here, \(K {\geq} 3\)).
%Using data sampled from the model in order to obtain ground truth values of \(\mathbf{z}_m\), we found the mean squared error of the \(\mathbf{z}_m\) expectation estimates was far lower for all \(K {>} 1\) when those estimates were calculated using massively parallel importance weights rather than global importance weights, with similar behaviour in the time required to reduce the MSE as with reducing the variance.


% We want to compare our methods vs standard global methods when using the prior as a proposal for:
% \begin{itemize}
%   \item The quality of importance weighted estimates of the posterior expectations.
%   We can use plots with K or time on the x-axis.
%   (As \(K\) increases, the time required for our methods will increase more rapidly than the time required for global methods, but our estimates will improve more rapidly too).
%   On the y-axis, we could have:
%   \begin{itemize}
%     \item the posterior moment estimate itself (this should converge to a fixed value as we increase \(K\)).
%     \item the variance in the estimator.
%   \end{itemize}
%   \item The quality of the importance samples, as estimated by predictive log-likelihoods.
% \end{itemize}

% Datasets could include:
% \begin{itemize}
%   \item Movielens
%   \item Radon
%   \item COVID model
% \end{itemize}

%However, importance weighting breaks down in larger models.
%\citet{chatterjee2018sample} considered a model with data, $x$, latent variables $z$ a true posterior, $\Pc{z}{x}$ and a proposal $\Q{z}$. 
%They showed that the number of samples required to accurately approximate the true posterior scales as \(\exp\b{D_{\text{KL}}\b{\Pc{z}{x}|| \Q{z}}}\).

\section{Related work}
\label{sec:related}
There is a considerable body of work in the discrete graphical model setting that computes posterior expectations, marginals and samples \citep{dawid1992applications,pfeffer2005design,bidyuk2007cutset,geldenhuys2012probabilistic,gogate2012importance,claret2013bayesian,sankaranarayanan2013static,goodman2014design,gehr2016psi,narayanan2016probabilistic,albarghouthi2017fairsquare,dechter2018abstraction,wang2018pmaf,obermeyer2019tensor,holtzen2020scaling}.
Our work differs in two respects.
First, our massively parallel methods are not restricted to discrete graphical models, but can operate with arbitrary continuous latent variables and graphs with a mixture of continuous and discrete latent variables.
Second, this prior work involves complex implementations that, in one sense or another, ``proceed by recording an adjoint compute graph alongside the forward computation and then traversing the adjoint graph backwards starting from the final result of the forward computation'' \citep{obermeyer2019tensor}.
The forward computation is reasonably straightforward: it is just a big tensor product that can be computed efficiently using pre-existing libraries such as opt-einsum, and results in (an estimate of) the marginal likelihood.
However, the backward traversal is much more complex, if for no other reason than the need to implement separate traversals for each operation of interest (computing posterior expectations, marginals and samples). 
Additionally, these traversals need to correctly handle all special cases, including optimized implementations of plates and timeseries.
Importantly, optimizing the forward computation is usually quite straightforward while implementing an optimized backward traversal is far more complex.
For instance, the forward computation for a timeseries involves a product of \(T\) matrices arranged in a chain.
Naively computing this product on GPUs is very slow, as it requires \(T\) separate matrix multiplications.
However, it is possible to massively optimize this forward computation, converting \(\mathcal{O}(T)\) to \(\mathcal{O}(\log(T))\) tensor operations by multiplying adjacent pairs of matrices in a single batched matrix multiplication operation.
This optimization is straightforward in the forward computation.
However, applying this optimization as part of the backward computation is far more complex (see \citealp{corenflos2022sequentialized} for details).
This complexity, along with similar complexity for other important optimizations such as plates, is prohibitive for academic teams implementing e.g.\ new probabilistic programming languages.
Our key contribution is thus to provide a much simpler approach to directly compute posterior expectations, marginals and samples by differentiating through the forward computation, without having to hand-write and hand-optimize backward traversals.

There is work on fitting importance weighted autoencoders \citep[IWAE; ][]{burda2015importance} and reweighted wake-sleep \citerws[RWS; ]{} in the massively parallel setting \citep{aitchison2019tensor,geffner2022variational,ml} for general probabilistic models.
However, this work only provides methods for performing massively parallel updates to approximate posteriors (e.g.\ by optimizing a massively parallel ELBO). 
This work does not provide a method to individually reweight the samples to provide accurate posterior expectations, marginals and samples.
Instead, this previous work simply takes the learned approximate posterior as an estimate of the true posterior, and does not attempt to correct for inevitable differences between approximate and true posterior.

Massively parallel importance sampling itself, bears similarities to e.g.\ particle filtering/SMC methods \citepf{} that have been generalised to arbitrary graphical models and where the resampling step has been eliminated.
However, our contribution is not massively parallel importance sampling in itself.
Instead, our contribution is the simple method, using autodiff to differentiate through a marginal likelihood estimator, for computing posterior expectations, marginals and samples without requiring the implementation of complex backwards traversals, and this has not appeared in past work.
%However, our massively parallel methods allow a broad class of proposal distributions in a broad class of probabilistic models. The divide-and-conquer algorithm \citep{kuntz2021divide} proposes a sequential monte-carlo estimator that is similar to ours, but requires a more complex implementation involving reweighting steps.
%In contrast to the previously mentioned methods, we develop a new approach to computing posterior expectations, marginals and samples by differentiating through a slightly modified marginal likelihood estimator.


\section{Conclusion}
We gave a new and far simpler method for computing posterior moments, marginals and samples in massively parallel importance sampling based on differentiating a slightly modified marginal likelihood estimator.

The method has limitations, in that while it is considerably more effective than e.g.\ VI, RWS and global importance sampling, it is more complex.
Additionally, at least a naive implementation may be quite costly in terms of memory consumption, limiting how the number of importance samples we can draw for each variable. 
That said, it should be possible to eliminate almost all of this overhead by careful optimizations to avoid allocating large intermediate tensors, following the strategy in KeOps \citep{charlier2021keops}.

For future work, we intend to use the contributions from this paper in developing an iterative inference algorithm, similar to VI and RWS, that uses massively parallel importance sampling, as well as developing a massively parallel probabilistic programming language.

%Furthermore, we have demonstrated a simple method for computing posterior expectations, marginals and samples using automatic differentiation to avoid complicated explicit backward computations that often rely on restricting the class of probabilistic model at hand.
%We showed that using massively parallel importance weights within this method led to more accurate and lower variance posterior expectation estimates, again with fewer latent samples needing to be drawn and in less time than when using global importance samples.



% \begin{contributions} % will be removed in pdf for initial submission 
% 					  % (without ‘accepted’ option in \documentclass)
%                       % so you can already fill it to test with the
%                       % ‘accepted’ class option
%     Briefly list author contributions. 
%     This is a nice way of making clear who did what and to give proper credit.
%     This section is optional.

%     H.~Q.~Bovik conceived the idea and wrote the paper.
%     Coauthor One created the code.
%     Coauthor Two created the figures.
% \end{contributions}

\begin{acknowledgements} % will be removed in pdf for initial submission,
						 % (without ‘accepted’ option in \documentclass)
                         % so you can already fill it to test with the
                         % ‘accepted’ class option
    Sam Bowyer is supported by the UKRI Engineering and Physical Sciences Research Council via the COMPASS Centre for Doctoral Training at the University of Bristol (EP/S023569/1). 
    Thomas Heap is also supported by the UKRI Engineering and Physical Sciences Research Council. 

    This work was possible thanks to the computational facilities of the University of Bristol's Advanced Computing Research Centre---\url{http://www.bris.ac.uk/acrc/}.
    We would like to thank Dr. Stewart for funding compute resources used in this project.

\end{acknowledgements}

% References
\bibliography{refs}


\newpage
\onecolumn
\appendix
\title{Using Autodiff to Estimate Posterior Moments, Marginals and Samples\\(Supplementary Material)}
\maketitle

\section{Derivations}
\label{app:deriv}
\subsection{\textGlob{} Importance Sampling}
\label{app:is_glob}
Here, we give the derivation for standard global importance sampling.
Ideally we would compute moments using the true posterior, $\P{z| x}$,
\begin{align}
  m_\post &= \E[{\P{z^k| x}}]{m(z^k)}.
  \intertext{However, the true posterior is not known.
    Instead, we write down the moment under the true posterior as an integral,}
  m_\post &= \int dz^k \P{z^k| x}m(z^k).
  \intertext{Next, we multiply the integrand by $1=\Q{z^k}/\Q{z^k}$,}
  m_\post &= \int dz^k \Q{z^k} \frac{\P{z^k| x}}{\Q{z^k}} m(z^k).
  \intertext{Next, the integral can be written as an expectation,}
  m_\post &= \E[\Q{z^k}]{\frac{\P{z^k| x}}{\Q{z^k}} m(z^k)}.
\end{align}
It looks like we should be able to estimate $m_\post$ by sampling from our approximate posterior, $\Q{z^k}$.
However, this is not yet possible, as we are not able to compute the true posterior, $\P{z^k| x}$.
We might consider using Bayes theorem,
\begin{align}
  \P{z^k| x} &= \frac{\P{z^k, x}}{\P{x}},
\end{align}
but this requires computing an intractable normalizing constant, 
\begin{align}
  \P{x} &= \int dz^k \P{z^k, x}.
\end{align}
Instead, we use an unbiased, importance-sampled estimate of the normalizing constant, $\Pglob(z)$ (Eq.~\ref{eq:Pglobal}).
\citet{burda2015importance} showed that in the limit as $K \rightarrow \infty$, $\Pglob(z)$ approaches $\P{x}$.
Using this estimate of the marginal likelihood our moment estimate becomes,
\begin{align}
  m_\glob &= \E[\Q{z^k}]{\frac{\frac{\P{z^k, x}}{\Q{z^k}}}{\Pglob(z)} m(z^k)}.
\end{align}
Using $r_k(z)$ (Eq.~\ref{eq:rglobal}), we can write this expression as,
\begin{align}
  m_\glob &= \E[{\Q[\phi]{z}}]{\frac{r_k(z)}{\Pglob(z)}m(z^k)}.
  %\Dglob &= \E[{\Q[\phi]{z|x}}]{\frac{1}{K} \sum_k \frac{r_k(z)}{\Pglob(z)}\Delta(z^k)}.
\end{align}
The approximate posterior and generative probabilities are the same for different values of $k$, so we can average over $k$, which gives Eq.~\eqref{eq:mglobal} in the main text.

\subsection{Massively Parallel Importance Sampling}
\label{app:is_tmc}

Inspired by the global importance sampling derivation, we consider massively parallel importance sampling.
In the global importance sampling derivation, the key idea was to show that the estimator was unbiased for each of the $K$ samples, $z^k$, in which case the average over all $K$ samples is also unbiased.
In massively parallel importance sampling, we use the same idea, except that we now have $K^n$ samples, denoted $z^\k$.
As before,
\begin{align}
  m_\post &= \E[\P{z^\k| x}]{m(z^\k)} = \int dz^\k \P{z^\k| x}m(z^\k).\\
  \intertext{Again, we multiply and divide by the approximate posterior for $z^\k$.  In the massively parallel setting, we specifically use $\prod_i \Qc{z_i^{k_i}}{z_\qa{i}}$,}
  \label{eq:dpost_int}
  m_\post &= \int dz^\k \b{\prod_i \Qc{z_i^{k_i}}{z_\qa{i}}} \frac{\P{z^\k| x}}{\prod_i \Qc{z_i^{k_i}}{z_\qa{i}}}  m(z^\k).
\end{align}
Overall, our goal is to convert the integral over the indexed latent variables in Eq.~\eqref{eq:dpost_int} into an integral over the full latent space, $z$, so that it can be written as an expectation over the proposal, $\Q{z}$.
To do that, we need to introduce the concept of non-indexed latent variables.
These are all samples of the latent variables, except for the ``indexed'', or $k$th sample.
For the $i$th latent variable, the non-indexed samples are,
\begin{align}
  z_i^{/k_i} &= \b{z_i^{1},\dotsc,z_i^{k_i-1},z_i^{k_i+1},\dotsc,z_i^K} \in \mathcal{Z}_i^{K-1}.
  \intertext{We can also succinctly write the non-indexed samples of all latent variables as,}
  z^{/\k} &= \b{z_1^{/k_1}, z_2^{/k_2}, \dotsc, z_n^{/k_n}} \in \mathcal{Z}^{K-1}.
\end{align}

The joint distribution over the non-indexed latent variables, conditioned on the indexed latent variables integrates to $1$,
\begin{align}
  \label{eq:intQz/k}
  1 &= \int dz^{/\k} \prod_i \Qc{z_i^{/k_i}}{z_i^{k_i}, z_\qa{i}},
\end{align}
We use this to multiply the integrand in Eq.~\eqref{eq:dpost_int},
\begin{align}
  m_\post &= \int dz^\k \b{\prod_i \Qc{z_i^{k_i}}{z_\qa{i}}} \frac{\P{z^\k| x}}{\prod_i \Qc{z_i^{k_i}}{z_\qa{i}}}  m(z^\k) \int dz^{/\k} \prod_i \Qc{z_i^{/k_i}}{z_i^{k_i}, z_\qa{i}}.
\end{align}
Next, we merge the integrals over over $z^{\k}$ and $z^{/\k}$ to form one integral over $z$, 
\begin{align}
  m_\post &= \int dz \; \Q{z} \frac{\P{z^\k| x}}{\prod_i \Qc{z_i^{k_i}}{z_\qa{i}}} m(z^\k).
\end{align}
This integral can be written as an expectation,
\begin{align}
  m_\post &= \E[{\Q{z}}]{\frac{\P{z^\k| x}}{\prod_i \Q{z_i^{k_i}| z_{\pa{i}}}} m(z^\k)}.
\end{align}
As in the derivation for global importance sampling, it looks like we might be able to estimate this by sampling from $\Q{z|x}$, but this does not yet work as we do not yet have a form for the posterior.
Again, we could compute the posterior using Bayes theorem,
\begin{align}
  \P{z^\k| x} &= \tfrac{\P{z^\k, x}}{\P{x}},
\end{align}
but we cannot compute the model evidence,
\begin{align} 
  \P[\theta]{x} = \int dz^\k \P[\theta]{z^\k, x}.
\end{align}
As in the global importance sampling section, we instead use an estimate of the marginal likelihood.
Here, we use a massively parallel estimate, $\Pmp(z)$,
\begin{align}
  m_\mp &= \E[{\Q{z|x}}]{\frac{\frac{\P{z^\k, x}}{\prod_i \Q{z_i^{k_i}| z_{\pa{i}}}}}{\Pmp(z)} m(z^\k)}.
\end{align}
Again, we use $r_\k(z)$ (Eq.~\ref{eq:Prmp}),
\begin{align}
  m_\mp &= \E[{\Q{z}}]{\frac{r_\k(z)}{\Pmp(z)}m(z^\k)}.
\end{align}
So the value for a single set of latent variables, $z^\k$, has the right expectation.
Thus, averaging over all $K^n$ settings of $\k$, we get the unbiased estimator in the main text, (Eq.~\ref{eq:mmp}).

% \newpage
\section{Algorithms}
\label{app:algo}

\begin{algorithm}[t]
\caption{Expectations using the source term trick}\label{alg:mom_Exp}
\begin{algorithmic}
\Require{Data $x$, Generative Model $\mathrm{P}(\mathbf{x} | z)$, Prior $\mathrm{P}(\mathbf{z})$, Proposal $\mathrm{Q}_\mathrm{MP}$, $K \geq 1$, function whose expectation we want to compute $m(z^\mathbf{k})$}

\For{$i \gets 1$ to $n$}
    \State Sample $z_i \sim \Qmpc{z_i}{z_j \text{ for all } j \in \qa{i}}$
    \State $z \gets \{z_1,...,z_{i-1}\} \cup z_i$
    \State $f^i_{k_i, \k_{\pa{i}}}(z) \gets \frac{\Pc{z^{k_i}_i}{z_j^{k_j} \text{ for all } j \in \pa{i}}}{\Qmp{z_i^{k_i}| x, z_j \text{ for all } j \in \qa{i}}}$
\EndFor
\State $f^x_{\k_{\pa{x}}}(z) \gets \Pc{x}{z_j^{k_j} \text{ for all } j \in \pa{x}}$
\State $r_\mathbf{k}(z) \gets f^x_{\k_{\pa{x}}}(z) \prod_i f^i_{k_i,\k_{\pa{i}}}(z)$
\State $J \gets 0$
\State $\Pmpexp(z, J) \gets \tfrac{1}{K^n} \sum_{\k^n} r_\mathbf{k}(z) e^{Jm(z^\mathbf{k})}$\\
\Return $m_\mp(z) \gets \at{\dd{J}}_{J=0} \log \Pmpexp(z, J)$
\end{algorithmic}
\end{algorithm}

\begin{algorithm}[t]
\caption{Marginal importance weights using the source term trick}\label{alg:mom_MIW}
\begin{algorithmic}
\Require{Data $x$, Generative Model $\mathrm{P}(\mathbf{x} | z)$, Prior $\mathrm{P}(\mathbf{z})$, Proposal $\mathrm{Q}_\mathrm{MP}$, $K \geq 1$, index of latent for which to calculate weight: $i$}

\For{$i \gets 1$ to $n$}
    \State Sample $z_i \sim \Qmpc{z_i}{z_j \text{ for all } j \in \qa{i}}$
    \State $z \gets \{z_1,...,z_{i-1}\} \cup z_i$
    \State $f^i_{k_i, \k_{\pa{i}}}(z) \gets \frac{\Pc{z^{k_i}_i}{z_j^{k_j} \text{ for all } j \in \pa{i}}}{\Qmp{z_i^{k_i}| x, z_j \text{ for all } j \in \qa{i}}}$
\EndFor
\State $f^x_{\k_{\pa{x}}}(z) \gets \Pc{x}{z_j^{k_j} \text{ for all } j \in \pa{x}}$
\State $r_\mathbf{k}(z) \gets f^x_{\k_{\pa{x}}}(z) \prod_i f^i_{k_i,\k_{\pa{i}}}(z)$
\State $\mathbf{J} \gets \mathbf{0} \in \mathbb{R}^K$
\State $\Pmpmarg(z, \mathbf{J}) \gets \tfrac{1}{K^n} \sum_{\k^n} r_\mathbf{k}(z) e^{J_{k_{i}}}$\\
\Return $w^i_{k_i} \gets \at{\dd{J_{k_{i}}}}_{J=0} \log \Pmpmarg(z, \mathbf{J})$
\end{algorithmic}
\end{algorithm}

\begin{algorithm}[t]
\caption{Joint distributions using the source term trick}\label{alg:mom_CDIS}
\begin{algorithmic}
\Require{Data $x$, Generative Model $\mathrm{P}(\mathbf{x} | z)$, Prior $\mathrm{P}(\mathbf{z})$, Proposal $\mathrm{Q}_\mathrm{MP}$, $K \geq 1$, index of latent for which to compute conditional distribution: $i$}

\For{$i \gets 1$ to $n$}
    \State Sample $z_i \sim \Qmpc{z_i}{z_j \text{ for all } j \in \qa{i}}$
    \State $z \gets \{z_1,...,z_{i-1}\} \cup z_i$
    \State $f^i_{k_i, \k_{\pa{i}}}(z) \gets \frac{\Pc{z^{k_i}_i}{z_j^{k_j} \text{ for all } j \in \pa{i}}}{\Qmp{z_i^{k_i}| x, z_j \text{ for all } j \in \qa{i}}}$
\EndFor
\State $f^x_{\k_{\pa{x}}}(z) \gets \Pc{x}{z_j^{k_j} \text{ for all } j \in \pa{x}}$
\State $\mathbf{J} \gets \mathbf{0} \in \mathbb{R}^{K^{1+|\text{pa}(i)|}}$
\State $r_\mathbf{k}(z) \gets f^x_{\k_{\pa{x}}}(z) \prod_i f^i_{k_i,\k_{\pa{i}}}(z)$
\State $\Pmpsamp(z, \mathbf{J}) \gets \frac{1}{K^n} \sum_{\k} \frac{\P{x, z^\k}}{\Q{z^\k}} e^{J_{k_i, \k_{\pa{i}}}}$\\
\Return $\P{k_i, \k_{\pa{i}}} \gets \at{\dd{J_{k_i, \k_\pa{i}}}}_{\J=\0} \log \Pmpsamp(z, \mathbf{J})$
\end{algorithmic}
\end{algorithm}

%\begin{algorithm}
%\caption{Expectations using Source Trick}\label{alg:mom_Exp}
%\begin{algorithmic}
%\Require{Data $x$, Generative Model $\mathrm{P}(\mathbf{x} | z)$, Prior $\mathrm{P}(\mathbf{z})$, Proposal $\mathrm{Q}_\mathrm{MP}$, $K \geq 1$, \\ \hspace{30pt} function whose expectation we want to compute: $m(z^\mathbf{k})$}
%\FOR{$i \gets 1$ to $n$}
%    \STATE Sample $z_i \sim \Qmpc{z_i}{z_j \text{ for all } j \in \qa{i}}$
%    \STATE $z \gets \{z_1,...,z_{i-1}\} \cup z_i$
%    \STATE $f^i_{k_i, \k_{\pa{i}}}(z) \gets \frac{\Pc{z^{k_i}_i}{z_j^{k_j} \text{ for all } j \in \pa{i}}}{\Qmp{z_i^{k_i}| x, z_j \text{ for all } j \in \qa{i}}}$
%\ENDFOR
%\STATE $f^x_{\k_{\pa{x}}}(z) \gets \Pc{x}{z_j^{k_j} \text{ for all } j \in \pa{x}}$
%\STATE $r_\mathbf{k}(z) \gets f^x_{\k_{\pa{x}}}(z) \prod_i f^i_{k_i,\k_{\pa{i}}}(z)$
%\STATE $J \gets 0$
%\STATE $\Pmpexp(z, J) \gets \tfrac{1}{K^n} \sum_{\k^n} r_\mathbf{k}(z) e^{Jm(z^\mathbf{k})}$
%\STATE $m_\mp(z) \gets \at{\dd{J}}_{J=0} \log \Pmpexp(z, J)$
%\end{algorithmic}
%\end{algorithm}



%\begin{algorithm}[tb]
%\caption{Marginal importance weights using Source Trick}\label{alg:mom_MIW}
%\begin{algorithmic}
%\Require{Data $x$, Generative Model $\mathrm{P}(\mathbf{x} | z)$, Prior $\mathrm{P}(\mathbf{z})$, Proposal $\mathrm{Q}_\mathrm{MP}$, $K \geq 1$, \\ \hspace{30pt} index of latent for which to calculate weight: $i$}
%
%\FOR{$i \gets 1$ to $n$}
%    \STATE Sample $z_i \sim \Qmpc{z_i}{z_j \text{ for all } j \in \qa{i}}$
%    \STATE $z \gets \{z_1,...,z_{i-1}\} \cup z_i$
%    \STATE $f^i_{k_i, \k_{\pa{i}}}(z) \gets \frac{\Pc{z^{k_i}_i}{z_j^{k_j} \text{ for all } j \in \pa{i}}}{\Qmp{z_i^{k_i}| x, z_j \text{ for all } j \in \qa{i}}}$
%\ENDFOR
%\STATE $f^x_{\k_{\pa{x}}}(z) \gets \Pc{x}{z_j^{k_j} \text{ for all } j \in \pa{x}}$
%\STATE $r_\mathbf{k}(z) \gets f^x_{\k_{\pa{x}}}(z) \prod_i f^i_{k_i,\k_{\pa{i}}}(z)$
%\STATE $\mathbf{J} \gets \mathbf{0} \in \mathbb{R}^K$
%\STATE $\Pmpmarg(z, \mathbf{J}) \gets \tfrac{1}{K^n} \sum_{\k^n} r_\mathbf{k}(z) e^{J_{k_{i}}}$
%\STATE $w^i_{k_i} \gets \at{\dd{J_{k_{i}}}}_{J=0} \log \Pmpmarg(z, \mathbf{J})$
%\end{algorithmic}
%\end{algorithm}


%\begin{algorithm}[tb]
%\caption{Conditional distributions using Source Trick}\label{alg:mom_CDIS}
%\begin{algorithmic}
%\Require{Data $x$, Generative Model $\mathrm{P}(\mathbf{x} | z)$, Prior $\mathrm{P}(\mathbf{z})$, Proposal $\mathrm{Q}_\mathrm{MP}$, $K \geq 1$, \\ \hspace{30pt} index of latent for which to compute conditional distribution: $i$}
%
%\FOR{$i \gets 1$ to $n$}
%    \STATE Sample $z_i \sim \Qmpc{z_i}{z_j \text{ for all } j \in \qa{i}}$
%    \STATE $z \gets \{z_1,...,z_{i-1}\} \cup z_i$
%    \STATE $f^i_{k_i, \k_{\pa{i}}}(z) \gets \frac{\Pc{z^{k_i}_i}{z_j^{k_j} \text{ for all } j \in \pa{i}}}{\Qmp{z_i^{k_i}| x, z_j \text{ for all } j \in \qa{i}}}$
%\ENDFOR
%\STATE $f^x_{\k_{\pa{x}}}(z) \gets \Pc{x}{z_j^{k_j} \text{ for all } j \in \pa{x}}$
%\STATE $\mathbf{J} \gets \mathbf{0} \in \mathbb{R}^{K^{1+|\text{pa}(i)|}}$
%\STATE $r_\mathbf{k}(z) \gets f^x_{\k_{\pa{x}}}(z) \prod_i f^i_{k_i,\k_{\pa{i}}}(z)$
%\STATE $\Pmpsamp(z, \mathbf{J}) \gets \frac{1}{K^n} \sum_{\k} \frac{\P{x, z^\k}}{\Q{z^\k}} e^{J_{k_i, \k_{\pa{i}}}}$
%\State $\P{k_i, \k_{\pa{i}}} \gets \at{\dd{J_{k_i, \k_\pa{i}}}}_{\J=\0} \log \Pmpsamp(z, \mathbf{J})$
%\end{algorithmic}
%\end{algorithm}

\newpage
\section{Experimental Datasets And Models}
\label{app:exps}

\subsection{Bus Delay Dataset}
\label{app:bus}

In this experiment, we model the length of delays to New York school bus journeys,\footnote{Dataset: \url{data.cityofnewyork.us/Transportation/Bus-Breakdown-and-Delays/ez4e-fazm}\\ Terms of use: \url{//opendata.cityofnewyork.us/overview/\#termsofuse}} working with a dataset supplied by the City of New York \citep{nyc2023bus}.
Our goal is to predict the length of a delay, based on the year, $y$, borough, $b$, bus company, $c$ and journey type, $j$.
Specifically, our data includes the years \(2015-2022\) inclusive and covers the five New York boroughs (Brooklyn, Manhatten, The Bronx, Queens, Staten Island) as well as some surrounding areas (Nassau County, New Jersey, Connecticut, Rockland County, Westchester). 
There are 57 bus companies, and 6 different journey types (e.g.\ pre-K/elementary school route, general education AM/PM route etc.)
We take $I=60$ delayed buses in each borough and year, and take $Y=3$ years and $B=3$ boroughs.
We then split along the \(I\) dimension to get two equally sized train and test sets.
Thus, each delay is uniquely identified by the year, $y$, the borough, $b$, and the index, $i$, giving $\mathrm{delay}_{ybi}$
The delays are recorded as an integer number of minutes and we discard any entries greater than \(130\) minutes. 

We have a hierarchical latent variable model describing the impact of each of the features (year, borough, bus company and journey type) on the length of delays.
Specifically, the integer delay is modelled with a Negative Binomial distribution, with fixed total count of 131.
The expected delay length is controlled by a logits latent variable, $\mathrm{logits}_{ybi}$, with one logits for each delayed bus.
The logits is a sum of three terms: one for the borough and year jointly, one for the bus company and one for the journey type.
Each of these three terms is themselves a latent variable that must be inferred.

First, we have a term for the year and borough, $\mathrm{YearBoroughWeight}_{yb}$, which has a hierarchical prior.
Specifically, we begin by sampling a global mean and variance, $\mathrm{GlobalMean}$ and $\mathrm{GlobalVariance}$.
Then for each year, we use $\mathrm{GlobalMean}$ and $\mathrm{GlobalVariance}$ to sample a mean for each year, $\mathrm{YearMean}_y$. 
Additionally, we sample a variance for each year, $\mathrm{YearVariance}_y$.
Then we sample a $\mathrm{YearBoroughWeight}_{yb}$ from a Gaussian distribution with a year-dependent mean, $\mathrm{YearMean}_y$, and variance $\exp(\mathrm{BoroughVariance}_b)$.

Next, the weights for the bus company and journey type are very similar.
Specifically, we have one latent weight for each bus company, $\mathrm{CompanyWeight}_c$, with $c\in\{1,\dotsc,57\}$, and for each journey type, $\mathrm{JourneyTypeWeight}_j$, with $j\in\{1,\dotsc,6\}$.
We have a table identifying the bus company, $b_{ybi}$, and journey type, $j_{ybi}$, for each delayed bus journey (remember that a particular delayed bus journey is uniquely identified by the year, $y$, borough, $b$ and index $i$).
In $\mathrm{logits_{ybi}}$ we use these tables to pick out the right company and journey type weight for that particular delayed bus journey, $\mathrm{CompanyWeight}_{c_{ybi}}$ and $\mathrm{JourneyTypeWeight}_{j_{ybi}}$.
The final generative model is defined by
%
%
%
%, based on the school year, borough, and the ID of the bus in question.
%These three levels motivate a hierarchical model in which a global mean and variance is sampled, which are used to sample a yearly mean which, along with a sampled borough variance, is used to sample an Borough mean for each year and borough.
%It is these Borough means, denoted by \(\mathrm{BoroughMean}_{mj}\) for year \(m\) and borough \(j\), over IDs \(i = 1, ..., I\), for which we estimate the posterior expectation.
%After this, the model samples a final variance that is used to sample two weight vectors which are multiplied with two vectors of covariates indicating the bus company and journey type.
%Finally, the predicted delay is sampled from a negative binomial distribution with logits given by the sum of these two multiplied terms and the ID mean.
%The full specification of the model can be found in Appendix \ref{bus}, with a corresponding graphical model provided in Appendix \ref{bus_gm}.
%The experiments were run using a subset of the data with \(I=30\) IDs from \(J=3\) boroughs in \(M=3\) years.
%As with the previous experiment, an equally sized but disjoint subset of the data is used to calculate predictive log-likelihoods.
%
%We run the same experiments as on the MovieLens model, but using posterior expectation estimates for \(\mathrm{BoroughMean}_{mj}\), as displayed in Fig.~\ref{fig:bus_mom}.
%We again observe improved performance from massively parallel importance sampling and weighting compared to global importance sampling and weighting.
%The Bus Delay dataset as retrieved from the City of New York open data sharing platform consists of tabular data with each row entry holding information about a particular bus delay occurring in New York. Alongside other information that we discard each entry contains information about the type of bus route (pre-K/elementary school route, general education AM/PM route etc.) of which there are \(6\) types and the bus company that was running the delayed bus, of which there are \(57\). The delays are recorded in full minutes and we discard any entries that have a delay of greater than \(130\) minutes. We organise the dataset into a nested hierarchy as so: the delay of the \(i\)'th delayed bus in the \(j\)'th borough in the \(m\)'th year is given as \(\mathrm{Delay}_{ybi}\) and is the observed variable in our model. The dataset (as retrieved) includes the years \(2015-2022\) inclusive and covers all five New York boroughs (Brooklyn, Manhatten, The Bronx, Queens, Staten Island) as well as some of the surrounding counties and areas (Nassau County, New Jersey, Connecticut, Rockland County, Westchester). We sample \(M=3\) years, \(J=3\) boroughs in each year and \(I=60\) delayed buses in each borough, and split along the \(I\) dimension to get two equally sized train and test splits.
% \begin{equation}
% \label{eq:bus model}
% \begin{split}
% \mathrm{GlobalVariance} &\sim \Normal(0,10^{-4}) \\
% \mathrm{GlobalMean} &\sim \Normal(0,10^{-4})  \\
% \mathrm{YearMean}_y &\sim \Normal(\mathrm{GlobalMean},\exp(\mathrm{GlobalVariance})), \ y\in\{1,...,Y\} \\
% \mathrm{BoroughVariance}_b &\sim \Normal(0, 0.25), \ b\in\{1,...,B\} \\
% \mathrm{YearBoroughWeight}_{yb} &\sim \Normal(\mathrm{YearMean}_y,\exp(\mathrm{BoroughVariance}_b)),\\
% \mathrm{WeightVariance} &\sim \Normal(0,10^{-4}),  \\
% \mathrm{CompanyWeight}_{c} &\sim \mathcal{N}(0, \exp(\mathrm{WeightVariance})), \ c\in\{1,...,C\} \\
% \mathrm{JourneyTypeWeight}_{j} &\sim \mathcal{N}(0, \exp(\mathrm{WeightVariance})), \ j\in\{1,...,J\} \\
% \mathrm{logits}_{ybi} &= \mathrm{YearBoroughWeight}_{yb} + \mathrm{CompanyWeight}_{c_{ybi}} + 
% \mathrm{JourneyTypeWeight}_{j_{ybi}}\\
% \mathrm{delay}_{ybi} &\sim \mathrm{NegativeBinomial}(\mathrm{total \ count}=130, \mathrm{logits}_{ybi}), 
% \end{split}
% \end{equation}
% and the graphical model is given in Fig.~\ref{fig:bus_gm}

% \begin{figure}[!htb]
% \begin{center}
% \resizebox{0.65\textwidth}{!}{%
%   \begin{tikzpicture}
%     % nodes
 
%     %I
%     \node[obs] (delay) {\(\mathrm{Delay}_{ybi}\)};%
    
%     \node[latent, left=of delay, xshift=-1.5cm, yshift=2.5cm]  (weightvariance) {\(\mathrm{WeightVariance}\)}; 
%     \node[latent, below=of weightvariance,yshift=-1.5cm, xshift=-0.5cm] (companyweight) {\(\mathrm{CompanyWeight}_c\)};  
%     \node[latent, below right=of companyweight, xshift=0.5cm]  (journeyweight) {\(\mathrm{JourneyTypeWeight_j}\)};  

%     %J
%     \node[latent,above=of delay] (idmean) {\(\mathrm{YearBoroughWeight}_{yb}\)}; %
%     \node[latent, right=of idmean, xshift=0.8cm] (boroughvariance) {\(\mathrm{BoroughVariance}_b\)};
%     %M
%     \node[latent,above=of idmean, yshift=0.5cm] (boroughmean) {\(\mathrm{YearMean}_y\)}; %
%     \node[latent, left=of boroughmean, xshift=-1.25cm] (yearmean) {\(\mathrm{GlobalMean}\)};
%     \node[latent, above=of yearmean, yshift=-0.5cm] (yearvariance) {\(\mathrm{GlobalVariance}\)};
%     % plate
%      \plate [inner sep=.35cm,] {platedelay} {(delay)} {\(\mathrm{I}\) Ids}; %
%      \tikzset{plate caption/.append style={below right=0pt and 0pt of #1.south east}}
%      \plate [inner sep=.6cm] {plateboroughs} {(idmean)(boroughvariance)(delay)} {\(B\) Boroughs}; 
%      \plate [inner sep=.6cm] {plateyears} {(boroughmean)(idmean)(delay)(plateboroughs)} {\(Y\) Years}; 
%     %edges
%      \edge {yearmean,yearvariance} {boroughmean}
%      \edge {boroughmean,boroughvariance} {idmean}
%      \edge {idmean} {delay}
%      \edge {weightvariance} {journeyweight,companyweight}
%      \edge {journeyweight,companyweight} {delay}
%      \end{tikzpicture} }
% \caption{Graphical model for the NYC Bus Breakdown dataset}
% \label{fig:bus_gm}
% \end{center}
% \end{figure}

% \begin{equation}
% \label{eq:bus model}
% \begin{split}
% \mathrm{GlobalVariance} &\sim \Normal(0,1) \\
% \mathrm{GlobalMean} &\sim \Normal(0,1)  \\
% \mathrm{YearMean}_y &\sim \Normal(\mathrm{GlobalMean},\exp(\mathrm{GlobalVariance})), \ y\in\{1,...,Y\} \\
% \mathrm{YearVariance}_b &\sim \Normal(0, 1), \ y\in\{1,...,Y\} \\
% \mathrm{YearBoroughWeight}_{yb} &\sim \Normal(\mathrm{YearMean}_y,\exp(\mathrm{YearVariance}_b)),\ b\in\{1,...,B\}\\
% % \mathrm{WeightVariance} &\sim \Normal(0,10^{-4}),  \\
% \mathrm{CompanyWeight}_{c} &\sim \mathcal{N}(0, 1), \ c\in\{1,...,C\} \\
% \mathrm{JourneyTypeWeight}_{j} &\sim \mathcal{N}(0, 1), \ j\in\{1,...,J\} \\
% \mathrm{logits}_{ybi} &= \mathrm{YearBoroughWeight}_{yb} + \mathrm{CompanyWeight}_{c_{ybi}} + 
% \mathrm{JourneyTypeWeight}_{j_{ybi}}\\
% \mathrm{delay}_{ybi} &\sim \mathrm{NegativeBinomial}(\mathrm{total \ count}=131, \mathrm{logits}_{ybi}), 
% \end{split}
% \end{equation}


\begin{equation}
\label{eq:bus model}
\begin{split}
\P{\mathrm{GlobalVariance}} =& \ \Normal(\mathrm{GlobalVariance}; 0,1) \\
\P{\mathrm{GlobalMean}} =& \ \Normal(\mathrm{GlobalMean}; 0,1)  \\
\Pc{\mathrm{YearMean}_y}{\mathrm{GlobalMean},\mathrm{GlobalVariance}} =& \ \Normal(\mathrm{YearMean}_y; \mathrm{GlobalMean},\exp(\mathrm{GlobalVariance})), \\ & \ y\in\{1,...,Y\} \\
\P{\mathrm{YearVariance}_b} =& \ \Normal(\mathrm{YearVariance}_b; 0, 1), \\ & \ y\in\{1,...,Y\} \\
\Pc{\mathrm{YearBoroughWeight}_{yb}}{\mathrm{YearMean}_y, \mathrm{YearVariance}_b} =& \ \Normal(\mathrm{YearBoroughWeight}_{yb}; \mathrm{YearMean}_y,\exp(\mathrm{YearVariance}_b)),\\ & \ b\in\{1,...,B\}\\
% \mathrm{WeightVariance} &\sim \Normal(0,10^{-4}),  \\
\P{\mathrm{CompanyWeight}_{c}} =& \ \mathcal{N}(\mathrm{CompanyWeight}_{c}; 0, 1), \ c\in\{1,...,C\} \\
\P{\mathrm{JourneyTypeWeight}_{j}} =& \ \mathcal{N}(\mathrm{JourneyTypeWeight}_{j}; 0, 1), \ j\in\{1,...,J\} \\
\mathrm{logits}_{ybi} =& \ \mathrm{YearBoroughWeight}_{yb} + \mathrm{CompanyWeight}_{c_{ybi}} \\ & \ + 
\mathrm{JourneyTypeWeight}_{j_{ybi}}\\
\Pc{\mathrm{delay}_{ybi}}{\mathrm{logits}_{ybi}} =& \ \mathrm{NegativeBinomial}(\mathrm{delay}_{ybi}; \mathrm{total \ count}=131, \mathrm{logits}_{ybi}), 
\end{split}
\end{equation}

and the corresponding graphical model is given in Fig.~\ref{fig:bus_gm}. We define the factorised proposal distribution $Q$ (which acts also as the initial approximate posterior for the iterative methods in Fig.~\ref{fig:results_summary}) in a very similar fashion:

\begin{equation}
\label{eq:bus model_Q}
\begin{split}
\Q{\mathrm{GlobalVariance}} =& \ \Normal(\mathrm{GlobalVariance}; 0,1) \\
\Q{\mathrm{GlobalMean}} =& \ \Normal(\mathrm{GlobalMean}; 0,1)  \\
\Q{\mathrm{YearMean}_y} =& \ \Normal(\mathrm{YearMean}_y; 0,1), \ y\in\{1,...,Y\} \\
\Q{\mathrm{YearVariance}_b} =& \ \Normal(\mathrm{YearVariance}_b; 0, 1), \ y\in\{1,...,Y\} \\
\Q{\mathrm{YearBoroughWeight}_{yb}} =& \ \Normal(\mathrm{YearBoroughWeight}_{yb}; 0, 1), \ b\in\{1,...,B\}\\
% \mathrm{WeightVariance} &\sim \Normal(0,10^{-4}),  \\
\Q{\mathrm{CompanyWeight}_{c}} =& \ \mathcal{N}(\mathrm{CompanyWeight}_{c}; 0, 1), \ c\in\{1,...,C\} \\
\Q{\mathrm{JourneyTypeWeight}_{j}} =& \ \mathcal{N}(\mathrm{JourneyTypeWeight}_{j}; 0, 1), \ j\in\{1,...,J\} \\
\end{split}
\end{equation}

\begin{figure}[!htb]
\begin{center}
\resizebox{0.9\textwidth}{!}{%
  \begin{tikzpicture}
    % nodes
 
    %I
    \node[obs] (delay) {\(\mathrm{Delay}_{ybi}\)};%
 
    \node[latent, left=of delay, xshift=-1.25cm] (companyweight) {\(\mathrm{CompanyWeight}_c\)};  
    \node[latent, above=of companyweight]  (journeyweight) {\(\mathrm{JourneyTypeWeight_j}\)};  

    %J
    \node[latent,right=of delay] (yearboroughweight) {\(\mathrm{YearBoroughWeight}_{yb}\)}; %
    
    \node[latent, above=of yearboroughweight] (yearvariance) {\(\mathrm{YearVariance}_y\)};
    %M
    \node[latent,left=of yearvariance] (yearmean) {\(\mathrm{YearMean}_y\)}; %
    
    \node[latent, above=of yearmean, yshift=0.5cm] (globalmean) {\(\mathrm{GlobalMean}\)};
    \node[latent, right=of globalmean] (globalvariance) {\(\mathrm{GlobalVariance}\)};
    
    % plate
     \plate [inner sep=.35cm,] {plateID} {(delay)} {\(\mathrm{I}\) Ids}; %
     \tikzset{plate caption/.append style={below right=0pt and 0pt of #1.south east}}
     \plate [inner sep=.6cm] {plateboroughs} {(yearboroughweight)(plateID)} {\(B\) Boroughs}; 
     \plate [inner sep=.6cm] {plateyears} {(yearmean)(yearvariance)(plateboroughs)} {\(Y\) Years}; 
     
    %edges
     \edge {globalmean,globalvariance} {yearmean}
     \edge {yearmean,yearvariance} {yearboroughweight}
     \edge {yearboroughweight} {delay}
     \edge {journeyweight,companyweight} {delay}
     \end{tikzpicture} }
\caption{Graphical model for the NYC Bus Breakdown dataset}
\label{fig:bus_gm}
\end{center}
\end{figure}

\subsection{Chimpanzee Prosociality Dataset}
\label{app:chimpanzees}
The Chimpanzee Prosociality dataset\footnote{Dataset: \url{https://rdrr.io/github/rmcelreath/rethinking/man/chimpanzees.html}\\ License (GPL V3): \url{https://github.com/rmcelreath/rethinking}} consists of repeated experiments in a controlled setting for testing the prosociable tendencies of seven chimpanzees \citep{silk_chimpanzees_2005}.
In each experiment, the chimpanzee being tested, the focal chimpanzee, would be sat at the end of a table with two levers in front of them, one on their left and one on their right. 
Each lever was connected to two dishes on the corresponding left or right side of the table, one closer to the focal chimpanzee than the other.
Pulling the lever would slide those two dishes to opposite ends of the table, one towards the focal chimpanzee and one the other away. 
Every repeat of the experiment, the two dishes closest to the chimpanzee (one on the right and one on the left) contained food, however, only one of the two dishes towards the end of the table contained any food. 
This meant that the focal chimpanzee received food no matter which lever was pulled, but had the option of whether to send food to the other end of the table or not.
This experiment was repeated $R=12$ times, on $B=6$ separate occasions (``blocks'') per $A=7$ chimpanzees (``actors''), each time potentially with another chimpanzee sat at the opposite of the table (with no levers), and with the empty dish sometimes connected to the left lever and sometimes connected to the right lever.

We base our hierarchical model on the cross-classified varying slopes model presented in \cite{mcelreath_statistical_2016}, where the observations \(y_{abr}\) represent whether a given chimpanzee \(a \in \{1,...,A\}\) during a repeat \(r \in \{1,...,R\}\) inside a particular block \(b \in \{1,...,B\}\) of experiments pulled the left lever. 
To help identify prosociable behaviour we work with two binary covariate quantities for a given combination \((a,b,r)\): \(\mathrm{Condition}_{abr} \in \{0,1\}\) indicating whether or not another chimpanzee sat at the other end of the table; and \(\mathrm{ProsocLeft}_{abr}\) indicating whether the prosociable lever (connected to two full dishes, rather than just one) was on the left or not.
We model the observations via a Bernoulli variable whose logits are given by the sum of two quantities: an intercept model \(\alpha + \alpha_a + \alpha_{ab}\) and a slope model \((\beta_\mathrm{P} + \beta_{\mathrm{PC}}*\mathrm{Condition_{abr}})*\mathrm{ProsocLeft}_{abr}\). 

The intercept model is comprised of three quantities: a global intercept $\alpha$ with a zero-mean Normal prior with variance 10, a per-actor intercept $\alpha_a$ and a per-actor-block intercept $\alpha_{ab}$, the latter of which have zero-mean priors and each have variances, $\sigma_{\mathrm{ACTOR}^2}, \sigma_{\mathrm{BLOCK}^2}$ respectively sampled from a hyperprior \(\mathrm{HalfCauchy}(1)\) distribution.
The slope model is comprised of two latent variables, $\beta_{\mathrm{PC}}, \beta_{\mathrm{P}}$, representing the effect on the chimpanzee of the presence of another chimpanzee at the end of the table and whether the prosocial choice came from the left lever or not respectively. 
Both of these variables are sampled from a Normal distribution with mean zero and variance 10. 
The full specification of the generative model is given by

% \begin{equation}
% \label{eq:chimpanzees model}
% \begin{split}
% \sigma_\mathrm{ACTOR}^2 &\sim \mathrm{HalfCauchy}(1), \\
% \sigma_\mathrm{BLOCK}^2 &\sim \mathrm{HalfCauchy}(1), \\
% \beta_\mathrm{PC} &\sim \mathcal{N}(0,10), \\
% \beta_\mathrm{P} &\sim \mathcal{N}(0,10), \\
% \alpha &\sim \mathcal{N}(0,10), \\
% \alpha_{a} &\sim \mathcal{N}(0,\sigma_\mathrm{ACTOR}^2), \ \ a\in\{1,...,A\} \\
% \alpha_{ab} &\sim \mathcal{N}(0,\sigma_\mathrm{BLOCK}^2), \ \ b\in\{1,...,B\} \\
% y_{abr} &\sim \mathrm{Bernoulli}(\mathrm{logits}=\alpha + \alpha_{a} + \alpha_{ab} + (\beta_\mathrm{P} + \beta_\mathrm{PC}*\mathrm{Condition}_{abr})*\mathrm{ProsocLeft}_{abr} , \ \ r\in\{1,...,R\} \\
% % \Phi_{mji} &= \beta_{mji} * \mathrm{Weather}_{mji} \\
% % z_{mji} &\sim \mathrm{Bernoulli}(\mathrm{logits}=\Phi_{mji}), \ \  \\
% % p_{mji} &= \alpha_{mji} * \mathrm{Quality}_{mji} \\
% % y_{mjir} &\sim \mathrm{Bernoulli}(\mathrm{logits}=p_{mji} * z_{mji}), r\in\{1,...,\mathrm{R}\}
% \end{split}
% \end{equation}

\begin{equation}
\label{eq:chimpanzees model}
\begin{split}
\P{\sigma_\mathrm{ACTOR}^2} &= \mathrm{HalfCauchy}(\sigma_\mathrm{ACTOR}^2; 1), \\
\P{\sigma_\mathrm{BLOCK}^2} &= \mathrm{HalfCauchy}(\sigma_\mathrm{BLOCK}^2; 1), \\
\P{\beta_\mathrm{PC}} &= \mathcal{N}(\beta_\mathrm{PC}; 0,10), \\
\P{\beta_\mathrm{P}} &= \mathcal{N}(\beta_\mathrm{P}; 0,10), \\
\P{\alpha} &= \mathcal{N}(\alpha; 0,10), \\
\Pc{\alpha_{a}}{\sigma_\mathrm{ACTOR}^2} &= \mathcal{N}(\alpha_{a}; 0,\sigma_\mathrm{ACTOR}^2), \ \ a\in\{1,...,A\} \\
\Pc{\alpha_{ab}}{\sigma_\mathrm{BLOCK}^2} &= \mathcal{N}((\alpha_{ab}; 0,\sigma_\mathrm{BLOCK}^2), \ \ b\in\{1,...,B\} \\
\mathrm{logits}_{abr} &= \alpha + \alpha_{a} + \alpha_{ab} + (\beta_\mathrm{P} + \beta_\mathrm{PC}*\mathrm{Condition}_{abr})*\mathrm{ProsocLeft}_{abr} , \ r\in\{1,...,R\} \\
\Pc{y_{abr}}{\mathrm{logits}_{abr}} &= \mathrm{Bernoulli}(y_{abr}; \mathrm{logits}=\mathrm{logits}_{abr}) \\
% \Phi_{mji} &= \beta_{mji} * \mathrm{Weather}_{mji} \\
% z_{mji} &\sim \mathrm{Bernoulli}(\mathrm{logits}=\Phi_{mji}), \ \  \\
% p_{mji} &= \alpha_{mji} * \mathrm{Quality}_{mji} \\
% y_{mjir} &\sim \mathrm{Bernoulli}(\mathrm{logits}=p_{mji} * z_{mji}), r\in\{1,...,\mathrm{R}\}
\end{split}
\end{equation}

and the graphical model is given in Fig.~\ref{fig:chimpanzees_gm}. 

The factorised proposal distribution $Q$ is defined similarly:
\begin{equation}
\label{eq:chimpanzees model_Q}
\begin{split}
\Q{\sigma_\mathrm{ACTOR}^2} &= \mathrm{HalfCauchy}(\sigma_\mathrm{ACTOR}^2; 1), \\
\Q{\sigma_\mathrm{BLOCK}^2} &= \mathrm{HalfCauchy}(\sigma_\mathrm{BLOCK}^2; 1), \\
\Q{\beta_\mathrm{PC}} &= \mathcal{N}(\beta_\mathrm{PC}; 0,10), \\
\Q{\beta_\mathrm{P}} &= \mathcal{N}(\beta_\mathrm{P}; 0,10), \\
\Q{\alpha} &= \mathcal{N}(\alpha; 0,10), \\
\Q{\alpha_{a}} &= \mathcal{N}(\alpha_{a}; 0,1), \ \ a\in\{1,...,A\} \\
\Q{\alpha_{ab}} &= \mathcal{N}(\alpha_{ab}; 0,1), \ \ b\in\{1,...,B\}
\end{split}
\end{equation}

In our experiments, we split the data into a training set that takes \(R=10\) of the repeats for each actor-block combination, and a test set that takes the remaining \(R=2\).

\begin{figure}[!htb]
\begin{center}
\resizebox{0.6\textwidth}{!}{%
 \begin{tikzpicture}
   % nodes

   %R
   \node[obs] (y) {\(\mathrm{y}_{abr}\)};%
   %B
   \node[latent, left=of y] (alphaab) {\(\alpha_{ab}\)};   

   %A
   \node[latent, above=of alphaab] (alphaa) {\(\alpha_{a}\)}; %
   
   %global
   \node[latent, left=of alphaab] (sigmab) {\(\sigma_{\mathrm{BLOCK}}\)};
   
   \node[latent, left=of alphaa] (sigmaa) {\(\sigma_{\mathrm{ACTOR}}\)};

   \node[latent, above=of y, yshift=1.5cm] (alpha) {\(\alpha\)}; %

   \node[latent, right=of y]  (betaPC) {\(\beta_{\mathrm{PC}}\)}; 
   \node[latent, above=of betaPC]  (betaP) {\(\beta_{\mathrm{P}}\)};


   
   % plate
    \plate [] {platerepeat} {(y)} {\(\mathrm{R}\) Repeats}; %
    \plate [] {plateblock} {(alphaab)(platerepeat)} {\(\mathrm{B}\) Blocks}; 
    \plate [] {plateactor} {(alphaa)(plateblock)} {\(\mathrm{A}\) Actors}; 
   %edges
    \edge {alpha, alphaa, alphaab, betaP, betaPC} {y}
    \edge {sigmaa} {alphaa}
    \edge {sigmab} {alphaab}
    
    \end{tikzpicture} }
\caption{Graphical model for the Chimpanzees dataset}
\label{fig:chimpanzees_gm}
\end{center}
\end{figure}



\subsection{MovieLens Dataset}
\label{app:movielens}
The MovieLens100K\footnote{Dataset + License: \url{files.grouplens.org/datasets/movielens/ml-latest-small-README.html}} \citep{harper2015movielens} dataset contains 100k ratings of $N{=}1682$ films from among $M{=}943$ users.
The original user ratings run from 0 to 5.
Following \citet{geffner2022variational}, we binarise the ratings into just likes/dislikes, by taking user-ratings of \(\{0,1,2,3\}\) as a binarised-rating of \(0\) (dislikes) and user-ratings of \(\{4,5\}\) as a binarised-rating of \(1\) (likes).
We assume binarised ratings of 0 for films which users have not previously rated.

The probabilistic graphical model of the full generative distribution is given in Fig~\ref{fig:movielens_pgm}. 
We use \(n\) to index films, and \(m\) to index users. 
Each film has a known feature vector, \(\mathbf{x}_n\), while each user has a latent weight-vector, $\mathbf{z}_m$, of the same length, describing whether or not they like any given feature.
There are 18 features, indicating which genre tags the film has (Action, Adventure, Animation, Childrens,...). 
Each film may have more than one tag.
The probability of the user liking a film is given by taking the dot-product of the film's feature vector with the latent weight-vector, and applying a sigmoid ($\sigma(\cdot)$) (final line)
% \begin{align}
% \nonumber
% \m &\sim \Normal(\mathbf{0}_{18}, \I) \\
% \nonumber
% \boldsymbol{\psi} &\sim \Normal(\mathbf{0}_{18}, \I) \\
% \nonumber
% \mathbf{z}_m &\sim \Normal(\m,\exp(\boldsymbol{\psi}) \I), \ m=1,\dotsc,M \\
% \label{eq:movielens_vi_rws}
% \mathrm{Rating}_{mn} &\sim \operatorname{Bernoulli}(\sigma(\mathbf{z}_m^\intercal \mathbf{x}_n)), \ n=1,\dotsc,N
% \end{align}

\begin{equation}
\label{eq:movielens model}
\begin{split}
\P{\m} &= \Normal(\m; \mathbf{0}_{18}, \I), \\
\P{\boldsymbol{\psi}} &= \Normal(\boldsymbol{\psi}; \mathbf{0}_{18}, \I), \\
\Pc{\mathbf{z}_m}{\m, \boldsymbol{\psi}} &= \Normal(\mathbf{z}_m; \m,\exp(\boldsymbol{\psi}) \I), \ m=1,\dotsc,M \\
\Pc{\mathrm{Rating}_{mn}}{\mathbf{z}_m, \mathbf{x}_n} &= \operatorname{Bernoulli}(\mathrm{Rating}_{mn}; \sigma(\mathbf{z}_m^\intercal \mathbf{x}_n)), \ n=1,\dotsc,N
\end{split}
\end{equation}

Additionally, we have latent vectors for the global mean, $\m$, and variance, $\boldsymbol{\psi}$, of the weight vectors.

The factorised proposal distribution $Q$ (also used as the initial approximate posterior for experiments using iterative methods) is given by
\begin{equation}
\label{eq:movielens model_Q}
\begin{split}
\Q{\m} &= \Normal(\m; \mathbf{0}_{18}, \I), \\
\Q{\boldsymbol{\psi}} &= \Normal(\boldsymbol{\psi}; \mathbf{0}_{18}, \I), \\
\Q{\mathbf{z}_m} &= \Normal(\mathbf{z}_m; \mathbf{0}_{18}, \I), \ m=1,\dotsc,M 
\end{split}
\end{equation}

We use a random subset of \(N=20\) films and \(M=450\) users for our experiment to ensure high levels of uncertainty. 
We use equally sized but disjoint subset of users held aside for calculation of the predictive log-likelihood. 

\begin{figure*}[t]
\centering
\resizebox{0.2\textwidth}{!}{%
  \tikz{
    % nodes
     \node[obs] (rating) {\(\mathrm{Rating}_{mn}\)};%
     \node[latent,above=of rating] (peruser) {\(\mathbf{z}_m\)}; %
     \node[latent, above=of peruser, xshift=-1cm] (psi) {\(\boldsymbol{\psi}\)};
    \node[latent, above=of peruser, xshift=1cm] (mu) {\(\boldsymbol{\mu}\)};
    % plate
     \plate [inner sep=.3cm] {plate2} {(rating)} {\(\mathrm{N}\) Films}; %
     \plate [inner sep=.3cm] {plate1} {(peruser)(plate2)} {\(\mathrm{M}\) Users}; %
    edges
     \edge {psi,mu} {peruser}
     \edge {peruser} {rating}}
     }
  \caption{Graphical model for the MovieLens dataset}
  \label{fig:movielens_pgm}
\end{figure*}


\subsection{Occupancy Dataset}
\label{app:occupancy}
Occupancy models aim to infer the the true presence of a bird at a given observation site from repeated samples. The nature of collecting bird occupancy data for detection-nondetection datasets means that false detection or nondetection must be accounted for \citep{doser2022spoccupancy}. We fit a modified multi-species occupancy model to the north American bird breeding survey data\footnote{Dataset: \url{https://www.sciencebase.gov/catalog/item/625f151ed34e85fa62b7f926}, licensing information is included with the dataset.}. This dataset records over 700 species of bird, along thousands of randomly selected road-side routes with readings taken at half mile intervals along the 24.5 mile routes for a total of 50 readings. The dataset covers the contiguous United States and the 13 provinces and territories of Canada.


The readings are taken during peak breeding season, usually in June, once per year. 
The presently available dataset covers the years 1966-2021, except for the year 2020.
For our purposes we take the unit step function of the sum of every 10 readings to give \(\mathrm{R}=5\) repeated samples from each route. 
Each route in each year has associated weather and quality covariates which are used in our model. 
We take a subset of \(J=12\) bird species, \(M=6\) years and \(I=200\) routes to form a training set on which to run our experiments, with a distinct test set sampling the same values of \(J\) and \(M\), but a different set of \(I=100\) routes. 
The \(\mathrm{Weather}_{jmi}\) covariate is the temperature at site \(i\) on year \(m\) replicated for each bird species. 
The \(\mathrm{Quality}_{jmi}\) covariate is indicates whether a particular series of readings at site \(j\) on year \(m\) followed all the recommended guidelines for recording birds.

We model the recording of a bird along a particular route in a particular repeated measurement as arising from a Bernoulli distribution, with logits given by the product of the weighted quality covariate and an inferred latent variable \(z_{jmi}\) that indicates the true presence of a particular bird species \(j\) along route \(i\) in year \(m\), as well as including some probability of a false-positive bird sighting. 
We model this latent variable \(z_{mji}\) as also arising from a Bernoulli distribution, with logits given by weighted Weather covariates multiplied by a variable \(\mathrm{BirdYearMean}_{jm}\) representing the mean frequency of a specific species in a given year. 
The prior for the Weather and Quality weights, \(\mathrm{WeatherWeight}_{jmi}\) and \(\mathrm{QualityWeight}_{jmi}\) respectively, are normal with mean and log-variance sampled from standard normal priors. 
The variable \(\mathrm{BirdYearMean}_{jm}\) also has an hierarchical prior: first we sample a \(\mathrm{BirdMean_j}\) whose mean and log-variance are sampled from standard normal priors, then for each year we sample a \(\mathrm{BirdYearMean}_{jm}\) for each year and bird species from a normal distribution with mean  \(\mathrm{BirdMean_j}\) and unit variance. 
The full model may be written as

% \begin{equation}
% \label{eq:occupancy model}
% \begin{split}
% \mu_\mathrm{BirdMean} &\sim \mathcal{N}(0,1), \\
% \sigma_\mathrm{BirdMean} &\sim \mathcal{N}(0,1), \\
% \mu_\mathrm{QualityWeight} &\sim \mathcal{N}(0,1), \\
% \sigma_\mathrm{QualityWeight} &\sim \mathcal{N}(0,1), \\
% \mu_\mathrm{WeatherWeight} &\sim \mathcal{N}(0,1), \\
% \sigma_\mathrm{WeatherWeight} &\sim \mathcal{N}(0,1), \\
% \mathrm{QualityWeight}_{j} &\sim \mathcal{N}(\mu_\mathrm{QualityWeight}, \exp(\sigma_\mathrm{QualityWeight})), \ \ j\in\{1,...,J\} \\
% \mathrm{WeatherWeight}_{j} &\sim \mathcal{N}(\mu_\mathrm{WeatherWeight}, \exp(\sigma_\mathrm{WeatherWeight})), \ \ j\in\{1,...,J\} \\
% \mathrm{BirdMean}_{j} &\sim \mathcal{N}(\mu_\mathrm{BirdMean},\exp(\sigma_\mathrm{BirdMean})), \ \ j\in\{1,...,J\} \\
% \mathrm{BirdYearMean}_{jm} &\sim \mathcal{N}(\mathrm{BirdMean}_{jm},1), \ \ m\in\{1,...,M\} \\
% z_{jmi} &\sim \mathrm{Bernoulli}(\mathrm{logits}= \mathrm{BirdYearMean}_{jm}*\beta_{j}*\mathrm{Weather}_{jmi} ) , \ \ i\in\{1,...,I\} \\
% y_{jmir} &\sim \mathrm{Bernoulli}(\mathrm{logits}= z_{jmi} * \mathrm{QualityWeight}_j  *\mathrm{Quality}_{jmir} + (1-z_{jmi})*(-10) ) , \ \ r\in\{1,...,R\} 
% \end{split}
% \end{equation}


\begin{equation}
\label{eq:occupancy model}
\begin{split}
\P{\mu_\mathrm{BirdMean}} =& \  \mathcal{N}(\mu_\mathrm{BirdMean}; 0,1), \\
\P{\sigma_\mathrm{BirdMean}} =& \  \mathcal{N}(\sigma_\mathrm{BirdMean}; 0,1), \\
\P{\mu_\mathrm{QualityWeight}} =& \  \mathcal{N}(\mu_\mathrm{QualityWeight}; 0,1), \\
\P{\sigma_\mathrm{QualityWeight}} =& \  \mathcal{N}(\sigma_\mathrm{QualityWeight}; 0,1), \\
\P{\mu_\mathrm{WeatherWeight}} =& \  \mathcal{N}(\mu_\mathrm{WeatherWeight}; 0,1), \\
\P{\sigma_\mathrm{WeatherWeight}} =& \  \mathcal{N}(\sigma_\mathrm{WeatherWeight}; 0,1), \\
\Pc{\mathrm{QualityWeight}_{j}}{\mu_\mathrm{QualityWeight}, \sigma_\mathrm{QualityWeight}} =& \  \mathcal{N}(\mathrm{QualityWeight}_{j}; \mu_\mathrm{QualityWeight}, \exp(\sigma_\mathrm{QualityWeight})), \\ & \ j\in\{1,...,J\} \\
\Pc{\mathrm{WeatherWeight}_{j}}{\mu_\mathrm{WeatherWeight}, \sigma_\mathrm{WeatherWeight}} =& \  \mathcal{N}(\mathrm{WeatherWeight}_{j}; \mu_\mathrm{WeatherWeight}, \exp(\sigma_\mathrm{WeatherWeight})), \\ & \ j\in\{1,...,J\} \\
\Pc{\mathrm{BirdMean}_{j}}{\mu_\mathrm{BirdMean}, \sigma_\mathrm{BirdMean}} =& \  \mathcal{N}(\mathrm{BirdMean}_{j}; \mu_\mathrm{BirdMean},\exp(\sigma_\mathrm{BirdMean})), \ j\in\{1,...,J\} \\
\Pc{\mathrm{BirdYearMean}_{jm}}{\mathrm{BirdMean}_{jm}} =& \  \mathcal{N}(\mathrm{BirdYearMean}_{jm}; \mathrm{BirdMean}_{jm},1), \ \ m\in\{1,...,M\} \\
\mathrm{logits}^z_{jmi} =& \ \mathrm{BirdYearMean}_{jm}*\mathrm{WeatherWeight}_{j}*\mathrm{Weather}_{jmi} , \\ & \ \ i\in\{1,...,I\} \\
\Pc{z_{jmi}}{\mathrm{logits^z_{jmi}}} =& \  \mathrm{Bernoulli}(z_{jmi}; \mathrm{logits}= \mathrm{logits^z_{jmi}} ) ,  \ i\in\{1,...,I\} \\
% \Pc{z_{jmi}}{\mathrm{BirdYearMean}_{jm}, \mathrm{WeatherWeight}_{j}, \mathrm{Weather}_{jmi}} =& \  \mathrm{Bernoulli}(z_{jmi}; \mathrm{logits}= \mathrm{BirdYearMean}_{jm}*\mathrm{WeatherWeight}_{j}*\mathrm{Weather}_{jmi} ) , \\ & \ i\in\{1,...,I\} \\
\mathrm{logits}^y_{jmir} =& \ z_{jmi} * \mathrm{QualityWeight}_j  *\mathrm{Quality}_{jmir} + (1-z_{jmi})*(-10) , \\ & \ r\in\{1,...,R\}  \\
\Pc{y_{jmir}}{\mathrm{logits^y_{jmir}}} =& \  \mathrm{Bernoulli}(y_{jmir}; \mathrm{logits}= \mathrm{logits^y_{jmir}} ) , r\in\{1,...,R\}\\
% \Pc{y_{jmir}}{z_{jmi}, \mathrm{QualityWeight}_j, \mathrm{Quality}_{jmir}} =& \ \mathrm{Bernoulli}(y_{jmir}; \mathrm{logits}= z_{jmi} * \mathrm{QualityWeight}_j  *\mathrm{Quality}_{jmir} + (1-z_{jmi})*(-10) ) , \\ & \ r\in\{1,...,R\} 
\end{split}
\end{equation}

and the graphical model is presented in Fig.~\ref{fig:occ_gm}.

The factorised proposal distribution $Q$ is given by

\begin{equation}
\label{eq:occupancy model_Q}
\begin{split}
\Q{\mu_\mathrm{BirdMean}} =& \  \mathcal{N}(\mu_\mathrm{BirdMean}; 0,1), \\
\Q{\sigma_\mathrm{BirdMean}} =& \  \mathcal{N}(\sigma_\mathrm{BirdMean}; 0,1), \\
\Q{\mu_\mathrm{QualityWeight}} =& \  \mathcal{N}(\mu_\mathrm{QualityWeight}; 0,1), \\
\Q{\sigma_\mathrm{QualityWeight}} =& \  \mathcal{N}(\sigma_\mathrm{QualityWeight}; 0,1), \\
\Q{\mu_\mathrm{WeatherWeight}} =& \  \mathcal{N}(\mu_\mathrm{WeatherWeight}; 0,1), \\
\Q{\sigma_\mathrm{WeatherWeight}} =& \  \mathcal{N}(\sigma_\mathrm{WeatherWeight}; 0,1), \\
\Q{\mathrm{QualityWeight}_{j}} =& \  \mathcal{N}(\mathrm{QualityWeight}_{j}; 0,1), \ j\in\{1,...,J\} \\
\Q{\mathrm{WeatherWeight}_{j}} =& \  \mathcal{N}(\mathrm{WeatherWeight}_{j}; 0,1), \ j\in\{1,...,J\} \\
\Q{\mathrm{BirdMean}_{j}} =& \  \mathcal{N}(\mathrm{BirdMean}_{j}; 0,1), \ j\in\{1,...,J\} \\
\Q{\mathrm{BirdYearMean}_{jm}} =& \  \mathcal{N}(\mathrm{BirdYearMean}_{jm};  0,1), \ m\in\{1,...,M\} \\
\mathrm{logits}^z_{jmi} =& \ \mathrm{BirdYearMean}_{jm}*\mathrm{WeatherWeight}_{j}*\mathrm{Weather}_{jmi} , \ i\in\{1,...,I\} \\
\Qc{z_{jmi}}{\mathrm{logits^z_{jmi}}} =& \  \mathrm{Bernoulli}(z_{jmi}; \mathrm{logits}= \mathrm{logits^z_{jmi}} ) ,  \ i\in\{1,...,I\} \\
\end{split}
\end{equation}



\begin{figure}[!htb]
\begin{center}
\resizebox{0.95\textwidth}{!}{%
 \begin{tikzpicture}
   % nodes

   %R
   \node[obs] (y) {\(\mathrm{y}_{jmir}\)};%
   %I 
   \node[latent, left=of y] (z) {\(z_{jmi}\)};   

   %M
   \node[latent, above=of z, xshift=-1cm] (birdyearmean) {\(\mathrm{BirdYearMean}_{jm}\)}; %
   %J
   \node[latent, above=of birdyearmean, xshift=-1cm] (birdmean) {\(\mathrm{BirdMean}_j\)}; %

   \node[latent, right=of birdmean, xshift=1cm]  (beta) {\(\mathrm{WeatherWeight}_{j}\)}; 
   \node[latent, right=of beta]  (alpha) {\(\mathrm{QualityWeight}_{j}\)};

   %global
    \node[latent, above=of beta, yshift=0.75cm]  (betalogvar) {\(\sigma_\mathrm{WeatherWeight}\)}; 
   \node[latent, left=of betalogvar]  (betamean) {\(\mu_\mathrm{WeatherWeight}\)}; 

    \node[latent, left=of birdmean]  (birdmeanlogvar) {\(\sigma_\mathrm{BirdMean}\)}; 
   \node[latent, above=of birdmeanlogvar, yshift=-0.5cm]  (birdmeanmean) {\(\mu_\mathrm{BirdMean}\)}; 

   
    \node[latent, right=of alpha]  (alphalogvar) {\(\sigma_\mathrm{QualityWeight}\)}; 
    \node[latent, above=of alphalogvar]  (alphamean) {\(\mu_\mathrm{QualityWeight}\)}; 

   
   % plate
    \plate [] {platerepeat} {(y)} {\(\mathrm{R}\) Repeats}; %
    \plate [] {plateroutes} {(z)(platerepeat)} {\(\mathrm{I}\) Routes}; 
    \plate [] {plateyears} {(birdyearmean)(plateroutes)} {\(\mathrm{M}\) Years}; 
    \plate [] {platebirds} {(birdmean)(alpha)(beta)(plateyears)} {\(\mathrm{J}\) Bird Species}; 
   %edges
    \edge {z} {y}
    \edge {alpha} {y}
    \edge {birdyearmean} {z}
    \edge {birdmean} {birdyearmean}
    \edge {beta} {z}
    \edge {alpha} {y}
    \edge {birdmeanmean, birdmeanlogvar} {birdmean}
    \edge {alphamean, alphalogvar} {alpha}
    \edge {betamean, betalogvar} {beta}

    \end{tikzpicture} }
\caption{Graphical model for the Bird Occupancy dataset}
\label{fig:occ_gm}
\end{center}
\end{figure}


\subsection{Experiment Details}
\label{app:exp_det}

In each experiment, we use a graphical model (specified above in Sections \ref{app:bus}-\ref{app:occupancy}) to define a prior/generative distribution and a proposal distribution over one of four datasets. 
In particular, the proposal distribution has the same structure as the prior, but with each latent variable parameterised independently of the others (cf. the dependencies shown between latents of the prior/generative distributions in Figs. \ref{fig:bus_gm}-\ref{fig:occ_gm}).

In our first set of experiments (Figs.~\ref{fig:IS_per_K} and \ref{fig:IS_per_K_TIME}) we draw $K$ samples from the proposal and compute both the global and massively parallel estimators of the marginal likelihood (Eqs. \ref{eq:Pglobal} and \ref{eq:Prmp}). 
Taking the logarithm of these unbiased estimators of the marginal likelihood gives (via Jensen's inequality) a lower bound on the log-marginal likelihood, referred to as the ELBO. 
% One of our main results is that we find the massively parallel ELBO to be a consistently tighter bound on the log-marginal likelihood than the global ELBO (see top row of Figs. 1 and 2). 

With these $K$ samples ($K^N$ if we consider all possible combination of $n$-latent samples), we obtain 100 posterior samples of all latents via the importance sampling mechanism described in Section \ref{sec:methods} (see Eq.~\ref{eq:Pk}). 
Then, we use these posterior latent samples to obtain predictive samples on an unseen test set of the data. We report the predictive log-likelihood of the test data given the predicted latent samples. 
% Again, we find that using the massively parallel importance sampling technique achieves higher predictive log-likelihoods than using the corresponding global techniques (see bottom row of Figs. 1 and 2).

In our second set of experiments, we compare the `one-shot' global and massively parallel importance sampling approaches with iterative methods---namely VI, IWAE and RWS. 
For each iteration of these methods, we obtain ELBOs and predictive log-likelihoods in the same manner as discussed above for global importance sampling, however, we also update the parameters of the proposal distribution at each iteration so as to maximise an objective function via the Adam optimiser. 
In the VI results, this objective function is the global ELBO, calculated with a single sample, i.e. $K=1$. 
The IWAE results also use the global ELBO but with $K=10$. The third method we consider is RWS in which we perform a maximum likelihood update of $Q$ using the posterior samples obtained by reweighting the proposal samples with the weights given in Eq.~\ref{eq:rglobal}, also with $K=10$. 
% In Fig. 3 we see that the massively parallel importance sampling method is able to achieve competitive ELBOs and predictive log-likelihoods in far less time than the iterative methods.

We ran all experiments on an 80GB NVIDIA A100 GPU.
For each experiment, we used global IS, MP IS, VI, IWAE, and RWS methods to obtain (for each value of $K$ in the IS methods, or each iteration of the iterative methods) 100 values of the ELBO, predictive log-likelihood (on a test set that was disjoint from the training set), which were averaged to produce the results presented in this work. % and posterior mean estimates for each latent variable in the model.
For global IS and MP IS, we additionally ran a single warmup iteration before the final 100 to allow for memory management optimisation---the results of these warmup runs were then discarded.
% These 100 values were averaged for the ELBO and predictive log-likelihood, whilst the 100 mean estimates were used to calculate mean squared error or variance per variable, based on whether the observed data had been sampled from the model itself---in which case the true variable values were known---or not.
We present error bars in Figs.~\ref{fig:IS_per_K}, and \ref{fig:IS_per_K_TIME} representing the standard error over these 100 (post-warmup) runs.

We ran the globabl IS and MP IS experiments for \(K\in\{3,5,10,15,20,25,30,45,60,80,10^2,3\times10^2,10^3,3\times10^3,10^4,3\times10^4,10^5,3\times10^5,10^6,3\times10^6,10^7\}\), however, only global IS on the chimpanzee dataset was able to run with each of these values (and no higher) --- failures were mostly due to the large amount of memory required or due to numerical instability in the models.


For VI, IWAE and RWS, we optimized our approximate posterior using Adam with learning rates of $0.3,0.1,0.03,0.01$ (we found that learning rates faster than $0.3$ were unstable on all models), and ran 100 optimization steps for both for each dataset/model. 
In Fig.~\ref{fig:results_summary} we plot only the best-performing learning rates for each method, where best-performing is decided as the a highest ELBO after a certain number of iterations, in particular we chose 75.
In Appendix section \ref{app:all_lr_results} we present plots of every learning rate for each method individually compared against MP IS and global IS on each model.
% For HMC, we ran the PyMC \citep{salvatier_probabilistic_2016}  No U-Turn Sampler \citep{hoffman2014no} on our GPU using JAX \citep{jax2018github} to obtain 250 consecutive samples after 10 tuning iterations and with all other parameters set to their default values, including a target acceptance rate of 0.8.

\section{Further Results}\label{app:more_results}

% \subsection{Posterior Moment Quality}
% Along with the ELBO and predicitve-log-likelihood (as presented in the plots of the main text), here we consider two measures of the quality of our posterior mean estimator.
% First, we consider the total variance of the estimator (shown in the third column of the following figures); in the ideal case, when we have exactly computed the posterior mean, this variance should be zero.
% While this variance can be zero in other cases (e.g.\ when the mean estimator always returns a constant), it does provide a lower-bound on the expected squared error between the estimator and the true value.
% Second, we directly considered the MSE between our posterior mean estimator and the true latent variable (shown in column d of the following figures).
% However, this quantity requires us to know the true value for the latent variable, which requires us to generate data from the model (we only generated data for the fourth column of the following plots, the other columns use real data).
% In both measures we report the sum of variances/MSEs over all latent variables in the model.
% Again, in both of these cases, we see dramatic improvements for massively parallel against standard ``global'' importance sampling and competitive results in significantly less time than the iterative methods.

% \begin{figure*}[t]
% \begin{center}
% \includegraphics[width=0.9\textwidth]{figures/all_models_summary.pdf}
% \end{center}
% \caption{A summary of results for every method on each dataset. The first column presents the ELBO, and the second presents predicitve log-likelihood, as given in Fig~\ref{fig:results_summary}. The third and fourth column show the total variation and MSE respectively.}
% \label{fig:all_models_summary}
% \end{figure*}


\subsection{The Effect of Learning Rate on Iterative Methods}\label{app:all_lr_results}
Here we present plots of showing the results obtained on each model-method combination for every stable learning rate that we tried on the iterative models (excluding HMC). 
Note that no models were numerically stable with a learning rate higher than 0.3 for any of the iterative methods, and further that many of the following plots do not include learning rates of 0.3 or even 0.1 (for example MovieLens with RWS) as these were numerically unstable as well.

% \subsubsection{VI}

\begin{figure*}[t]
\begin{center}
\includegraphics[width=0.9\textwidth]{figures/VI_summary.pdf}
\end{center}
\caption{A comparison of results for VI on each dataset with varying learning rates.}
\label{fig:vi_summary}
\end{figure*}


% \subsubsection{IWAE}

\begin{figure*}[t]
\begin{center}
\includegraphics[width=0.9\textwidth]{figures/IWAE_summary.pdf}
\end{center}
\caption{A comparison of results for IWAE on each dataset with varying learning rates.}
\label{fig:iwae_summary}
\end{figure*}

% \subsubsection{RWS}

\begin{figure*}[t]
\begin{center}
\includegraphics[width=0.9\textwidth]{figures/RWS_summary.pdf}
\end{center}
\caption{A comparison of results for RWS on each dataset with varying learning rates.}
\label{fig:rws_summary}
\end{figure*}

\end{document}
