\documentclass[accepted]{uai2023} % after acceptance, for a revised
                                    % version; also before submission to
                                    % see how the non-anonymous paper
                                    % would look like
%% There is a class option to choose the math font
% \documentclass[mathfont=ptmx]{uai2023} % ptmx math instead of Computer
                                         % Modern (has noticable issues)
% \documentclass[mathfont=newtx]{uai2023} % newtx fonts (improves upon
                                          % ptmx; less tested, no support)
% NOTE: Only keep *one* line above as appropriate, as it will be replaced
%       automatically for papers to be published. Do not make any other
%       change above this note for an accepted version.

%% Choose your variant of English; be consistent
%\usepackage[american]{babel}
\usepackage[british]{babel}

%% Some suggested packages, as needed:
\usepackage{natbib} % has a nice set of citation styles and commands
    \bibliographystyle{plainnat}
    \renewcommand{\bibsection}{\subsubsection*{References}}
\usepackage{mathtools} % amsmath with fixes and additions
\usepackage{algorithm}
\usepackage{algpseudocode}
% \usepackage{siunitx} % for proper typesetting of numbers and units
\usepackage{booktabs} % commands to create good-looking tables
\usepackage{amsmath,amsfonts,amssymb}
\usepackage{subcaption}
\usepackage{tikz}
\usetikzlibrary{bayesnet}
\usetikzlibrary{arrows} % nice language for creating drawings and diagrams
\usepackage{xr}
\externaldocument{Heap_216-supp}


\newcommand{\bracket}[3]{\left#1 #3 \right#2}
\newcommand{\mbracket}[5]{\left#1 #4 \middle#2 #5 \right#3}
\renewcommand{\b}{\bracket{(}{)}}
\newcommand{\bc}{\mbracket{(}{\vert}{)}}
\newcommand{\ab}{\bracket{\langle}{\rangle}}
\newcommand{\cb}{\bracket{\{}{\}}}
\newcommand{\abs}{\bracket{\lvert}{\rvert}}
\newcommand{\sqb}{\bracket{[}{]}}
\newcommand{\E}[1][]{\mathrm{E}_{#1}\sqb}
\newcommand{\Var}[1][]{\mathrm{Var}_{#1}\sqb}
\newcommand{\bareP}{\operatorname{P}}
\renewcommand{\P}[1][]{\bareP_{#1}\b}
\newcommand{\Pc}[1][]{\bareP_{#1}\bc}
\newcommand{\Pef}{\P[\text{ef}]}
\newcommand{\Qef}{\Q[\text{ef}]}
\newcommand{\bareQ}{{\operatorname{Q}}}
\newcommand{\bareQt}{\tilde{\operatorname{Q}}}
\newcommand{\Q}[1][]{\bareQ_{#1}\b}
\newcommand{\Qc}[1][]{\bareQ_{#1}\bc}
\newcommand{\Qglob}{\bareQ_{\glob}\b}
\newcommand{\Qglobc}{\bareQ_{\glob}\bc}
\newcommand{\Qmp}{\bareQ_{\mp}\b}
\newcommand{\Qmpc}{\bareQ_{\mp}\bc}
\newcommand{\Qtmc}{\bareQ_{\tmc}\b}
\newcommand{\Qtmcc}{\bareQ_{\tmc}\bc}
\newcommand{\Z}{\mathbf{Z}}
\newcommand{\0}{\mathbf{0}}
\newcommand{\T}{\mathbf{T}}
\newcommand{\I}{\mathbf{I}}
\newcommand{\J}{\mathbf{J}}
\renewcommand{\u}{\mathbf{u}}
\newcommand{\m}{\boldsymbol{\mu}}
\newcommand{\n}{\boldsymbol{\eta}}
\newcommand{\np}{\boldsymbol{\eta}_{\bareP}}
\newcommand{\nq}{\boldsymbol{\eta}_{\bareQ}}
\renewcommand{\k}{\mathbf{k}}
\newcommand{\K}{\mathbf{K}}
\newcommand{\ceq}{{=}}
\newcommand{\cbi}[1]{\{ #1\}_{i=1}^n}
\newcommand{\cbk}[1]{\{ #1\}_{k=1}^K}
\newcommand{\cbik}[1]{\{ #1\}_{ik}}
\newcommand{\pa}[1]{{\textrm{pa}\b{#1}}}
\newcommand{\pl}[1]{{\textrm{pl}\b{#1}}}
\newcommand{\apl}{\textrm{pl}}
\newcommand{\pax}{\pa{x}}
\newcommand{\qa}[1]{{\textrm{qa}\b{#1}}}
\newcommand{\fa}{\textrm{qa}\b}
\newcommand{\zi}{z^{\text{ind}}}
\newcommand{\zn}{z^{\text{non-ind}}}
\newcommand{\dd}[2][]{\frac{\partial #1}{\partial #2}}
\newcommand{\at}{\bracket{.}{\rvert}}
\newcommand{\Dkl}{\operatorname{D}_\text{KL}\mbracket{(}{\Vert}{)}}
\newcommand{\argmax}{\operatorname*{argmax}}

\newcommand{\tmc}{\textrm{TMC}}
\newcommand{\nis}{\textrm{NIS}}
\newcommand{\snis}{\textrm{SNIS}}
\renewcommand{\mp}{\textrm{MP}}
\newcommand{\glob}{\textrm{global}}
\newcommand{\post}{\textrm{post}}
\newcommand{\old}{\textrm{old}}
\newcommand{\Pe}{\mathcal{P}}
\renewcommand{\L}{\mathcal{L}}
\newcommand{\Pmp}{\Pe_\mp}
\newcommand{\Pmpexp}{\Pmp^\text{exp}}
\newcommand{\Pmpmarg}{\Pmp^\text{marg}}
\newcommand{\Pmpsamp}{\Pmp^\text{samp}}
\newcommand{\Pglob}{\Pe_\glob}
\newcommand{\Pold}{\Pe_\old}
\newcommand{\Lmp}{\L_\mp}
\newcommand{\Lglob}{\L_\glob}
\newcommand{\const}{\operatorname{const}}

\newcommand{\tsum}{{\textstyle \sum}}
\newcommand{\tprod}{{\textstyle \prod}}


\newcommand{\thetaglob}{\Delta \theta_\glob}
\newcommand{\phiglob}{\Delta \phi_\glob}
\newcommand{\thetamp}{\Delta \theta_\mp}
\newcommand{\phimp}{\Delta \phi_\mp}
\newcommand{\thetapost}{\Delta \theta_\post}
\newcommand{\phipost}{\Delta \phi_\post}

\newcommand{\Dpost}{\Delta_\post}
\newcommand{\Dglob}{\Delta_\glob}
\newcommand{\Dmp}{\Delta_\mp}

\newcommand{\texttmc}{TMC}
\newcommand{\textglob}{global}
\newcommand{\textGlob}{Global}

\newcommand{\cdo}{\operatorname{do}\b}

\newcommand{\Normal}{\mathcal{N}}

\newcommand{\citedeg}[1][]{\citep[#1][]{carpenter1999improved,li2012deterministic,li2014fight,zhou2016new,wang2017survey}}

\newcommand{\citepf}[1][]{\citep[#1][]{gordon1993novel,doucet2009tutorial,andrieu2010particle,maddison2017filtering,le2017auto,lindsten2017divide,naesseth2018variational,lai2022variational}}

\newcommand{\citevpf}{\citep{maddison2017filtering,le2017auto,lindsten2017divide,naesseth2018variational,lai2022variational}}

\newcommand{\citevi}[1][]{\citep[#1][]{jordan1999introduction,wainwright2008graphical,kingma2013auto,rezende2014stochastic,blei2017variational,nguyen2017variational,zhang2018advances,kingma2019introduction,gayoso2021joint}}
\newcommand{\citeiwae}[1][]{\citep[#1][]{burda2015importance,cremer2017reinterpreting}}
\newcommand{\citerws}[1][]{\citep[#1][]{bornschein2014reweighted,le2020revisiting}}


\title{Massively Parallel Reweighted Wake-Sleep}

% The standard author block has changed for UAI 2023 to provide
% more space for long author lists and allow for complex affiliations
%
% All author information is authomatically removed by the class for the
% anonymous submission version of your paper, so you can already add your
% information below.
%
% Add authors
\author[1]{Thomas Heap}
\author[1]{Gavin Leech}
\author[1]{\href{mailto:<laurence.aitchison@brisol.ac.uk>?Subject=Your UAI 2023 paper}{Laurence Aitchison}{}}



% Add affiliations after the authors
\affil[1]{%
    Department of Computer Science\\
    University of Bristol\\
    Bristol
}

  \begin{document}
\maketitle

\begin{abstract}
    Reweighted wake-sleep (RWS) is a machine learning method for performing Bayesian inference in a very general class of models.
    RWS draws $K$ samples from an underlying approximate posterior, then uses importance weighting to provide a better estimate of the true posterior.
    RWS then updates its approximate posterior towards the importance-weighted estimate of the true posterior.
    However, recent work \citep{chatterjee2018sample} indicates that the number of samples required for effective importance weighting is exponential in the number of latent variables.
    Attaining such a large number of importance samples is intractable in all but the smallest models.
    Here, we develop massively parallel RWS, which circumvents this issue by drawing $K$ samples of all $n$ latent variables, and individually reasoning about all $K^n$ possible combinations of samples.
    While reasoning about $K^n$ combinations might seem intractable, the required computations can be performed in polynomial time by exploiting conditional independencies in the generative model.
    We show considerable improvements over standard ``global'' RWS, which draws $K$ samples from the full joint.
\end{abstract}

\section{Introduction}
Many machine learning tasks involve inferring the latent variables from underlying observations \citep{jaynes2003probability,mackay2003information}.
One approach to inferring these latent variables from data is to use Bayesian inference.
In Bayesian inference, we define a generative model which consists of a prior, $\P{\text{latents}}$, describing the probability of the latent variable before seeing data, and a likelihood, $\Pc{\text{data}}{\text{latents}}$, describing the probability of the data given the latents.
The goal is then to compute the posterior using Bayes theorem,
\begin{align}
  \Pc{\text{latents}}{\text{data}} \propto \Pc{\text{data}}{\text{latents}} \P{\text{latents}}.
\end{align}
However, computing this posterior is typically intractable, especially in more complex models where the likelihood or prior is parameterised by a neural network.

As an alternative, modern approaches such as  variational inference \citevi[VI; ] and reweighted wake-sleep \citerws[RWS; ] learn the parameters, $\phi$, of an approximate posterior, $\Qc[\phi]{\text{latents}}{\text{data}}$.
In VI, we learn this posterior by optimizing the evidence lower-bound objective (ELBO) using the reparameterisation trick \citep{kingma2013auto,rezende2014stochastic}.
This bound often has considerable slack, which can bias inferences.
To address this issue importance weighted auto-encoders \citeiwae[IWAEs; ] draw multiple samples from the approximate posterior and use importance weighting to provide a tighter bound on the model evidence.
In RWS, we draw multiple samples from the approximate posterior, reweight those samples to approximate the true posterior, then update the approximate posterior towards the reweighted approximation of the true posterior (specifically, this is the wake-phase Q update; see \citealp{bornschein2014reweighted}).

However, recent work \citep{chatterjee2018sample} showed that the number of samples required to get accurate importance weighted estimates is very large.
Specifically, they showed that the required number of samples scales as $e^{\Dkl{\Pc{z}{x}}{\Qc{z}{x}}}$.
This is particularly problematic because we expect the KL divergence to scale linearly in the number of latent variables, $n$.
Indeed, if $\Pc{z}{x}$ and $\Qc{z}{x}$ are IID over the $n$ latent variables, then the KL-divergence is exactly proportional to $n$.
Overall, this implies that we expect the required number of samples to be exponential in the number of latent variables, which is clearly infeasible for larger models.

This problem has been addressed in the IWAE context using TMC \citep{aitchison2019tensor}, which draws $K$ samples for each of the $n$ latent variables, and individually reasons about each of the $K^n$ combinations of samples.
Here, we develop an analogous approach for RWS, which we call massively parallel (MP) RWS.
Critically, this is not a simple extension of the derivations in \citet{aitchison2019tensor}.
The derivations in \citet{aitchison2019tensor} are either restricted to factorised approximate posteriors, or use an augmented state-space viewpoint which cannot be applied to RWS.
We therefore give very different and considerably more general derivations in Sec.~\ref{sec:methods}.
Indeed, these more general derivations allow us to use a more general class of approximate posteriors, even in the original VI setting.

%Instead, we define massively parallel IWAE and RWS, which draw
%It might seem that we are left again with the problem of reasoning about exponentially many ($K^n$) combinations.
%However, it turns out that we can use conditional independencies in the generative model to compute the required quantities in polynomial time.
%As such, we can implicitly obtain an exponential number of importance samples, enough to get accurate importance weighted estimates, \citep{chatterjee2018sample} in polynomial time.

\section{Related Work}
Of course, our methods are based on fundamental work on VI \citevi{}, IWAE \citeiwae{} and RWS \citerws{}.
%More specifically, there have been a few papers involving similar schemes, albeit usually in a restricted class of models.

%There is a considerable body of work on more sophisticated schemes for VI which does not apply to RWS.
Perhaps the most obvious related work is TMC \citep{aitchison2019tensor}, which also draws $K$ samples for each of the $n$ latent variables, and considers all $K^n$ combinations.
The key difference to our work is that TMC only applies to VI, while our work applies to RWS.
However, our more general derivations improve on TMC itself.
Specifically, TMC is restricted to approximate posteriors that are IID across the $K$ particles for one latent variable.
In contrast, our derivations allow us to couple the distribution over $K$ particles for a single latent variable (Appendix~\ref{app:ap}), which gives scope for e.g.\ applying variance reduction strategies.

Further, there is a body of work improving VI, but not RWS in specific restricted classes of model.

The first model class is a single-level hierarchical model, with a Bayesian parameter, $z_0$, common to all datapoints, and latent variables, $z_1\dotsc z_{n}$, each associated with a different datapoint.
\citet{geffner2022variational} propose a ``local'' importance weighting (LIW) scheme for this class of model, which contrasts with standard importance weighting schemes that they describe as ``global''.
We adopt their ``global'' terminology for standard IWAE and RWS, which draw $K$ samples from the full joint approximate posterior.
LIW in effect does IWAE separately for each datapoint: it separately draws $K$ IWAE samples for the latent variables, $z_1\dotsc z_n$, associated with each datapoint, $x_1,\dotsc,x_n$.
This looks very similar to TMC and massively parallel RWS, which draw $K$ samples for the Bayesian parameter, $z_0$ and the latent variables, $z_1,\dotsc,z_n$, and reasons about all $K^{n+1}$ combinations of all samples on $z_0,z_1,\dotsc,z_n$.
However LIW differs from TMC and massively parallel RWS in that LIW draws only a single sample of the Bayesian parameter, $z_0$.
Of course, there are additional differences.
In particular, LIW, like TMC, ultimately performs VI, while massively parallel RWS applies RWS.
Further, massively parallel RWS (like TMC) can be applied to a very broad class of models, while LIW is restricted to these single-level hierarchical models.
%As such, the differences between massively parallel RWS and \citet{geffner2022variational} mirror those between massively parallel and TMC.
%Specifically, both TMC and LIW are methods for VI and do not apply to RWS.
%Of course, our approach differs primarily in that it applies to RWS, whereas LIW applies to VI.
%However, there are two key additional differences between LIW and massively parallel-like methods (including TMC).
%First, massively parallel methods apply to a much more general class of generative model (including e.g.\ timeseries and multi-level hierarchical models amongst others).
%Second, LIW schemes draw only a single sample of $z_0$ (the latent variable common to all datapoints).
%As such, LIW in effect does IWAE separately for each datapoint; it separately draws $K$ IWAE samples for the latent variables associated with each datapoint, $z_1\dotsc z_N$.
%In contrast, massively parallel methods also draw $K$ samples for the Bayesian parameters, $z_0$, then reason over all $K^n$ combinations.
%Differences between LIW and massively parallel methods may be small if $z_0$ is small (e.g.\ a small dimensional vector).
%However, the differences are likely to be much larger if one tries to generalise the LIW approach to complex models with large, structured latent states.

A second class of models is timeseries models.
Massively parallel methods in timeseries may bear some resemblance to particle filtering/sequential Monte-Carlo (SMC) \citepf{}, in that SMC/particle filters also reason over multiple samples for each latent variable.
However, work which learns a proposal/approximate posterior in the particle filtering setting focuses on VI rather than RWS \citevpf{}.
Moreover, most work in SMC / particle filtering considers only a restrictive class of timeseries model, while massively parallel methods operate in a very general class of models.
While there is some work extending SMC to more general generative models \citep[e.g.\ ][]{lindsten2017divide}, this work does not, for instance, give a mechanism to learn an approximate posterior using e.g.\ IWAE or RWS, let alone to have an approximate posterior whose structure differs from that of the underlying generative model.

%Tensor Monte Carlo \citep[TMC;][]{aitchison2019tensor} is superficially similar to our approach, in the sense that it also considers all $K^n$ combinations of samples.
%However there are numerous differences between the approaches.
%Perhaps the most obvious difference is that the TMC scheme only operates for VI, whereas our massively parallel approach applies to other methods including RWS.
%Moreover, even for VI, our scheme offers benefits over TMC.
%Third, TMC approximate posteriors are vulnerable to particle degeneracy-like effects that reduce sample diversity \citedeg{}.

%However, more fundamental differences emerge when we look more closely (see Appendix~\ref{app:ap} for more details).
%In particular, our proposal draws $K$ samples independently from the full joint approximate posterior.
%In contrast, the TMC non-factorised proposal is inspired by particle filters, and it therefore does not draw $K$ full joint independent samples.
%Instead, in TMC, the samples for parent latent variables are selected randomly, in a manner equivalent to equally-weighted resampling in a particle filter.
%The massively parallel approach has four advantages over the TMC approach.
%First, our approach is simpler to implement, as it does not require randomly choosing parent particles when sampling from the approximate posterior.
%Second, TMC has a potentially higher computational complexity than our massively parallel approach.
%This is because TMC requires computing approximate posterior probabilities, marginalising over all possible parent samples.
%If $\abs{\qa{i}}$ denotes the number of parents of the $i$th latent variable under the approximate posterior, then there are $K^{\abs{\qa{i}}}$ possible mixture components, implying a total computational cost to compute the probabilities of $\mathcal{O}(K^{1+\abs{\qa{i}}})$ (as we also need to compute the probability of $K$ samples of the $i$th latent variable).
%In contrast, the computational cost for sampling our massively parallel approximate posterior is simply $\mathcal{O}(K)$, as we simply draw $K$ samples independently from the joint posterior.
%Third, TMC approximate posteriors are vulnerable to particle degeneracy-like effects that reduce sample diversity \citedeg{}.
%Finally, both the previous TMC and our massively parallel approximate posterior are defined by starting with a user-defined, base approximate posterior defined over a single sample of the full joint latent space.
%While the massively parallel proposal has to match the underlying user-defined proposal (as it just draws IID samples from that proposal), the TMC proposal does not.
%Specifically, the issue is that randomly sampling the parent particles breaks dependencies between proposal latent variables (see Appendix~\ref{app:tmc_ap_marg} for a worked example).
%This is highly problematic, as it

\section{Background}

Here, we give background on IWAE and RWS, which are methods for performing Bayesian inference in a probabilistic generative model.
Both IWAE and RWS work with a collection of $K$ samples of the latent variables.
The full collection of $K$ samples is denoted $z$, while an individual sample (specifically the $k$th sample) is denoted $z^k$,
%\textbf{General setup.} We start with a generative model, $\P[\theta]{x, z'}$, where $x$ is our observations and $z' \in \mathcal{Z}$ is our latents.
%Our goal is to learn the parameters, $\theta$, and to approximate the posterior, $\P[\theta]{z'| x}$.
%Note that we use $z'$ rather than $z$ because we use $z$ for a collection of $K$ samples from the full latent space,
\begin{align}
  \label{eq:z}
  z &= (z^1, z^2, \dotsc, z^K) \in \mathcal{Z}^K.
\end{align}
For standard global VI and RWS, $K$ samples are drawn by sampling $K$ times from the underlying single-sample approximate posterior, $\Q[\phi]{z^k| x}$, which has parameters, $\phi$,
\begin{align}
  \label{eq:Qglobal}
  \Qglob{z| x} &= \prod_{k\in\mathcal{K}} \Q[\phi]{z^k| x},
\end{align}
where $\mathcal{K} = \{1,\dotsc,K\}$.

IWAE and RWS can be written in terms of an unbiased estimator of the marginal likelihood (Appendix~\ref{app:iwae_glob}),
%Now, we highlight and unify aspects of IWAE and RWS in a manner that will be useful for developing our approach.
%In particular, both IWAE and RWS can be written in terms of, $\mathcal{P}_\text{global}(z)$,
\begin{align}
  \label{eq:Pglobal}
  \Pglob(z) &= \frac{1}{K} \sum_{k\in\mathcal{K}} r_k(z),\\
  \label{eq:rglobal}
  %r_k(z) &= \frac{\P[\theta]{x, z^k}}{\prod_i \Q[\phi]{z_i^k| x, z_{\pa{i}}^{k}}}.
  r_k(z) &= \frac{\P[\theta]{x, z^k}}{\Q[\phi]{z^k| x}},
\end{align}
where $\P{x, z^k}$ is the generative probability, and $r_k(z)$ is the ratio of generative and approximate posterior probabilities, $r(z^k)$.

\subsection{Importance weighted autoencoder}
In IWAE \citep{burda2015importance}, we optimize $\phi$ and $\theta$ using the IWAE objective, $\Lglob$, which forms a lower-bound on the marginal likelihood, $\log \P[\theta]{x}$,
\begin{align}
  \label{eq:L:iwae}
  \log \P[\theta]{x} \geq \Lglob(\theta, \phi) &= \E[{\Qglob{z| x}}]{\log \Pglob(z)}
\end{align}
Differentiating this objective wrt the parameters of the generative model is straightforward, as $\Q[\phi]{z| x}$ does not depend on $\theta$ so we can interchange the expectation and gradient operators.
In contrast, the distribution over which the expectation is taken does depend on $\phi$, so the $\phi$ update is more difficult to implement and requires reparameterisation \citep{kingma2013auto,rezende2014stochastic}.

%this objective
%We updated $\theta$ and $\phi$ by differentiating $\mathcal{L}$,
%\begin{subequations}
%\begin{align}
%  \Delta \theta_\text{IWAE} &= \nabla_\theta \mathcal{L} = \phantom{\nabla_\phi} \E[{\Q[\phi]{z| x}}]{\nabla_\theta \log \mathcal{P}_\text{global}(z)}\\
%  \Delta \phi_\text{IWAE}   &= \nabla_\phi   \mathcal{L} = \nabla_\phi \E[{\Q[\phi]{z| x}}]{\phantom{\nabla_\theta} \log \mathcal{P}_\text{global}(z)}
%  %\E[\epsilon]{\nabla_\phi \log \mathcal{P}_\text{global}(z(\epsilon; \phi))}
%\end{align}
%\end{subequations}
%Note that the $\theta$ update is relatively easy to implement, as the distribution over which the expectation, $\Q[\phi]{z| x}$, does not depend on $\theta$, so we can put the gradient wrt $\theta$ inside the expectation.
%In contrast, the distribution over which the expectation is taken does depend on $\phi$, so the $\phi$ update is more difficult to implement.

\subsection{Reweighted wake-sleep}
In RWS \citep{bornschein2014reweighted}, we do not have a single unified objective.
%the ideal case, we would update both $\theta$ and $\phi$ using the true posterior,
%\begin{subequations}
%\begin{align}
%  \E{\Delta \theta_\text{true post}} &= \E[{\P[\theta]{z| x}}]{\nabla_\theta \log \P[\theta]{z, x}}\\
%  \E{\Delta \phi_\text{true post}} &= \E[{\P[\theta]{z| x}}]{\nabla_\phi   \log \Q[\phi]{z| x}}
%\end{align}
%\end{subequations}
%The $\theta$ update resembles the M-step in EM, while the $\phi$ update trains the approximate posterior, $\Q[\phi]{z| x}$, using maximum likelihood on samples from the true posterior, $\P[\theta]{z| x}$.
%
Instead, we update the generative model and approximate posterior by drawing $K$ samples from an approximate posterior, $\Q[\phi]{z^k| x}$.
We then use importance reweighting to bring those samples closer to the true posterior, $\P[\theta]{z| x}$, and do a maximum likelihood-like update with those reweighted samples.
In particular, the $\bareP$ update resembles the M-step in EM, and maximizes $\log \P[\theta]{z^k, x}$ for the reweighted samples.
Likewise, the (wake-phase) $\bareQ$ update maximizes $\log \Q[\phi]{z^k| x}$ for the reweighted samples mirroring the true posterior, and therefore brings $\Q[\phi]{z^k| x}$ closer to the true posterior,
\begin{subequations}
\label{eq:rws_global:iw}
\begin{align}
  \label{eq:rws_global:iw:P}
  \thetaglob &= \\
  \nonumber
  &\E[{\Qglob{z| x}}]{\frac{1}{K} \sum_{k\in\mathcal{K}} \frac{r_k(z)}{\Pglob(z)} \nabla_\theta \log \P[\theta]{z^k, x}},\\
  \label{eq:rws_global:iw:Q}
  \phiglob &= \\
  \nonumber
  &\E[{\Qglob{z| x}}]{\frac{1}{K} \sum_{k\in\mathcal{K}} \frac{r_k(z)}{\Pglob(z)} \nabla_\phi \log \Q[\phi]{z^k| x}}.
\end{align}
\end{subequations}
See Appendix~\ref{app:rws_glob} for a derivation of these updates.
However, it turns out that implementing the updates in (Eq.~\ref{eq:rws_global:iw}) directly is difficult, as it requires us to separately compute the gradients for each sample, $z^k$.
Instead, we typically use,
\begin{subequations}
\label{eq:rws_global:obj}
\begin{align}
  \label{eq:rws_global:obj:P}
  \thetaglob &= \E[{\Qglob{z| x}}]{\nabla_\theta \log \mathcal{P}_\text{global}(z)},\\
  \label{eq:rws_global:obj:Q}
  \phiglob   &= \E[{\Qglob{z| x}}]{\nabla_\phi \b{- \log \mathcal{P}_\text{global}(z)}}.
\end{align}
\end{subequations}
See Appendix~\ref{app:rws:updates} for a proof of equivalence.

\section{Methods}
\label{sec:methods}

\begin{figure*}
\begin{center}
\includegraphics[width=0.9\textwidth]{chart_movielens_rws.pdf}
\end{center}
\caption{Results of massively parallel RWS and standard or ``global'' RWS for a hierarchical model on subsets of MovieLens with differing numbers of users and films per user, showing the predictive log likelihood after 25k training iterations.}
\label{fig:movielens_rws}
\end{figure*}


These previous approaches draw $K$ samples from the full joint latent space.
However, the required number of samples scales exponentially in the number of latent variables \citep{chatterjee2018sample}.
Thus, we define a massively parallel scheme in which we draw $K$ samples for each latent variable, then effectively obtain $K^n$ samples by considering all combinations of $K$ samples for each of the $n$ latent variables.
To that end, we denote each of the separate samples for separate latent variables $z_i^k\in \mathcal{Z}_i$, where $k$ indexes the sample and $i$ indexes the latent variable.
%The $k$th sample of the full joint latent space can therefore be written,
%\begin{align}
%  z^k &= (z^k_1, z^k_2, \dotsc, z^k_n) \in \mathcal{Z}_1 \times \mathcal{Z}_2 \times \dotsm \times \mathcal{Z}_n = \mathcal{Z}.
%\end{align}
%As before (Eq.~\ref{eq:z}), we write the collection of all $K$ samples of all $n$ latent variables as $z$.
We can write the collection of $K$ samples for a single latent variable (the $i$th) as,
\begin{align}
  z_i &= (z_i^1, z_i^2, \dotsc, z_i^K) \in \mathcal{Z}_i^K.
\end{align}
To sample all $K$ copies of the full joint latent space, TMC \citep{aitchison2019tensor} uses an IID distribution over the $K$ samples, $z_i^1,\dotsc,z_i^K$,
%In contrast, TMC proposals \citep{aitchison2019tensor} were restricted to only IID distributions over the $K$ samples,
\begin{align}
  \label{eq:Qtmc}
  \Qtmcc{z}{x} &= \prod_{i=1}^n \prod_{k\in \mathcal{K}} \Qtmcc{z_i^k}{z_j \text{ for all } j \in \qa{i}}.
\end{align}
Here, $\qa{i}$ are the indices of parents of the $i$th latent variable under the approximate posterior.
In contrast, massively parallel methods allow for dependencies between the $K$ samples for the $i$th latent variable, $z_i^1,\dotsc,z_i^K$ (Appendix~\ref{app:ap}),
\begin{align}
  %\Qc{z^k}{x} &= \prod_{i=1}^n \Qc{z_i^k}{z_{\qa{i}}^k}.\\
  %\label{eq:Qmp}
  %\Qc{z}{x} &= \prod_{i=1}^n \prod_{k\in \mathcal{K}} \Qc{z_i^k}{z_{\qa{i}}^k}
  \label{eq:Qmp}
  \Qmpc{z}{x} &= \prod_{i=1}^n \Qmpc{z_i}{z_j \text{ for all } j \in \qa{i}}.
\end{align}
There are no formal constraints on these dependencies.
However, there are practical constraints, namely that we need to be able to efficiently compute the single-particle marginals, $\Qc{z_i^k}{z_j \text{ for all } j \in \qa{i}}$.
In Appendix~\ref{app:ap}, we give more specifics about choices of $\Qmpc{z_i}{z_j \text{ for all } j \in \qa{i}}$ and $\Qtmcc{z_i^k}{z_j \text{ for all } j \in \qa{i}}$.
At a high level, these distributions are constructed by mixing the underlying single-sample approximate posterior, $\bareQ_\phi$, for different combinations of parent particles.

The generative model is more complicated, because we want to evaluate the generative probability for any of the $K^n$ possible combinations of the $K$ samples of the $n$ latent variables.
To facilitate writing down these generative probabilities, we begin by defining a vector of indices,
\begin{align}
  \k &= \b{k_1, k_2, \dotsc, k_n} \in \mathcal{K}^n,
\end{align}
which has one index, $k_i$, for each of the $n$ latent variables.
The latent variables specified by these indices is known as the ``indexed'' latent variables and can be written,
\begin{align}
  z^\k &= \b{z_1^{k_1}, z_2^{k_2}, \dotsc, z_n^{k_n}} \in \mathcal{Z}.
\end{align}
The generative probability can thus be written,
\begin{equation}
\label{eq:gen}
\begin{aligned}
  %\P[\theta]{x, z^\k} &= \P[\theta]{x| z^{\k_{\pa{x}}}_{\pa{x}}} \prod_{i=1}^n \P[\theta]{z^{k_i}_i| z^{\k_{\pa{i}}}_{\pa{i}}}.
  \P[\theta]{x, z^\k} = &\Pc[\theta]{x}{z_j^{k_j} \text{ for all } j \in \pa{x}} \\ &\prod_{i=1}^n \Pc[\theta]{z^{k_i}_i}{z_j^{k_j} \text{ for all } j \in \pa{i}}.
\end{aligned}
\end{equation}
Here, $\pa{i}$ are the indices of parents of the $i$th latent variable under the generative model, and $\pa{x}$ are the parents of the data under the generative model.
%so $z_{\pa{i}}^{\k_{\pa{i}}}$ denotes the indexed latent samples for all parents of $i$ under the generative model,
%\begin{align}
%  z_{\pa{i}}^{\k_{\pa{i}}} &= \cb{z_j^{k_j} \text{ for } j \in \pa{i}}.
%\end{align}
Our use of $\pa{i}$ mirrors our use of $\qa{i}$ to denote indices of parents of the $i$th latent variable under the approximate posterior.
Of course, the structure of the generative model and approximate posterior may differ, so $\qa{i}$ and $\pa{i}$ can also differ.

%\textbf{General setup.} We start with a generative model, $\P[\theta]{x, z'}$, where $x$ is our observations and $z' \in \mathcal{Z}$ is our latents.
%Our goal is to learn the parameters, $\theta$, and to approximate the posterior, $\P[\theta]{z'| x}$.
%In most cases, the full set of latents, $z'$, are structured, in the sense that $z'$ is made up of $n$ separate latent variables, $z'_i \in \mathcal{Z}_i$, indexed $i$,
%\begin{align}
%  z' &= (z'_1, z'_2, \dotsc, z'_n) \in \mathcal{Z}_1 \times \mathcal{Z}_2 \times \dotsm \times \mathcal{Z}_n = \mathcal{Z}.
%\end{align}
%Note that we use $z'$ rather than $z$ because we are reserving $z$ for later use.
%The generative model can be factorised according to its graphical model structure,
%\begin{align}
%  \label{eq:gen}
%  \P[\theta]{x, z'} &= \P[\theta]{x| z'_{\pa{x}}} \prod_{i=1}^n \P[\theta]{z'_i| z'_{\pa{i}}}
%\end{align}
%where $\pa{i}$ are the indices of parents of the $i$th latent variable, and $\pa{x}$ are the indicies of the parents of the data.
%All methods we consider use an approximate posterior, $\Q{z'| x}$, which factorises as,
%\begin{align}
%  \Q{z'| x} &= \prod_i \Q{z_i'| z_{\qa{i}}'}.
%\end{align}
%Here, $\qa{i}$ are the indices of parents of the $i$th latent variable under the approximate posterior.  As usual, the graphical model (and hence the parents of each latent variable) may be different for the generative model and approximate posterior.
%
%We focus on multi-sample methods, such as IWAE and RWS.
%These methods draw $K$ samples for each latent variable, $z_i^k$, with samples being indexed $k$.
%The collection of all $K$ samples for the $i$th latent variable is,
%\begin{align}
%  z_i &= (z_i^1, z_i^2, \dotsc, z_i^K) \in \mathcal{Z}_i^K.
%  %z^i &= (z^i_1, z^i_2, \dotsc, z^i_K) \in \mathcal{Z}_i^K.
%\end{align}
%The $k$th sample for all latent variables is,
%\begin{align}
%  z^k &= (z_1^k, z_2^k, \dotsc, z_n^k) \in \mathcal{Z}.
%\end{align}
%And all $K$ samples for all $n$ latent variables is,
%\begin{align}
%  z &= (z_1, z_2, \dotsc, z_n) \in \mathcal{Z}^K.
%\end{align}
%These $K$ samples are drawn by sampling $K$ times from the underlying approximate posterior,
%\begin{align}
%  \label{eq:Qglobal}
%  \Q{z| x} &= \prod_{k\in\mathcal{K}} \Q{z^k| x} = \prod_{k\in\mathcal{K}} \prod_{i=1}^n \Q{z_i^k| z_{\qa{i}}^k}.
%\end{align}
%Now, we highlight and unify aspects of IWAE and RWS in a manner that will be useful for developing our approach.
%As we will show, both IWAE and RWS can be written in terms of, $\mathcal{P}_\text{global}(z)$,
%\begin{align}
%  \label{eq:Pglobal}
%  \Pglob(z) &= \frac{1}{K} \sum_{k\in\mathcal{K}} r_k(z),\\
%  \label{eq:rglobal}
%  %r_k(z) &= \frac{\P[\theta]{x, z^k}}{\prod_i \Q[\phi]{z_i^k| x, z_{\pa{i}}^{k}}}.
%  r_k(z) &= \frac{\P[\theta]{x, z^k}}{\Q[\phi]{z^k| x}},
%\end{align}
%which is the average ratio of generative and approximate posterior probabilities, $r(z^k)$.
%When $z$ is sampled from $\Q[\phi]{z| x}$, then $\mathcal{P}_\text{global}(z)$ is an unbiased estimator of the marginal likelihood, $\P[\theta]{x}$ (Sec.~\ref{app:iwae_glob}).

\begin{figure*}
  \centering
  \includegraphics[width=0.45\linewidth]{chart_movielens_rwsN5M300.pdf}
  \includegraphics[width=0.45\linewidth]{chart_movielens_rws_timeN5M300.pdf}
  \caption{
    Comparison of performance between massively parallel RWS and global RWS, for the movielens dataset. \textbf{a} Same as Fig.~\ref{fig:movielens_rws} (top right), except that we use much higher values of $K$ for global RWS.
    \textbf{b} As massively parallel RWS may take longer than global RWS for a given $K$, we plotted time for a single training iteration against the predictive log-likelihood.
  }
 %
 % Results of inference for the movielens dataset showing predictive log likelihood after $25$k iterations. Note that Global only begins outperfoming massively parallel RWS with $K_\mp=3$ somewhere between $K_\glob=3000$ and $K_\glob=10000$. \textbf{b} Time per single training iteration plotted against predictive log likelihood for massively parallel RWS and globally importance sampled RWS for the Movielens dataset. Here we see that massively parallel RWS is able to attain near maximum performance for almost no additional compute time when increasing from $K$=3 to $K$=10.
  \label{fig:movielens_comp}
\end{figure*}

%\textbf{massively parallel IWAE and RWS.}
Looking at $\Pglob(z)$ (Eq.~\ref{eq:Pglobal}) we average only over $K$ terms, corresponding to our $K$ samples from the full joint latent space.
Our key contribution is to adapt RWS for the case where we average over all $K^n$ combinations of samples for each latent variable, indexed $\k$.

We can define an alternative unbiased marginal likelihood estimator, $\Pmp(z)$.
This estimator is obtained by averaging over all $K^n$ combinations of all samples of all latent variables,
\begin{align}
  \label{eq:Pmp}
  \Pmp(z) &= \frac{1}{K^n} \sum_{\k \in \mathcal{K}^n} r_\k(z),\\
  \label{eq:rtmc}
  r_\k(z) &= \frac{\P{x, z^\k}}{\prod_i \Qmpc{z_i^{k_i}}{z_j \text{ for all } j \in \qa{i}}}%\Q{z_i^{k_i}| x, z^{k_i}_\qa{i}}}.
\end{align}
For the proof that $\Pmp(z)$ is an unbiased marginal likelihood estimator, see Appendix~\ref{app:iwae_tmc}.
By analogy with \textglob{} IWAE, we can define an objective for massively parallel IWAE,
\begin{align}
  \label{eq:L:tmc}
  \Lmp &= \E[{\Qmpc{z}{x}}]{\log \Pmp(z)}.
\end{align}
We prove that this quantity has the required properties (specifically, that it is a lower-bound on the log marginal likelihood) in Appendix~\ref{app:iwae_tmc}.
This quantity is very similar to that given in \citep{aitchison2019tensor}, except that it allows for a slightly more general proposal, $\bareQ_{\mp}$, which allows for dependencies between the $K$ samples for a single latent variable, $z_i^1,\dotsc,z_i^K$.
Our key contribution is to design massively parallel updates for RWS,
\begin{subequations}
\label{eq:rws_tmc:iw}
\begin{align}
  \label{eq:rws_tmc:iw:P}
  \thetamp &{=} \E[{\Qmpc{z}{x}}]{\frac{1}{K^n}\sum_{\k\in\mathcal{K}^n} \frac{r_\k(z)}{\Pmp(z)} \nabla_\theta \log \P[\theta]{z^\k, x}}\\
  \label{eq:rws_tmc:iw:Q}
  \phimp &{=} \E[{\Qmpc{z}{x}}]{\frac{1}{K^n}\sum_{\k\in\mathcal{K}^n} \frac{r_\k(z)}{\Pmp(z)} \nabla_\phi \log \Q[\phi]{z^\k, x}}
\end{align}
\end{subequations}
These updates are derived in Appendix~\ref{app:rws_tmc}, and they can
be implemented using,
\begin{subequations}
\label{eq:rws_tmc:iw:imp}
  \begin{align}
    \label{eq:rws_tmc:iw:P:imp}
  \thetamp &= \E[{\Qmpc{z}{x}}]{\nabla_\theta \log \Pmp(z)},\\
    \label{eq:rws_tmc:iw:Q:imp}
  \phimp    &= \E[{\Qmpc{z}{x}}]{\nabla_\phi \b{- \log \Pmp(z)}},
  \end{align}
\end{subequations}
(see Appendix~\ref{app:rws_glob}).


\begin{algorithm}
\caption{Massively Parallel RWS}\label{alg:mprws}
\begin{algorithmic}
\Require{Data $x$, Prior $\mathrm{P}_\theta$, Proposal $\mathrm{Q}_\mathrm{MP}$, $K \geq 1$}

\For{$i \gets 1$ to $n$}
    \State Sample $z_i \sim \Qmpc{z_i}{z_j \text{ for all } j \in \qa{i}}$
    \State $z \gets \{z_1,...,z_{i-1}\} \cup z_i$
    \State $f^i_{k_i, \k_{\pa{i}}}(z) \gets \frac{\Pc[\theta]{z^{k_i}_i}{z_j^{k_j} \text{ for all } j \in \pa{i}}}{\Qmp{z_i^{k_i}| x, z_j \text{ for all } j \in \qa{i}}}$
\EndFor
\State $f^x_{\k_{\pa{x}}}(z) \gets \Pc[\theta]{x}{z_j^{k_j} \text{ for all } j \in \pa{x}}$
\State $\Pmp(z) \gets \tfrac{1}{K^n} \sum_{\k^n} f^x_{\k_{\pa{x}}}(z) \prod_i f^i_{k_i,\k_{\pa{i}}}(z)$
\State $\thetamp \gets {\nabla_\theta \log \Pmp(z)}$
\State $\phimp \gets {\nabla_\phi \b{- \log \Pmp(z)}}$
% \Ensure $y = x^n$
% \State $y \gets 1$
% \State $X \gets x$
% \State $N \gets n$
% \While{$N \neq 0$}
% \If{$N$ is even}
%     \State $X \gets X \times X$
%     \State $N \gets \frac{N}{2}$  \Comment{This is a comment}
% \ElsIf{$N$ is odd}
%     \State $y \gets y \times X$
%     \State $N \gets N - 1$
% \EndIf
% \EndWhile
\end{algorithmic}
\end{algorithm}


\subsection{Efficiently averaging exponentially many terms}

\begin{figure*}
  \centering
  \includegraphics[width=0.45\linewidth]{chart_bus.pdf}
  \includegraphics[width=0.45\linewidth]{chart_bus_time.pdf}
  \caption{Comparison of performance for massively parallel RWS and global RWS on the NYC bus breakdown dataset. \textbf{a} Predictive log-likelihood against $K$ after 75k training iterations. \textbf{b} Predictive log-likelihood against the time for a single training iteration.
  }
  \label{fig:bus_chart}
\end{figure*}

It should be surprising that we can compute $\Pmp(z)$ (Eq.~\ref{eq:Pmp}) efficiently, as it involves summing over exponentially many ($K^n$) terms.
However, it turns out that efficient computation is possible if we exploit structure in the generative model.
To exploit structure, we first need to write down the generative probability for the $\k$th sample of all latent variables, $z^\k$.
This looks alot like Eq.~\eqref{eq:gen}, as it follows the same graphical model structure, with $\pa{x}$ and $\pa{i}$ giving the indices of parents of the data and the $i$th latent variable respectively,
%; the only difference is that we need to be careful about the indices of the samples, $\k$,
%\begin{align}
  %\P[\theta]{x, z^\k} &= \P[\theta]{x| z_{\pa{x}}^{\k_\pa{x}}} \prod_{i=1}^n \P[\theta]{z^{k_i}_i| z_{\pa{i}}^{\k_\pa{i}}}.
  %\P[\theta]{x, z^\k} &= \Pc[\theta]{x}{z_j^{k_j} \text{ for all } j \in \pa{x}} \prod_{i=1}^n \Pc[\theta]{z^{k_i}_i}{z_j^{k_j} \text{ for all } j \in \pa{i}}.
%\end{align}
%However, the key difference is that we now need to be very careful about the indicies, $\k$.
%In particular, remember that the indexed latents are $z^\k = (z_1^{k_1}, z_2^{k_2}, \dotsc, z_n^{k_n})$.
%As e.g.\ the distribution over $x$ depends on only a subset of the latent variables given by $\pa{x}$, we need to be able to write the indexed latents for only the parents of $x$,
%\begin{subequations}
%\begin{align}
%  z_{\pa{x}}^{\k_{\pa{x}}} &= \b{z_j^{k_j} \text{ for all } j \in \pa{x}},\\
%  z_{\pa{i}}^{\k_{\pa{i}}} &= \b{z_j^{k_j} \text{ for all } j \in \pa{i}}.
%\end{align}
%\end{subequations}
%In particular, the indices of samples for the parent latent variables are,
%\begin{subequations}
%\begin{align}
%  \k_{\pa{x}} &= \b{k_j \text{ for all } j \in \pa{x}} \in \mathcal{K}^{\abs{\pa{x}}}, \\
%  \k_{\pa{i}} &= \b{k_j \text{ for all } j \in \pa{i}} \in \mathcal{K}^{\abs{\pa{i}}},
%\end{align}
%\end{subequations}
%where $\abs{\pa{x}}$ and $\abs{\pa{i}}$ are the number of parents for the data and the $i$th latent respectively.
%Likewise, the actual value of the $\k_{\pa{x}}$th or $\k_{\pa{i}}$th combination of the parent latent variables is,
%$\Pmp(z)$ (Eq.~\ref{eq:Pmp}),
\begin{equation}
  \label{eq:tensor_product_pq}
\begin{aligned}
  %\Pmp(z) &= \frac{1}{K^n} \sum_{\k\in\mathcal{K}^n} \P[\theta]{x| z_{\pa{x}}^{\k_\pa{x}}} \prod_{i=1}^n \frac{\P[\theta]{z^{k_i}_i| z_{\pa{i}}^{\k_\pa{i}}}}{\Q{z_i^{k_i}| x, z_\qa{i}}}
  %\Pmp(z) &= \frac{1}{K^n} \sum_{\k\in\mathcal{K}^n} \Pc[\theta]{x}{z_j^{k_j} \text{ for all } j \in \pa{x}} \prod_{i=1}^n \frac{\Pc[\theta]{z^{k_i}_i}{z_j^{k_j} \text{ for all } j \in \pa{i}}.}{\Q{z_i^{k_i}| x, z_\qa{i}}}
  r_\k(z) = &\Pc[\theta]{x}{z_j^{k_j} \text{ for all } j \in \pa{x}} \\ &\prod_{i=1}^n \frac{\Pc[\theta]{z^{k_i}_i}{z_j^{k_j} \text{ for all } j \in \pa{i}}.}{\Qmp{z_i^{k_i}| x, z_j \text{ for all } j \in \qa{i}}}
\end{aligned}
\end{equation}
If we fix $z$ (i.e. all $K^n$ samples of all $n$ latent variables), then $r_\k(z)$ can be regarded as a big tensor with $K^n$ elements, indexed by $\k$.
In that case, each term in the product defining $r_\k(z)$ (Eq.~\ref{eq:tensor_product_pq}) can also be regarded as a tensor.
The key observation is that the individual tensors in the product typically have only a few indices.
For instance, the probability of the data, $x$, depends only on the indices of samples of the parents (i.e.\ $(k_j \text{ for all } j \in \pa{x})$).
These indices of the samples of the parents can be written,
\begin{subequations}
\begin{align}
  \k_{\pa{x}} &= \b{k_j \text{ for all } j \in \pa{x}} \in \mathcal{K}^{\abs{\pa{x}}}, \\
  \k_{\pa{i}} &= \b{k_j \text{ for all } j \in \pa{i}} \in \mathcal{K}^{\abs{\pa{i}}}.
\end{align}
\end{subequations}
and where $\abs{\pa{x}}$ and $\abs{\pa{i}}$ are the number of parents latent variables.
To make explicit the idea that the individual terms in Eq.~\eqref{eq:tensor_product_pq} can be understood as tensors, we define $f^x_{\k_{\pa{x}}}(z)$ as the tensor for the data and $f^i_{k_i, \k_{\pa{i}}}(z)$ as the tensor for the $i$th latent variable,
\begin{subequations}
\begin{align}
  \label{eq:fx}
  f^x_{\k_{\pa{x}}}(z) &= \Pc[\theta]{x}{z_j^{k_j} \text{ for all } j \in \pa{x}},\\%\in \mathbb{R}^{K^{\abs{\pa{x}}}}.\\
  \label{eq:fi}
  f^i_{k_i, \k_{\pa{i}}}(z) &= \frac{\Pc[\theta]{z^{k_i}_i}{z_j^{k_j} \text{ for all } j \in \pa{i}}}{\Qmp{z_i^{k_i}| x, z_j \text{ for all } j \in \qa{i}}}.% \in \mathbb{R}^{K^{1+\abs{\pa{i}}}}.
\end{align}
\end{subequations}
Thus, we can write $r_\k(z)$ as a product of these factors,
\begin{align}
  \label{eq:tensor_product}
  r_\k(z) &= f^x_{\k_{\pa{x}}}(z) \prod_i f^i_{k_i,\k_{\pa{i}}}(z),
\end{align}
and $\Pmp(z)$ can be understood as a big tensor product,
\begin{align}
  \Pmp(z) &= \tfrac{1}{K^n} \sum_{\k^n} f^x_{\k_{\pa{x}}}(z) \prod_i f^i_{k_i,\k_{\pa{i}}}(z).
\end{align}
This tensor product can be efficiently computed in polynomial time by ordering the sums and products using Python packages such as opt-einsum \citep{daniel2018opt}.

%There is one tensor for the data, $f^x_{\k_{\pa{x}}}(z)$, and one tensor for each latent variable, $f^i_{k_i, \k_{\pa{i}}}(z)$.
%The tensor for the data, $f^x_{\k_{\pa{x}}}(z)$ (Eq.~\ref{eq:fx}) gives the probability of the data for all possible combinations of samples of the parent latent variables.
%Remember that $\pa{x}$ is the set of indices of the latent varaibles that are parents of the data, so the number of parents is given by $\abs{\pa{x}}$.
%As there are $K$ samples for each parent latent variable, and $\abs{\pa{x}}$ parent latent variables, there are $K^{\abs{\pa{x}}}$ total combinations of all samples of all parents, and hence a total of $K^{\abs{\pa{x}}}$ elements in $f^x_{\k_{\pa{x}}}(z)$.
%To make this dependence only on the samples of the parent latent variables explicit, we take the full set of indicies for all latent variables, $\k\in\mathcal{K}^n$ (in the outer sum in Eq.~\ref{eq:tensor_product}), and extract only those indices corresponding to the parents of the data, $x$, or of the $i$th latent variable,
%\begin{align}
%  \k_{\pa{x}} &\in \mathcal{K}^{\abs{\pa{x}}} & \k_{\pa{i}} &\in \mathcal{K}^{\abs{\pa{i}}}.
%\end{align}
%Likewise, the tensor for the $i$th latent variable, $f^i_{k_i, \k_{\pa{i}}}(z)$ (Eq.~\ref{eq:fi}) gives the probability ratio for all possible combinations of samples of the parent latent variables, and all samples of the $i$th latent variable.
%This tensor therefore has $K^{1+\abs{\pa{i}}}$ elements in total.
%Ultimately, then the computation in Eq.~\eqref{eq:tensor_product} is a big tensor product, which takes the product of a number of tensors, $f^x_{\k_{\pa{x}}}(z)$ and $f^i_{\k_{\pa{i}}}(z)$, then sums out the indices.





%\subsection{Variational Inference}
%
%
%
%\subsection{Reweighted wake-sleep (RWS)}
%For an IWAE, there are two phases.
%The wake-phase $\bareQ$ update can be understood as learning $\Q[\phi]{z}$ by doing maximum likelihood on approximate samples from the true posterior, $\P[\theta]{z_k| x}$,
%\begin{align}
%  %\Delta \theta &= \E[{\Q[\phi]{z}}]{\nabla_\theta \tfrac{1}{K} \sum_k \log \frac{\P[\theta]{x, z_k}}{\Q[\phi]{z_k}}}\\
%  \Delta \phi &= \int dz \P[\theta]{z| x} \nabla_\phi \log \Q[\phi]{z}.\\
%  \intertext{Of course, we do not have samples from the true posterior.  Instead, we sample from $\Q[\phi]{z}$ and reweight,}
%  \Delta \phi &= \int dz \Q[\phi]{z} \frac{\P[\theta]{z| x}}{\Q[\phi]{z}} \nabla_\phi \log \Q[\phi]{z} = \E[{\Q[\phi]{z}}]{\frac{\P[\theta]{z| x}}{\Q[\phi]{z}} \nabla_\phi \log \Q[\phi]{z}}.
%  \intertext{As we do not have the posterior ...}
%\end{align}
%This can be understood as ...
%
%\subsection{Importance weighted autoencoder (IWAE)}
%For an IWAE, the objective is usually written,
%\begin{align}
%  \L &= \E{\tfrac{1}{K} \log \sum_k \frac{\P[\theta]{x, z_k}}{\Q[\phi]{z_k}}}
%\end{align}
%This can be understood as importance weighting,
%\begin{align}
%  \Delta \theta &= \nabla_{\theta} \L = \nabla_{\theta} \E{\tfrac{1}{K} \sum_k \log \frac{\P[\theta]{x, z_k}}{\Q[\phi]{z_k}}}
%\end{align}
%\begin{align}
%  \mathcal{P}(z_1,\dotsc,z_k) &= \sum_k \log \frac{\P[\theta]{x, z_k}}{\Q[\phi]{z_k}}
%\end{align}
%\begin{align}
%  \L &= \E[\Q{z}]{\mathcal{P}(z)}\\
%  \intertext{VI}
%  \Delta \theta &= \E[{\Q[\phi]{z}}]{\nabla_\theta \mathcal{P}(z)}\\
%  \Delta \phi &= \nabla_\phi \E[{\Q[\phi]{z}}]{\mathcal{P}(z)}
%  \intertext{RWS}
%  \Delta \theta &= \E[{\Q[\phi]{z}}]{\nabla_\theta \mathcal{P}(z)}\\
%  \Delta \phi &= -\E[{\Q[\phi]{z}}]{\nabla_\phi \mathcal{P}(z)}
%\end{align}




\section{Experiments}



We present an empirical evaluation of massively parallel RWS \footnote{Code for reproduction of experiments can be found: \url{https://github.com/ThomasHeap/MPRW-S}}.
%We perform variational inference on multiple models with hierarchical latent structures.
%We compare massively parallel RWS against TMC \cite{aitchison2019tensor} and standard importance weighted RWS.
Since the RWS wake phase $\bareQ$ requires multiple importance samples we test massively parallel RWS (MP RWS) with $K \in \{3, 10, 30\}$ and global RWS with $K \in \{3, 10, 30,100,300,1000,3000,10000,30000\}$.
Unless otherwise stated our variational posterior is of the form $q_\phi(\mathbf{z}) = \prod_{i=1}^L q(z_{i})$, where $q(z_{i})$ is from the same family of distributions as $z_{i}$'s distribution in the generative model. We compare massively parallel RWS (Eq.~\ref{eq:rws_tmc:iw} and Eq.~\ref{eq:rws_tmc:iw:imp}) and  against standard ``global'' RWS (Eq.~\ref{eq:rws_global:iw} and Eq.~\ref{eq:rws_global:obj}).

Optimisation is done using Adam \citep{kingma2014adam} with $\beta = (0.9, 0.999)$, no weight decay, and a learning rate of $0.001$ which is decreased by a factor of $10$ every $10$k iterations.
%The optimiser is attempting to maximise the sum of the wake phase $\phimp$ and $\thetamp$ updates.
In all cases we plot the result of 5 runs with different random seeds and plot the mean and standard error.
All times are measured on a single Nvidia A100 GPU. %except for the timeseries model where times are measured on an Core i5-7200U CPU.
% \subsection{Synthetic model with one layer of latents}

% We use the hierarchical model from \cite{geffner2022variational} with synthetic data:

% \begin{equation}
% \begin{split}
% p(\mathbf{\mu_Z}, \mathbf{\psi_Z}, \mathbf{\psi_y}, \mathbf{y}, \mathbf{Z}) = \ & \mathcal{N}(\mathbf{\mu_Z} \vert 0, 1)
% \mathcal{N}(\mathbf{\psi_Z} \vert 0, 1)
% \mathcal{N}(\mathbf{\psi_y} \vert 0, 1) \\
% &\prod^M_{\mathrm{i}=1}\mathcal{N}(\mathbf{Z}_\mathrm{i} \vert \mathbf{\mu_Z}\mathbf{1}_{d_\mathbf{Z}}, \exp(\mathbf{\psi_Z})) \prod^N_{\mathrm{j}=1}\mathcal{N}(y_\mathrm{ij} \vert \mathbf{Z}^\intercal_\mathrm{i} \mathbf{X}_\mathrm{ij}), \exp(\mathbf{\psi_y}))
% \end{split}
% \end{equation}

% where $\mathbf{Z}_\mathrm{i} \in \mathbb{R}^{d_\mathbf{Z}}$, $\mathbf{X} \sim \mathcal{N}(\mathbf{0}^{(M,N,d_\mathbf{Z})}, \mathbf{1})$ and $d_\mathbf{Z} = 20$ if $N=30$, $d_\mathbf{Z} = 5$ otherwise.

% As in \cite{geffner2022variational} we generate datasets for all combinations of $M=\{10,50,100\}$ and $N=\{10,30\}$.

% Results are shown in \ref{fig:simple_hier}.
% It can be seen that for both methods increasing K improves performances (apart from in the case of $M=50$, $N=10$ and $M=100$, $N=10$ where for $K=10,15$ the lower bound doesn't change).
% It can also be seen that TPP outperforms LIW for all $K>1$.

% \begin{figure}
% \begin{center}
% \includegraphics{chart_hier.pdf}
% \end{center}
% \caption{Results of inference for the simple hierarchical model on synthetic datasets of multiple sizes.}
% \label{fig:simple_hier}
% \end{figure}

% \subsection{Synthetic model with four layers of latents}


% We use the following hierarchical model:

% \begin{equation}
% \begin{split}
% p(\mathbf{\mu_{Z^1}}, \mathbf{\mu_{Z^2}}, \mathbf{\mu_{Z^3}}, \mathbf{\mu_{Z^4}}, \mathbf{\psi_Z}, \mathbf{\psi_y}, \mathbf{y}, \mathbf{Z}) = \ & \mathcal{N}(\mathbf{\mu_{Z^1}} \vert 0, 1)
% \mathcal{N}(\mathbf{\psi_Z} \vert 0, 1)
% \mathcal{N}(\mathbf{\psi_y} \vert 0, 1) \\
% &\prod^2_{\mathrm{i}=1}
% \mathcal{N}(\mathbf{\mu_{Z^2_\mathrm{i}}} \vert \mathbf{\mu_{Z^1}}, 1)
% \prod^2_{\mathrm{j}=1}\mathcal{N}(\mathbf{\mu_{Z^3_\mathrm{ij}}} \vert \mathbf{\mu_{Z^2_\mathrm{i}}}, 1) \\
% &\prod^2_{\mathrm{k}=1}\mathcal{N}(\mathbf{\mu_{Z^4_\mathrm{ijk}}} \vert \mathbf{\mu_{Z^3_\mathrm{ij}}}, 1)
% \prod^M_{\mathrm{m}=1}\mathcal{N}(\mathbf{Z_\mathrm{ijkm}} \vert \mathbf{\mu_{Z^4_\mathrm{ijk}}}\mathbf{1}_{18}, \exp(\mathbf{\psi_Z})) \\
% &\prod^N_{\mathrm{n}=1}\mathcal{N}(y_\mathrm{ijkmn} \vert \mathbf{Z^\intercal_\mathrm{ijkm}} \mathbf{X}_\mathrm{ijkmn}), \exp(\mathbf{\psi_y}))
% \end{split}
% \end{equation}

% where $\mathbf{\mu_{Z^1}} \in \mathbb{R}$, $\mathbf{\mu_{Z^2}} \in \mathbb{R}^2$, $\mathbf{\mu_{Z^3}} \in \mathbb{R}^{(2,2)}$, $\mathbf{\mu_{Z^4}} \in \mathbb{R}^{(2,2,2)}$ $\mathbf{Z_\mathrm{ijkm}} \in \mathbb{R}^{5}$, $\mathbf{X} \sim \mathcal{N}(\mathbf{0}^{(2,2,2,M,N,5)}, \mathbf{1})$. Here the model is organised in layers: $\mathbf{\mu_{Z^i}}$ depends on $\mathbf{\mu_{Z^{i-1}}}$ for $i=2,3,4$, the superscipts refer to the place in this hierarchy, and the subscripts refer to the indexing implied by the plating.

% We generate datasets with $M \in \{2,4,10\}$ groups and $N=2$ observations per group.

% Results are shown in \ref{fig:deep_hier}.
% We see that increasing K improves the performance of both TPP and LIW, and that TPP outperforms LIW for all $K>1$.

% \begin{figure}
% \begin{center}
% \includegraphics{chart_deeper_regression.pdf}
% \end{center}
% \caption{Results of inference for the deeper hierarchical model on synthetic datasets of multiple sizes.}
% \label{fig:deep_hier}
% \end{figure}

\subsection{Movielens dataset}



We show results on the MovieLens100K dataset \citep{harper2015movielens}.
This dataset consists of 100K ratings from $\mathrm{M}=943$ users (indexed $m$) of $\mathrm{N}=1682$ films (indexed $j$).
Each film, indexed $j$, has as a feature vector $\mathbf{x}_j$.
We observe user ratings, and following \citep{geffner2022variational}, binarise ratings of $(0,1,2,3)$ to $0$ and ratings of $(4,5)$ to $1$.
%
% \subsubsection{Variational Inference}
% For variational inference we use the following hierarchical model:
%
% \begin{equation}
% \label{eq:movielens_vi}
% \begin{split}
% \mu &\sim \Normal(\mathbf{0}_{18},1) \\
% \psi &\sim \Normal(\mathbf{0}_{18},1) \\
% \mathrm{Per \ user \ mean_m} &\sim \Normal(\mu,\exp(\psi)), \ \mathrm{m}=1,...,\mathrm{M} \\
% \mathrm{Rating_{mn}} &\sim \mathcal{B}(\mathrm{Per \ user \ mean_m}^\intercal \mathbf{X}_\mathrm{n}), \ \mathrm{n}=1,...,\mathrm{N} \\
% \end{split}
% \end{equation}
%
% where $\mathrm{Per \ user \ mean}_\mathrm{i} \in \mathbb{R}^{18}$, $\mathbf{X} \in \mathbb{R}^{(N,18)}$ is per film features corresponding to genre and $\mathcal{B}$ denotes a Bernoulli distribution. We assume a global mean and variance for the per user latent mean and then model the ratings as a logistic regression with weights given by the latent mean and features given by the aforementioned matrix of film genre tags. A corresponding graphical model can be seen in \ref{fig:movielens_gm}.
%
% % \begin{equation}
% % \begin{split}
% % p(\mathbf{\mu_Z}, \mathbf{\psi_Z}, \mathbf{\psi_y}, \mathbf{y}, \mathbf{Z}) = \mathcal{N}(\mathbf{\mu_Z} \vert 0, 1)
% % \mathcal{N}(\mathbf{\psi_Z} \vert 0, 1)
% % \prod^M_{\mathrm{i}=1}\mathcal{N}(\mathbf{Z}_\mathrm{i} \vert \mathbf{\mu_Z}\mathbf{1}_{d_\mathbf{Z}}, \exp(\mathbf{\psi_Z})) \prod^N_{\mathrm{j}=1}\mathcal{B}(y_\mathrm{ij} \vert \mathbf{Z}^\intercal_\mathrm{i}  \mathbf{X}_\mathrm{j}))
% % \end{split}
% % \end{equation}
%
% % where $\mathbf{Z}_\mathrm{i} \in \mathbb{R}^{18}$, $\mathbf{X} \in \mathbb{R}^{(N,18)}$ and $\mathcal{B}$ denotes a Bernoulli distribution.
%
% We use the MovieLens dataset to generate datasets of multiple sizes: we consider numbers of users $\mathrm{M} \in \{50,150,300\}$ and number of ratings per user of $\mathrm{N} \in \{5,10\}$.
%
% Results for variational inference are shown in figure \ref{fig:movielens_tpp}.
% All methods correspond to plain VI for $K=1$. It is unsurprising that all methods improve with increasing K, increasing the number of importance weights should always improve the ELBO. Similarly it should not be surprising that massively parallel outperforms both LIW and Global K: our massively parallel method is drawing $K^{(\mathrm{M}+2)}$ samples of the entire latent space, compared to Global K's $K$ samples and LIWs $K^M$ samples of only the $\mathrm{Per \ user \ mean}$ latent variable.
%
% \begin{figure*}
% \begin{center}
% \includegraphics[width=0.9\textwidth]{chart_movielens.pdf}
% \end{center}
% \caption{Results of variational inference for a hierarchical model on the movielens dataset with differing numbers of groups of users and observations per group, showing the final ELBO after 25k training iterations. All methods are equivalent to standard VI when $K=1$.}
% \label{fig:movielens_tpp}
% \end{figure*}
%
We use the following hierarchical model:
\begin{align}
\nonumber
\m &\sim \Normal(\mathbf{0}_{18},1) \\
\nonumber
\psi &\sim \operatorname{Categorical}([0.1,0.5,0.4,0.05,0.05]) \\
\nonumber
\mathbf{z}_m &\sim \Normal(\m,\exp(\psi) \I), \ m=1,\dotsc,M \\
\label{eq:movielens_vi_rws}
\mathrm{Rating}_{mj} &\sim \operatorname{Bernoulli}(\sigma(\mathbf{z}_m^\intercal \mathbf{x}_j)), \ j=1,\dotsc,\mathrm{N}
\end{align}
This model, first samples a global mean, $\m$, and a discrete variance, $\psi$.
We then sample a vector, $\mathbf{z}_m$ for each user, which describes the types of films that they will rate highly.
The probability of a high rating is then given by taking the dot-product of the latent user-vector, $\mathbf{z}_m$ and the film's feature vector, $\mathbf{x}_j$. A corresponding graphical model can be seen in Appendix \ref{movielens}.

Note that this model has a discrete latent variable $\psi$.
As RWS does not reparameterise gradients of the ELBO, inference can proceed straightforwardly, without needing any approaches to discrete latent variables from VI, such as summing out the latent variable, applying \textsc{REINFORCE} gradient estimators or using continuous relaxations \citep{le2020revisiting}.
We compare the two methods by calculating the predictive log likelihood on a test set the same size as the training set.

To evaluate inference methods effectively, it is important to ensure that the posterior distributions are broad, and have not collapsed to very narrow point-like distributions.
As such, we evaluate on subsets of the full MovieLens dataset, composed of either 5 or 10 films per user, and 50, 150 or 300 users.

\begin{figure*}[ht]
  \centering
  \includegraphics[width=0.45\linewidth]{timeseries.pdf}
  \includegraphics[width=0.45\linewidth]{timeseries_more_obs.pdf}
  \caption{Comparison of performance between massively parallel methods, TMC, particle filter and global importance weighting. \textbf{a} The timeseries model with one observation. \textbf{b} The timeseries model with multiple observations.}
  \label{fig:timeseries}
\end{figure*}
% \begin{equation}
% \begin{split}
% p(\mathbf{\mu_Z}, \mathbf{\psi_Z}, \mathbf{\psi_y}, \mathbf{y}, \mathbf{Z}) = \ & \mathcal{N}(\mathbf{\mu_Z} \vert 0, 1)
% \mathrm{Categorical}(\mathbf{\psi_Z} \vert [0.1,0.5,0.4,0.05,0.05])\\
% &\prod^M_{\mathrm{i}=1}\mathcal{N}(\mathbf{Z}_\mathrm{i} \vert \mathbf{\mu_Z}\mathbf{1}_{d_\mathbf{Z}}, \exp(\mathbf{\psi_Z})) \prod^N_{\mathrm{j}=1}\mathcal{B}(y_\mathrm{ij} \vert \mathbf{Z}^\intercal_\mathrm{i}  \mathbf{X}_\mathrm{j}))
% \end{split}
% \end{equation}

% This is the same as the model in \ref{eq:movielens_vi}, except for the discrete categorical prior over the $\mathrm{Per \ user \ mean}$ variance $\psi$.

Results are shown in Fig.~\ref{fig:movielens_rws} and Fig.~\ref{fig:movielens_comp}.
massively parallel RWS gives considerably higher predictive log-likelihoods for all $K$ (Fig.~\ref{fig:movielens_rws},~\ref{fig:movielens_comp}a).
Importantly, the massively parallel RWS updates are more complex than the global RWS updates, so may take longer.
We therefore also considered the performance, measured as the predictive log-likelihood, against the time for a single training iteration.
We again found considerable, albeit less dramatic, improvements (Fig.~\ref{fig:movielens_comp}b).

% By comparing the performance of Global RWS and MP RWS in fig. \ref{fig:movielens_N5M300} it appears that the effective sample size of the massively parallel method with $K$=3 is somewhere between 3000 and 10000.




% \begin{figure}
% \begin{center}
% \includegraphics{chart_movielens_rws.pdf}
% \end{center}
% \caption{Results of inference for the deeper hierarchical model on synthetic datasets of multiple sizes.}
% \label{fig:movielens_rws}
% \end{figure}

% \subsection{Radon dataset}
% The radon dataset \cite{price1996bayesian, gelman2006data} consists of readings of radon levels in over 80,000 houses conducted by the Environmental Protection Agency in the United States stratified on the zipcode, county and state level. The dataset includes county level Uranium measurements, and whether each reading took place in a basement or not.

% To model the radon level we use the model described in equation \ref{eq:radon} in section \ref{radon} of the appendix.

% % \begin{equation}
% % \label{eq:radon}
% % \begin{split}
% % \mathrm{StateVariance} &\sim \Normal(0,1) \\
% % \mathrm{StateMean} &\sim \Normal(0,10^{-4}) \\
% % \mathrm{CountyMean_m} &\sim \Normal(\mathrm{StateMean},\exp(\mathrm{StateVariance})), \ \mathrm{m}=1,...,\mathrm{M} \\
% % \mathrm{CountyVariance_m} &\sim \Normal(0,1), \mathrm{m}=1,...,\mathrm{M} \\
% % \mathrm{w_j} &\sim \Normal(0,1), \ \mathrm{j}=1,...,\mathrm{J} \\
% % \mathrm{ZipMean_{mj}} &\sim \Normal(\mathrm{CountyMean_m} + \mathrm{w_j} * \mathrm{Uranium}_\mathrm{mj},\exp(\mathrm{CountyVariance_m})), \ \mathrm{j}=1,...,\mathrm{J} \\
% % \mathrm{ZipVariance_j} &\sim \Normal(0,1), \ \mathrm{j}=1,...,\mathrm{J} \\
% % \mathrm{ReadingMean_{mji}} &\sim \Normal(\mathrm{ZipMean_{mj}},\exp(\mathrm{ZipVariance_j})), \ \mathrm{i}=1,...,\mathrm{I} \\
% % \mathrm{ReadingVariance_i} &\sim \Normal(0,1), \ \mathrm{i}=1,...,\mathrm{I} \\
% % \mathrm{b_n} &\sim \Normal(0,1), \ \mathrm{n}=1,...,\mathrm{N} \\
% % \mathrm{Reading_{mjin}} &\sim \Normal(\mathrm{ReadingMean_{mji}} + b_{\mathrm{n}} * \mathrm{Basement_{mji}},\exp(\mathrm{ReadingVariance_i})), \ \mathrm{n}=1,...,\mathrm{N} \\
% % \end{split}
% % \end{equation}

% % \begin{equation}
% % \label{eq:radon}
% % \begin{split}
% % \mathrm{StateVariance} &\sim \mathcal{U}(0,100) \\
% % \mathrm{StateMean} &\sim \Normal(0,10^{-4}) \\
% % \mathrm{CountyMean_m} &\sim \Normal(\mathrm{StateMean_m},\mathrm{StateVariance_m}), \ \mathrm{m}=1,...,\mathrm{M} \\
% % \mathrm{CountyVariance_m} &\sim \mathcal{U}(0,100), \mathrm{m}=1,...,\mathrm{M} \\
% % \mathrm{w_j} &\sim \mathcal{U}(0,100), \ \mathrm{j}=1,...,\mathrm{J} \\
% % \mathrm{ZipMean_{mj}} &\sim \Normal(\mathrm{CountyMean_m} + \mathrm{w_j} * \mathrm{Uranium}_\mathrm{mj},\mathrm{CountyVariance_m}), \ \mathrm{j}=1,...,\mathrm{J} \\
% % \mathrm{ZipVariance_j} &\sim \mathcal{U}(0,100), \ \mathrm{j}=1,...,\mathrm{J} \\
% % \mathrm{ReadingMean_{mji}} &\sim \Normal(\mathrm{ZipMean_{mj}},\mathrm{ZipVariance_ji}), \ \mathrm{i}=1,...,\mathrm{I} \\
% % \mathrm{ReadingVariance_i} &\sim \mathcal{U}(0,100), \ \mathrm{i}=1,...,\mathrm{I} \\
% % \mathrm{b_n} &\sim \Normal(0,1), \ \mathrm{n}=1,...,\mathrm{N} \\
% % \mathrm{Reading_{mjin}} &\sim \Normal(\mathrm{ReadingMean_{mji}} + b_n * \mathrm{Basement_{mji}},\mathrm{ReadingVariance_i}), \ \mathrm{n}=1,...,\mathrm{N} \\
% % \end{split}
% % \end{equation}

% % \begin{equation}
% % \begin{split}
% % p(\sigma_\beta, \mu_\beta, \beta, \gamma, \sigma_\alpha, \alpha, \sigma_\omega, \omega, \sigma_\mathbf{y}, \psi_\mathrm{int}, \mathbf{y}) =& \
% % \prod^M_{\mathrm{m}=1}\mathcal{U}_{[0,100]}(\sigma_{\beta_m})
% % \mathcal{N}(\mu_{\beta_m} | 0, 10^{-4})\mathcal{N}(\beta_\mathrm{m} | \mu_{\beta_m}, \sigma_{\beta_m}) \\
% % &\prod^J_{\mathrm{j}=1}\mathcal{U}_{[0,100]}(\gamma_j)\mathcal{U}_{[0,100]}(\sigma_{\alpha_j})\mathcal{N}(\alpha_\mathrm{mj} | \beta_\mathrm{m} + \gamma_j * \mathrm{Uranium}_\mathrm{mj}, \sigma_{\alpha_j}) \\
% % &\prod^I_{\mathrm{i}=1}\mathcal{U}_{[0,100]}(\sigma_{\omega_i})\mathcal{N}(\omega_\mathrm{mji} | \alpha_\mathrm{mj}, \sigma_{\omega_i}) \\
% % &\prod^N_{\mathrm{n}=1}\mathcal{U}_{[0,100]}(\sigma_{\mathbf{y}_n})
% % \mathcal{N}(\psi_{\mathrm{int}_n} | 0, 1) \mathcal{N}(\mathbf{y}_\mathrm{mjin} | \omega_\mathrm{mji} + \psi_{\mathrm{int}_n} * \mathrm{Basement}_\mathrm{mji}, \sigma_{\mathbf{y}_n})
% % \end{split}
% % \end{equation}


% We consider a dataset of 128 readings, $\mathrm{N}=2$ readings in $\mathrm{J}=4$ zipcodes in $\mathrm{I}=4$ counties in $\mathrm{M}=4$ states. We optimise the evidence lower bound for $50$k training iterations.

% Results are shown in \ref{fig:radon_chart}.
% Unlike with the previous datasets, LIW shows almost no improvement in lower bound with increasing K and is outperformed for all $K>1$ by both the massively parallel method and Global K.
% By modifying LIW to permit multiple samples for the other latents we confirmed that large improvements in the lower bound only occur once multiple samples can be drawn of the $\mathrm{StateMean}$ latent, which explains the relative performance of the three methods.
% massively parallel on the other hand shows an improving lower bound with increased K as in other models.
% \begin{figure}
% \begin{center}
% \includegraphics[width=0.45\textwidth]{chart_radon_alongreadings.pdf}
% \end{center}
% \caption{Results of inference for the radon dataset showing final lower bound after $50$k iterations.}
% \label{fig:radon_chart}
% \end{figure}

% Figure \ref{fig:radon_chart_predll}, which shows the average predictive log likelihood of the posteriors
% \begin{figure}
% \begin{center}
% \includegraphics{chart_radon_alongreadings_predll.pdf}
% \end{center}
% \caption{Results of inference for the radon dataset showing predictive $\log$ likelihoods bound after $25$k iterations.}
% \label{fig:radon_chart_predll}
% \end{figure}

\subsection{NYC Bus breakdown dataset}
The city of New York releases data on the length of delays to school bus journeys \citep{nyc2023bus}. We model the length of the delay in terms of the type of journey, the school year in which the delay occurred, the borough the delay occurred in and the ID of the bus that was delayed.

To model delay time we use the model outlined in Appendix \ref{bus}. Because the dataset can be stratified into a hierarchy with three levels (Year, Borough and ID) we want our model to reflect this and, inspired by attempts to use hierarchical regression to model radon levels indoor radon levels \citep{price1996bayesian}, we use a similar multi-level regression with three levels. This model first samples a variance and mean for each year, then uses these to sample a borough mean for each year. A variance is then sampled for each borough, which together with the year level borough mean is used to sample an ID mean for each year and borough. Finally, a variance is sampled and used to sample two weight vectors, $\mathbf{C}_i$ which has length ``Number of bus companies'' and $\mathbf{J}_i$ which has length ``Number of types of journeys''. These are used to weight covariates that indicate which bus company was running a given ID's route and which type of journey was being undertaken respectively. These are then summed with the sampled ID mean for that year and borough to get the logits for a negative binomial distribution that then gives the predicted delay for the $i$-th ID in the $j$-th borough in the $m$-th year. A corresponding graphical model can be seen in Appendix \ref{bus_gm}.


We evaluate this model using a training dataset with $270$ observations: $I=30$ Ids from $J=3$ Boroughs in $M=3$ Years. We perform RWS for $75$k iterations, and evaluate the predictive log likelihood on a held out test set the same size as the training set.

% \begin{equation}
% \begin{split}
% p(\sigma_\beta, \mu_\beta, \beta, \sigma_\alpha, \alpha, \sigma_{\phi\psi}, \psi, \phi, \mathbf{y}) =& \
% \prod^M_{\mathrm{m}=1}\mathrm{Cat}(\mathbf{\sigma_{\beta_\mathrm{m}}} \vert [0.1,0.5,0.4,0.05,0.05])
% \mathcal{N}(\mu_{\beta_m} | 0, 10^{-4})\mathcal{N}(\beta_\mathrm{m} | \mu_{\beta_m}, \exp\sigma_{\beta_m}) \\
% &\prod^J_{\mathrm{j}=1}\mathrm{Cat}(\mathbf{\sigma_{\alpha_\mathrm{j}}} \vert [0.1,0.4,0.05,0.5,0.05])\mathcal{N}(\alpha_\mathrm{mj} | \beta_\mathrm{m}, \exp\sigma_{\alpha_j}) \\
% &\prod^I_{\mathrm{i}=1}\mathrm{Cat}(\mathbf{\sigma_{\phi\psi_\mathrm{i}}} \vert [0.1,0.4,0.5,0.05,0.05])\mathcal{N}(\psi_\mathrm{i} | 0, \sigma_{\phi\psi_\mathrm{i}})
% \mathcal{N}(\phi_\mathrm{i} | 0, \sigma_{\phi\psi_\mathrm{i}}) \\
% &\mathrm{NegativeBinomial}(\mathbf{y}_\mathrm{mji} | \alpha_\mathrm{mj} + \phi_\mathrm{i} * \mathrm{Bus \ company \ name} + \psi_\mathrm{i} * \mathrm{Journey \ type}, \sigma_{\phi\psi_\mathrm{i}})
% \end{split}
% \end{equation}

% \begin{equation}
% \label{eq:bus model}
% \begin{split}
% \mathrm{YearVariance_m} &\sim \mathrm{Cat}([0.1,0.5,0.4,0.05,0.05]) \ \mathrm{m}=1,...,\mathrm{M} \\
% \mathrm{YearMean_m} &\sim \Normal(0,10^{-4}) \ \mathrm{m}=1,...,\mathrm{M} \\
% \mathrm{BoroughMean_m} &\sim \Normal(\mathrm{YearMean_m},\exp(\mathrm{YearVariance_m})), \ \mathrm{m}=1,...,\mathrm{M} \\
% \mathrm{BoroughVariance_j} &\sim \mathrm{Cat}([0.1,0.4,0.05,0.5,0.05]), \mathrm{j}=1,...,\mathrm{J} \\
% \mathrm{IdMean_{mj}} &\sim \Normal(\mathrm{BoroughMean_m},\mathrm{BoroughVariance_m}), \ \mathrm{j}=1,...,\mathrm{J} \\
% \mathrm{WeightVariance_i} &\sim \mathrm{Cat}([0.1,0.4,0.5,0.05,0.05]), \ \mathrm{i}=1,...,\mathrm{I} \\
% \mathrm{Cn_i} &\sim \mathcal{N}(0, \mathrm{WeightVariance_i}), \ \mathrm{i}=1,...,\mathrm{I} \\
% \mathrm{Jt_i} &\sim \mathcal{N}(0, \mathrm{WeightVariance_i}), \ \mathrm{i}=1,...,\mathrm{I} \\
% \mathrm{logits_{mji}} &= \mathrm{IdMean_{mj}} + \mathrm{Cn_i} * \mathrm{Bus \ company \ name_{mji}} + \mathrm{Jt_i} * \mathrm{Journey \ type_{mji}} \\
% \mathrm{Delay_{mji}} &\sim \mathrm{NegativeBinomial}(\mathrm{total \ count}=130, \mathrm{logits_{mji}}), \ \mathrm{i}=1,...,\mathrm{I} \\
% \end{split}
% \end{equation}

%To model the radon level we use the model described in equation \ref{eq:bus model} in section \ref{bus} of the appendix. A corresponding graphical model can be seen in \ref{fig:bus_gm}.
%
%
%We consider a training dataset of 270 readings, $N=3$ Years in $I=3$ Boroughs and $J=30$ Ids. We perform RWS for $75$k iterations, and evalutate the predictive log likelihood on a held out test set the same size as the training set.

Results are shown in Fig.~\ref{fig:bus_chart}. Again we see that massively parallel RWS outperforms global RWS for all $K$.
% \begin{figure}[!h]
% \begin{center}
% \includegraphics[width=0.45\textwidth]{chart_bus.pdf}
% \end{center}
% \caption{Results of inference for the bus breakdown dataset showing predictive log likelihood after $75$k iterations.}
% \label{fig:bus_chart}
% \end{figure}



% \begin{figure}[!h]
% \begin{center}
% \includegraphics[width=0.45\textwidth]{chart_bus_time.pdf}
% \end{center}
% \caption{Time for a single training iteration plotted against predictive log likelihood for massively parallel RWS and globally importance sampled RWS for the NYC Bus Delay dataset.}
% \label{fig:bus_chart_time}
% \end{figure}

% \begin{figure*}[ht]
%   \centering
%   \begin{subfigure}[b]{.49\textwidth}
%     \includegraphics[width=\linewidth]{chart_bus.pdf}
%     \caption{Results of inference for the bus breakdown dataset showing predictive log likelihood after $75$k iterations.}
%     \label{fig:bus_chart}
%   \end{subfigure}%
%   ~
%   \begin{subfigure}[b]{.49\textwidth}
%     \includegraphics[width=\linewidth]{chart_bus_time.pdf}
%     \caption{Time for a single training iteration plotted against predictive log likelihood for massively parallel RWS and globally importance sampled RWS for the NYC Bus Delay dataset.}
%     \label{fig:bus_chart_time}
%   \end{subfigure}
%   \caption{Comparison of performance between massively parallel RWS, Global RWS, for the NYC bus breakdown dataset}
%   \label{fig:bus}
% \end{figure*}

\subsection{Comparing MP VI with TMC}

Even though our main contribution is in developing massively parallel RWS, our derivations also allow for slightly more general massively parallel approaches to VI.
In particular, our derivations allow us to couple the proposal for the $K$ samples of the $i$th latent variable, $z_i^1,\dotsc,z_i^K$, while TMC \citep{aitchison2019tensor} forces these $K$ samples to be IID.
This coupling in massively parallel methods allows us to introduce variance-reduction strategies inspired by methods for reducing particle degeneracy in particle filters \citedeg{} (see Appendix~\ref{app:ap} for further details).

To highlight these advantages, we considered two toy timeseries models: a single observation and a multi- observation model.
%For these timeseries models model the prior is used as a proposal. We compare both methods by computing the ELBO for $K \in \{1,3,10,30,300\}$ and plotting the average and standard error of 250 evaluations.

\subsubsection{Single Observation}
In the single observation model, there is a latent timeseries $z_1,\dotsc,z_{30}$ (we use $N=30$), and an observation, $x$, only at the last timestep,
\begin{equation}
\label{eq:timeseries}
\begin{aligned}
z_1 &= 0,\\
z_i  &\sim \Normal(z_{i-1}, 1/N),\\
x &\sim \Normal(z_N, 1)
\end{aligned}
\end{equation}
We use the prior to define the proposal (see Appendix~\ref{app:ap}).
Results can be seen in figure \ref{fig:timeseries}a.
For large $K$, all methods converge to the same value, as the ELBOs are all bounded by the true model evidence.
To compare the methods, we therefore need to consider their relative performance for smaller values of $K$.
We can see that the TMC (orange) \citep{aitchison2019tensor} performs considerably worse than massively parallel VI (red) and IWAE (blue) \citep{burda2015importance}.
We believe that TMC is performing poorly because of particle degeneracy \citedeg{}.
In particular, the TMC proposal for $z_i$ is given by a mixture of the prior, conditioned particles from the previous timestep, $z_{i-1}$.
In sampling from this mixture, in essence, we first sample a parent particle, $z_{i-1}^{k_{i-1}}$, then we sample from the prior, conditioned on that parent sample, $\P{z_i^{k_i}| z_{i-1}^{k_{i-1}}}$.
In TMC, we choose these parent sample IID, which means that one parent particle, $z_{i-1}^{k_{i-1}}$ may have zero, one or multiple children.
This is problematic: whenever a parent sample has zero children, then this reduces diversity in the samples of $z_i$, and this issue builds up over timesteps.
Massively parallel methods circumvent this issue by ensuring that each parent sample has one and only one child sample (which requires us to couple the distribution over $z_i^1,\dotsc,z_i^K$), and IWAE avoids the issue by simply sampling $z_i^k$ conditioned on $z_{i-1}^{k}$.
Massively parallel is comparable to IWAE in this setting due to conditioning only a single scalar value at the end of the timeseries.
These methods separate when we consider multiple observations (next).

%
%We can see that massively parallel VI (red) outperforms TMC (orange).
%on this example for all $K$, but performs comparably with global importance weighting. This pattern is expected as the benefits of massively parallel importance weighting are not fully apparent in the case of a single observation. This is because in the case with more observations we expect the posterior to diverge further from the prior, and the larger number of ``effective'' importance weights drawn by massively parallel methods should counterbalance this.

\subsubsection{Multiple Observations}
Next, we considered a more standard timeseries with multiple observations.
We are implementing these methods in the context of a new probabilistic programming language.
This language currently has limitations on the number of latent variables that are inherited from the opt-einsum implementation.
As such, we were not able to do the obvious thing of having one observation at every timestep.
Instead, we had an observation every third timestep.
\begin{equation}
\label{eq:timeseries more obs}
\begin{aligned}
z_1 &\sim \Normal(0, 1),\\
z_i &\sim \Normal((1-\tfrac{1}{\tau})z_{i-1}, 2/\tau), \\
x_i &\sim \Normal(z_{i}, 1) \quad\quad \text{if $i$ divisible by 3}.
\end{aligned}
\end{equation}
Again, we use $N=30$.
Results can be seen in Fig.~\ref{fig:timeseries}.
Again, the methods converge as $K$ increases, but this time, massively parallel VI (red) gives better performance than both alternatives for lower values of $K$.
%Note here that the massively parallel method outperforms all other methods here for all $K$ as previously predicted. As in the previous example all methods are converging as $K$ increases.
%
%These examples are designed to be problematic for TMC which suffers from particle degeneracy issues, and shows that massively parallel methods reduce the effect of this problem by drawing $K$ samples from the joint instead of sampling from a mixture over the parent samples.



%\begin{figure*}[h]
%  \centering
%  \begin{subfigure}[b]{.5\textwidth}
%    \includegraphics[width=\linewidth]{time_timeseries.pdf}
%    \caption{Time vs. ELBO for timeseries with a single observation}
%    \label{fig:time_timeseries-a}
%  \end{subfigure}%
%  ~
%  \begin{subfigure}[b]{.5\textwidth}
%    \includegraphics[width=\linewidth]{time_timeseries_more_obs.pdf}
%    \caption{Time vs. ELBO for timeseries with multiple observations}
%    \label{fig:time_timeseries-b}
%  \end{subfigure}
%  \caption{Comparison of time per ELBO evaluation plotted against performance between massively parallel methods, TMC, particle filter and global importance weighting for \hyperref[fig:timeseries-a]{(a)} the timeseries model with one observation, and \hyperref[fig:timeseries-b]{(b)} the timeseries model with multiple observations. Points at the lower part of the y-scale correspond to fewer $K$ samples. Note that the massively parallel method requires a shorter time to achieve a greater ELBO compared to TMC.}
%  \label{fig:time_timeseries}
%\end{figure*}

% \begin{figure*}
% \centering
%   \subcaptionbox{ELBO for timeseries with a single observation}{\includegraphics{timeseries.pdf}}\hspace{1em}%
%   \subcaptionbox{ELBO for timeseries with multiple observations}{\includegraphics{timeseries_more_obs.pdf}}

%   \caption{Comparison of performance between massively parallel methods, TMC, particle filter and global importance weighting for (a) the timeseries model with one observation, and (b) the timeseries model with multiple observations.}
%   \label{fig:timeseries}
% \end{figure*}
\section{Conclusion}
We introduced massively parallel RWS, in which we draw $K$ samples for $n$ latent variables, and efficiently consider all $K^n$ combinations by exploiting conditional independencies in the generative model.
We showed that massively parallel RWS represents a considerable improvement over previous RWS methods that draw $K$ samples from the full joint latent space.

\bibliography{Heap_216}

\end{document}
