{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Bad posterior geometry and how to deal with it"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "HMC and its variant NUTS use gradient information to draw (approximate) samples from a posterior distribution. \n",
    "These gradients are computed in a *particular coordinate system*, and different choices of coordinate system can make HMC more or less efficient. \n",
    "This is analogous to the situation in constrained optimization problems where, for example, parameterizing a positive quantity via an exponential versus softplus transformation results in distinct optimization dynamics.\n",
    "\n",
    "Consequently it is important to pay attention to the *geometry* of the posterior distribution. \n",
    "Reparameterizing the model (i.e. changing the coordinate system) can make a big practical difference for many complex models. \n",
    "For the most complex models it can be absolutely essential. For the same reason it can be important to pay attention to some of the hyperparameters that control HMC/NUTS (in particular the `max_tree_depth` and `dense_mass`). \n",
    "\n",
    "In this tutorial we explore models with bad posterior geometries---and what one can do to get achieve better performance---with a few concrete examples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from jax import random\n",
    "import jax.numpy as jnp\n",
    "\n",
    "import numpyro\n",
    "from numpyro.diagnostics import summary\n",
    "import numpyro.distributions as dist\n",
    "from numpyro.infer import MCMC, NUTS\n",
    "\n",
    "assert numpyro.__version__.startswith(\"0.15.0\")\n",
    "\n",
    "# NB: replace cpu by gpu to run this notebook on gpu\n",
    "numpyro.set_platform(\"cpu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We begin by writing a helper function to do NUTS inference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_inference(\n",
    "    model, num_warmup=1000, num_samples=1000, max_tree_depth=10, dense_mass=False\n",
    "):\n",
    "    kernel = NUTS(model, max_tree_depth=max_tree_depth, dense_mass=dense_mass)\n",
    "    mcmc = MCMC(\n",
    "        kernel,\n",
    "        num_warmup=num_warmup,\n",
    "        num_samples=num_samples,\n",
    "        num_chains=1,\n",
    "        progress_bar=False,\n",
    "    )\n",
    "    mcmc.run(random.PRNGKey(0))\n",
    "    summary_dict = summary(mcmc.get_samples(), group_by_chain=False)\n",
    "\n",
    "    # print the largest r_hat for each variable\n",
    "    for k, v in summary_dict.items():\n",
    "        spaces = \" \" * max(12 - len(k), 0)\n",
    "        print(\"[{}] {} \\t max r_hat: {:.4f}\".format(k, spaces, np.max(v[\"r_hat\"])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluating HMC/NUTS\n",
    "\n",
    "In general it is difficult to assess whether the samples returned from HMC or NUTS represent accurate (approximate) samples from the posterior. \n",
    "Two general rules of thumb, however, are to look at the effective sample size (ESS) and `r_hat` diagnostics returned by `mcmc.print_summary()`.\n",
    "If we see values of `r_hat` in the range `(1.0, 1.05)` and effective sample sizes that are comparable to the total number of samples `num_samples` (assuming `thinning=1`) then we have good reason to believe that HMC is doing a good job. \n",
    "If, however, we see low effective sample sizes or large `r_hat`s for some of the variables (e.g. `r_hat = 1.15`) then HMC is likely struggling with the posterior geometry. \n",
    "In the following we will use `r_hat` as our primary diagnostic metric."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model reparameterization\n",
    "\n",
    "### Example #1\n",
    "\n",
    "We begin with an example (horseshoe regression; see [examples/horseshoe_regression.py](https://github.com/pyro-ppl/numpyro/blob/master/examples/horseshoe_regression.py) for a complete example script) where reparameterization helps a lot. \n",
    "This particular example demonstrates a general reparameterization strategy that is useful in many models with hierarchical/multi-level structure. \n",
    "For more discussion of some of the issues that can arise in hierarchical models see reference [1]."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# In this unreparameterized model some of the parameters of the distributions\n",
    "# explicitly depend on other parameters (in particular beta depends on lambdas and tau).\n",
    "# This kind of coordinate system can be a challenge for HMC.\n",
    "def _unrep_hs_model(X, Y):\n",
    "    lambdas = numpyro.sample(\"lambdas\", dist.HalfCauchy(jnp.ones(X.shape[1])))\n",
    "    tau = numpyro.sample(\"tau\", dist.HalfCauchy(jnp.ones(1)))\n",
    "    betas = numpyro.sample(\"betas\", dist.Normal(scale=tau * lambdas))\n",
    "    mean_function = jnp.dot(X, betas)\n",
    "    numpyro.sample(\"Y\", dist.Normal(mean_function, 0.05), obs=Y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To deal with the bad geometry that results form this coordinate system we change coordinates using the following re-write logic.\n",
    "Instead of \n",
    "\n",
    "$$ \\beta \\sim {\\rm Normal}(0, \\lambda \\tau) $$\n",
    "\n",
    "we write\n",
    "\n",
    "$$ \\beta^\\prime \\sim {\\rm Normal}(0, 1) $$\n",
    "\n",
    "and\n",
    "\n",
    "$$ \\beta \\equiv \\lambda \\tau \\beta^\\prime  $$\n",
    "\n",
    "where $\\beta$ is now defined *deterministically* in terms of $\\lambda$, $\\tau$,\n",
    "and $\\beta^\\prime$. \n",
    "In effect we've changed to a coordinate system where the different\n",
    "latent variables are less correlated with one another. \n",
    "In this new coordinate system we can expect HMC with a diagonal mass matrix to behave much better than it would in the original coordinate system.\n",
    "\n",
    "There are basically two ways to implement this kind of reparameterization in NumPyro:\n",
    "\n",
    "- manually (i.e. by hand)\n",
    "- using [numpyro.infer.reparam](http://num.pyro.ai/en/stable/reparam.html), which automates a few common reparameterization strategies\n",
    "\n",
    "To begin with let's do the reparameterization by hand."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# In this reparameterized model none of the parameters of the distributions\n",
    "# explicitly depend on other parameters. This model is exactly equivalent\n",
    "# to _unrep_hs_model but is expressed in a different coordinate system.\n",
    "def _rep_hs_model1(X, Y):\n",
    "    lambdas = numpyro.sample(\"lambdas\", dist.HalfCauchy(jnp.ones(X.shape[1])))\n",
    "    tau = numpyro.sample(\"tau\", dist.HalfCauchy(jnp.ones(1)))\n",
    "    unscaled_betas = numpyro.sample(\n",
    "        \"unscaled_betas\", dist.Normal(scale=jnp.ones(X.shape[1]))\n",
    "    )\n",
    "    scaled_betas = numpyro.deterministic(\"betas\", tau * lambdas * unscaled_betas)\n",
    "    mean_function = jnp.dot(X, scaled_betas)\n",
    "    numpyro.sample(\"Y\", dist.Normal(mean_function, 0.05), obs=Y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we do the reparameterization using [numpyro.infer.reparam](http://num.pyro.ai/en/stable/reparam.html). \n",
    "There are at least two ways to do this. \n",
    "First let's use [LocScaleReparam](http://num.pyro.ai/en/stable/reparam.html#numpyro.infer.reparam.LocScaleReparam)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from numpyro.infer.reparam import LocScaleReparam\n",
    "\n",
    "# LocScaleReparam with centered=0 fully \"decenters\" the prior over betas.\n",
    "config = {\"betas\": LocScaleReparam(centered=0)}\n",
    "# The coordinate system of this model is equivalent to that in _rep_hs_model1 above.\n",
    "_rep_hs_model2 = numpyro.handlers.reparam(_unrep_hs_model, config=config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To show the versatility of the [numpyro.infer.reparam](http://num.pyro.ai/en/stable/reparam.html) library let's do the reparameterization using [TransformReparam](http://num.pyro.ai/en/stable/reparam.html#numpyro.infer.reparam.TransformReparam) instead."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from numpyro.distributions.transforms import AffineTransform\n",
    "from numpyro.infer.reparam import TransformReparam\n",
    "\n",
    "\n",
    "# In this reparameterized model none of the parameters of the distributions\n",
    "# explicitly depend on other parameters. This model is exactly equivalent\n",
    "# to _unrep_hs_model but is expressed in a different coordinate system.\n",
    "def _rep_hs_model3(X, Y):\n",
    "    lambdas = numpyro.sample(\"lambdas\", dist.HalfCauchy(jnp.ones(X.shape[1])))\n",
    "    tau = numpyro.sample(\"tau\", dist.HalfCauchy(jnp.ones(1)))\n",
    "\n",
    "    # instruct NumPyro to do the reparameterization automatically.\n",
    "    reparam_config = {\"betas\": TransformReparam()}\n",
    "    with numpyro.handlers.reparam(config=reparam_config):\n",
    "        betas_root_variance = tau * lambdas\n",
    "        # in order to use TransformReparam we have to express the prior\n",
    "        # over betas as a TransformedDistribution\n",
    "        betas = numpyro.sample(\n",
    "            \"betas\",\n",
    "            dist.TransformedDistribution(\n",
    "                dist.Normal(0.0, jnp.ones(X.shape[1])),\n",
    "                AffineTransform(0.0, betas_root_variance),\n",
    "            ),\n",
    "        )\n",
    "\n",
    "    mean_function = jnp.dot(X, betas)\n",
    "    numpyro.sample(\"Y\", dist.Normal(mean_function, 0.05), obs=Y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally we verify that `_rep_hs_model1`, `_rep_hs_model2`, and  `_rep_hs_model3` do indeed achieve better `r_hat`s than `_unrep_hs_model`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "unreparameterized model (very bad r_hats)\n",
      "[betas]         \t max r_hat: 1.0775\n",
      "[lambdas]       \t max r_hat: 3.2450\n",
      "[tau]           \t max r_hat: 2.1926\n",
      "\n",
      "reparameterized model with manual reparameterization (good r_hats)\n",
      "[betas]         \t max r_hat: 1.0074\n",
      "[lambdas]       \t max r_hat: 1.0146\n",
      "[tau]           \t max r_hat: 1.0036\n",
      "[unscaled_betas]  \t max r_hat: 1.0059\n",
      "\n",
      "reparameterized model with LocScaleReparam (good r_hats)\n",
      "[betas]         \t max r_hat: 1.0103\n",
      "[betas_decentered]  \t max r_hat: 1.0060\n",
      "[lambdas]       \t max r_hat: 1.0124\n",
      "[tau]           \t max r_hat: 0.9998\n",
      "\n",
      "reparameterized model with TransformReparam (good r_hats)\n",
      "[betas]         \t max r_hat: 1.0087\n",
      "[betas_base]    \t max r_hat: 1.0080\n",
      "[lambdas]       \t max r_hat: 1.0114\n",
      "[tau]           \t max r_hat: 1.0029\n"
     ]
    }
   ],
   "source": [
    "# create fake dataset\n",
    "X = np.random.RandomState(0).randn(100, 500)\n",
    "Y = X[:, 0]\n",
    "\n",
    "print(\"unreparameterized model (very bad r_hats)\")\n",
    "run_inference(partial(_unrep_hs_model, X, Y))\n",
    "\n",
    "print(\"\\nreparameterized model with manual reparameterization (good r_hats)\")\n",
    "run_inference(partial(_rep_hs_model1, X, Y))\n",
    "\n",
    "print(\"\\nreparameterized model with LocScaleReparam (good r_hats)\")\n",
    "run_inference(partial(_rep_hs_model2, X, Y))\n",
    "\n",
    "print(\"\\nreparameterized model with TransformReparam (good r_hats)\")\n",
    "run_inference(partial(_rep_hs_model3, X, Y))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Aside: numpyro.deterministic\n",
    "\n",
    "In `_rep_hs_model1` above we used [numpyro.deterministic](http://num.pyro.ai/en/stable/primitives.html?highlight=deterministic#numpyro.primitives.deterministic) to define `scaled_betas`.\n",
    "We note that using this primitive is not strictly necessary; however, it has the consequence that `scaled_betas` will appear in the trace and will thus appear in the summary reported by `mcmc.print_summary()`. \n",
    "In other words we could also have written:\n",
    "\n",
    "```\n",
    "scaled_betas = tau * lambdas * unscaled_betas\n",
    "```\n",
    "\n",
    "without invoking the `deterministic` primitive."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Mass matrices\n",
    "By default HMC/NUTS use diagonal mass matrices. \n",
    "For models with complex geometries it can pay to use a richer set of mass matrices.\n",
    "\n",
    "\n",
    "### Example #2\n",
    "In this first simple example we show that using a full-rank (i.e. dense) mass matrix leads to a better `r_hat`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dense_mass = False (bad r_hat)\n",
      "[x]             \t max r_hat: 1.3810\n",
      "dense_mass = True (good r_hat)\n",
      "[x]             \t max r_hat: 0.9992\n"
     ]
    }
   ],
   "source": [
    "# Because rho is very close to 1.0 the posterior geometry\n",
    "# is extremely skewed and using the \"diagonal\" coordinate system\n",
    "# implied by dense_mass=False leads to bad results\n",
    "rho = 0.9999\n",
    "cov = jnp.array([[10.0, rho], [rho, 0.1]])\n",
    "\n",
    "\n",
    "def mvn_model():\n",
    "    numpyro.sample(\"x\", dist.MultivariateNormal(jnp.zeros(2), covariance_matrix=cov))\n",
    "\n",
    "\n",
    "print(\"dense_mass = False (bad r_hat)\")\n",
    "run_inference(mvn_model, dense_mass=False, max_tree_depth=3)\n",
    "\n",
    "print(\"dense_mass = True (good r_hat)\")\n",
    "run_inference(mvn_model, dense_mass=True, max_tree_depth=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Example #3\n",
    "\n",
    "Using `dense_mass=True` can be very expensive when the dimension of the latent space `D` is very large. \n",
    "In addition it can be difficult to estimate a full-rank mass matrix with `D^2` parameters using a moderate number of samples if `D` is large. In these cases `dense_mass=True` can be a poor choice. \n",
    "Luckily, the argument `dense_mass` can also be used to specify structured mass matrices that are richer than a diagonal mass matrix but more constrained (i.e. have fewer parameters) than a full-rank mass matrix ([see the docs](http://num.pyro.ai/en/stable/mcmc.html#hmc)).\n",
    "In this second example we show how we can use `dense_mass` to specify such a structured mass matrix."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "rho = 0.9\n",
    "cov = jnp.array([[10.0, rho], [rho, 0.1]])\n",
    "\n",
    "\n",
    "# In this model x1 and x2 are highly correlated with one another\n",
    "# but not correlated with y at all.\n",
    "def partially_correlated_model():\n",
    "    x1 = numpyro.sample(\n",
    "        \"x1\", dist.MultivariateNormal(jnp.zeros(2), covariance_matrix=cov)\n",
    "    )\n",
    "    x2 = numpyro.sample(\n",
    "        \"x2\", dist.MultivariateNormal(jnp.zeros(2), covariance_matrix=cov)\n",
    "    )\n",
    "    numpyro.sample(\"y\", dist.Normal(jnp.zeros(100), 1.0))\n",
    "    numpyro.sample(\"obs\", dist.Normal(x1 - x2, 0.1), jnp.ones(2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's compare two choices of `dense_mass`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dense_mass = False (very bad r_hats)\n",
      "[x1]            \t max r_hat: 1.5882\n",
      "[x2]            \t max r_hat: 1.5410\n",
      "[y]             \t max r_hat: 2.0179\n",
      "\n",
      "dense_mass = True (bad r_hats)\n",
      "[x1]            \t max r_hat: 1.0697\n",
      "[x2]            \t max r_hat: 1.0738\n",
      "[y]             \t max r_hat: 1.2746\n",
      "\n",
      "structured mass matrix (good r_hats)\n",
      "[x1]            \t max r_hat: 1.0023\n",
      "[x2]            \t max r_hat: 1.0024\n",
      "[y]             \t max r_hat: 1.0030\n"
     ]
    }
   ],
   "source": [
    "print(\"dense_mass = False (very bad r_hats)\")\n",
    "run_inference(partially_correlated_model, dense_mass=False, max_tree_depth=3)\n",
    "\n",
    "print(\"\\ndense_mass = True (bad r_hats)\")\n",
    "run_inference(partially_correlated_model, dense_mass=True, max_tree_depth=3)\n",
    "\n",
    "# We use dense_mass=[(\"x1\", \"x2\")] to specify\n",
    "# a structured mass matrix in which the y-part of the mass matrix is diagonal\n",
    "# and the (x1, x2) block of the mass matrix is full-rank.\n",
    "\n",
    "# Graphically:\n",
    "#\n",
    "#       x1 x2 y\n",
    "#   x1 | * * 0 |\n",
    "#   x2 | * * 0 |\n",
    "#   y  | 0 0 * |\n",
    "\n",
    "print(\"\\nstructured mass matrix (good r_hats)\")\n",
    "run_inference(partially_correlated_model, dense_mass=[(\"x1\", \"x2\")], max_tree_depth=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `max_tree_depth`\n",
    "\n",
    "The hyperparameter `max_tree_depth` can play an important role in determining the quality of posterior samples generated by NUTS. The default value in NumPyro is `max_tree_depth=10`. In some models, in particular those with especially difficult geometries, it may be necessary to increase `max_tree_depth` above `10`. In other cases where computing the gradient of the model log density is particularly expensive, it may be necessary to decrease `max_tree_depth` below `10` to reduce compute. As an example where large `max_tree_depth` is essential, we return to a variant of example #2. (We note that in this particular case another way to improve performance would be to use `dense_mass=True`).\n",
    "\n",
    "### Example #4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_tree_depth = 5 (bad r_hat)\n",
      "[x]             \t max r_hat: 1.1159\n",
      "max_tree_depth = 10 (good r_hat)\n",
      "[x]             \t max r_hat: 1.0166\n"
     ]
    }
   ],
   "source": [
    "# Because rho is very close to 1.0 the posterior geometry is extremely\n",
    "# skewed and using small max_tree_depth leads to bad results.\n",
    "rho = 0.999\n",
    "dim = 200\n",
    "cov = rho * jnp.ones((dim, dim)) + (1 - rho) * jnp.eye(dim)\n",
    "\n",
    "\n",
    "def mvn_model():\n",
    "    numpyro.sample(\"x\", dist.MultivariateNormal(jnp.zeros(dim), covariance_matrix=cov))\n",
    "\n",
    "\n",
    "print(\"max_tree_depth = 5 (bad r_hat)\")\n",
    "run_inference(mvn_model, max_tree_depth=5)\n",
    "\n",
    "print(\"max_tree_depth = 10 (good r_hat)\")\n",
    "run_inference(mvn_model, max_tree_depth=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Other strategies\n",
    "\n",
    "- In some cases it can make sense to use variational inference to *learn* a new coordinate system. For details see [examples/neutra.py](https://github.com/pyro-ppl/numpyro/blob/master/examples/neutra.py) and reference [2]."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## References\n",
    "\n",
    "[1] \"Hamiltonian Monte Carlo for Hierarchical Models,\"\n",
    "    M. J. Betancourt, Mark Girolami.\n",
    "\n",
    "[2] \"NeuTra-lizing Bad Geometry in Hamiltonian Monte Carlo Using Neural Transport,\"\n",
    "    Matthew Hoffman, Pavel Sountsov, Joshua V. Dillon, Ian Langmore, Dustin Tran, Srinivas Vasudevan.\n",
    "    \n",
    "[3] \"Reparameterization\" in the Stan user's manual.\n",
    "    https://mc-stan.org/docs/2_27/stan-users-guide/reparameterization-section.html"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
