Abstract: With the advent of automatic vectorization tools (e.g., JAX's vmap), writing multi-chain MCMC algorithms is often now as simple as invoking those tools on single-chain code. Whilst convenient, for various MCMC algorithms this results in a synchronization problem---loosely speaking, at each iteration all chains running in parallel must wait until the last chain has finished drawing its sample. In this work, we show how to design single-chain MCMC algorithms in a way that avoids synchronization overheads when vectorizing with tools like vmap, by using the framework of finite state machines (FSMs). Using a simplified model, we derive an exact theoretical form of the obtainable speed-ups using our approach, and use it to make principled recommendations for optimal algorithm design. We implement several popular MCMC algorithms as FSMs, including Elliptical Slice Sampling, HMC-NUTS, and Delayed Rejection, demonstrating speed-ups of up to an order of magnitude in experiments.
Lay Summary: Many important scientific tasks - like estimating how a new drug will affect patients or predicting future climate patterns - rely on computing quantities of a distribution that is not known in closed form. In practice, one often uses a class of simulation techniques called "Markov Chain Monte Carlo" (MCMC), to draw samples from this unknown distribution and use them to estimate these quantities (e.g. its mean or mode). Modern hardware such as GPUs can run multiple MCMC simulations (often called "chains") in parallel to save time. Several “automatic vectorization” tools have been developed that take code for running a single chain, and transform it into code that can run a batch of chains together in lockstep. These tools have made it very convenient for researchers to leverage the power of this modern hardware. However, a limitation arises whenever each chain does a random amount of work to produce each sample. Since the chains must run in lockstep, all chains must wait at every sampling step until the slowest chain finishes returning its sample. In other words, the time taken to draw every sample is driven by the slowest chain, rather than the typical chain.
In our work, we break each chain’s logic into just a few simple stages (e.g. “propose a candidate,” “compute its score,” “decide to accept or reject,” and “record the result”) and keep track of which stage each chain is in. We then pack all of these stages into a single, GPU-friendly function that can still be conveniently transformed with existing tools to run batches of simulations. At each step, the function looks at every chain’s current stage and does the required computation, allowing all chains in the batch to move forward immediately no matter what stage they are currently at. This “finite‐state machine” approach removes thousands of tiny synchronization pauses and keeps the hardware busy much more effectively. On real problems such as modeling house prices; analyzing stock price trends; and modeling predator-prey patterns, we see up to 10× more samples per second, even though the mathematical steps of each simulation remain the same. As a result, users can obtain the same statistical conclusions but in potentially a fraction of the wall‐clock time, allowing faster progress in fields like medicine, climate science, economics and engineering.
Link To Code: https://github.com/hwdance/jax-fsm-mcmc
Primary Area: Probabilistic Methods->Monte Carlo and Sampling Methods
Keywords: Markov Chain Monte Carlo, MCMC, Parallel Computing
Submission Number: 12121
Loading