{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Benchmark NumPyro in large dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook uses `numpyro` and replicates experiments in references [1] which evaluates the performance of NUTS on various frameworks. The benchmark is run with CUDA 10.1 on a NVIDIA RTX 2070."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from jax import random\n",
    "import jax.numpy as jnp\n",
    "\n",
    "import numpyro\n",
    "import numpyro.distributions as dist\n",
    "from numpyro.examples.datasets import COVTYPE, load_dataset\n",
    "from numpyro.infer import HMC, MCMC, NUTS\n",
    "\n",
    "assert numpyro.__version__.startswith(\"0.15.0\")\n",
    "\n",
    "# NB: replace gpu by cpu to run this notebook in cpu\n",
    "numpyro.set_platform(\"gpu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We do preprocessing steps as in [source code](https://github.com/google-research/google-research/blob/master/simple_probabilistic_programming/no_u_turn_sampler/logistic_regression.py) of reference [1]:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading - https://d2hg8soec8ck9v.cloudfront.net/datasets/covtype.zip.\n",
      "Download complete.\n",
      "Data shape: (581012, 55)\n",
      "Label distribution: 211840 has label 1, 369172 has label 0\n"
     ]
    }
   ],
   "source": [
    "_, fetch = load_dataset(COVTYPE, shuffle=False)\n",
    "features, labels = fetch()\n",
    "\n",
    "# normalize features and add intercept\n",
    "features = (features - features.mean(0)) / features.std(0)\n",
    "features = jnp.hstack([features, jnp.ones((features.shape[0], 1))])\n",
    "\n",
    "# make binary feature\n",
    "_, counts = np.unique(labels, return_counts=True)\n",
    "specific_category = jnp.argmax(counts)\n",
    "labels = labels == specific_category\n",
    "\n",
    "N, dim = features.shape\n",
    "print(\"Data shape:\", features.shape)\n",
    "print(\n",
    "    \"Label distribution: {} has label 1, {} has label 0\".format(\n",
    "        labels.sum(), N - labels.sum()\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we construct the model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model(data, labels):\n",
    "    coefs = numpyro.sample(\"coefs\", dist.Normal(jnp.zeros(dim), jnp.ones(dim)))\n",
    "    logits = jnp.dot(data, coefs)\n",
    "    return numpyro.sample(\"obs\", dist.Bernoulli(logits=logits), obs=labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Benchmark HMC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "number of leapfrog steps: 5000\n",
      "avg. time for each step : 0.0015881952285766601\n",
      "\n",
      "                mean       std    median      5.0%     95.0%     n_eff     r_hat\n",
      "  coefs[0]      1.99      0.00      1.99      1.98      1.99      4.53      1.49\n",
      "  coefs[1]     -0.03      0.00     -0.03     -0.03     -0.03      4.26      1.49\n",
      "  coefs[2]     -0.12      0.00     -0.12     -0.12     -0.12      5.57      1.10\n",
      "  coefs[3]     -0.29      0.00     -0.29     -0.29     -0.29      4.77      1.40\n",
      "  coefs[4]     -0.09      0.00     -0.09     -0.10     -0.09      5.13      1.04\n",
      "  coefs[5]     -0.15      0.00     -0.15     -0.15     -0.15      2.61      3.11\n",
      "  coefs[6]     -0.02      0.00     -0.02     -0.02     -0.02      2.68      2.54\n",
      "  coefs[7]     -0.50      0.00     -0.50     -0.50     -0.50     11.32      1.00\n",
      "  coefs[8]      0.27      0.00      0.27      0.27      0.27      3.25      2.03\n",
      "  coefs[9]     -0.02      0.00     -0.02     -0.02     -0.02      6.34      1.42\n",
      " coefs[10]     -0.23      0.00     -0.23     -0.23     -0.22      3.76      1.50\n",
      " coefs[11]     -0.31      0.00     -0.31     -0.31     -0.31      3.51      1.40\n",
      " coefs[12]     -0.54      0.00     -0.54     -0.54     -0.54      2.64      2.52\n",
      " coefs[13]     -1.94      0.00     -1.94     -1.94     -1.93      2.54      2.75\n",
      " coefs[14]      0.24      0.00      0.24      0.24      0.24      9.69      1.08\n",
      " coefs[15]     -1.07      0.00     -1.07     -1.07     -1.07      3.85      1.85\n",
      " coefs[16]     -1.26      0.00     -1.26     -1.26     -1.26      5.80      1.07\n",
      " coefs[17]     -0.22      0.00     -0.22     -0.22     -0.22      4.45      1.33\n",
      " coefs[18]     -0.08      0.00     -0.08     -0.08     -0.08      2.45      2.88\n",
      " coefs[19]     -0.68      0.00     -0.68     -0.69     -0.68      2.72      2.12\n",
      " coefs[20]     -0.13      0.00     -0.13     -0.13     -0.13      2.79      2.30\n",
      " coefs[21]     -0.02      0.00     -0.02     -0.02     -0.02      8.65      1.15\n",
      " coefs[22]      0.02      0.00      0.02      0.02      0.02      2.73      2.32\n",
      " coefs[23]     -0.15      0.00     -0.15     -0.15     -0.15      2.75      2.56\n",
      " coefs[24]     -0.12      0.00     -0.12     -0.12     -0.12      3.92      1.31\n",
      " coefs[25]     -0.32      0.00     -0.32     -0.32     -0.32      5.25      1.31\n",
      " coefs[26]     -0.17      0.00     -0.17     -0.17     -0.17      4.08      1.13\n",
      " coefs[27]     -1.19      0.00     -1.19     -1.19     -1.19      3.22      1.85\n",
      " coefs[28]     -0.05      0.00     -0.05     -0.05     -0.05      7.87      1.01\n",
      " coefs[29]     -0.03      0.00     -0.03     -0.03     -0.03      7.36      1.17\n",
      " coefs[30]     -0.04      0.00     -0.04     -0.04     -0.04      2.88      2.06\n",
      " coefs[31]     -0.06      0.00     -0.06     -0.06     -0.06      6.43      1.23\n",
      " coefs[32]     -0.02      0.00     -0.02     -0.02     -0.02      6.80      1.03\n",
      " coefs[33]     -0.03      0.00     -0.03     -0.03     -0.03      6.47      1.26\n",
      " coefs[34]      0.11      0.00      0.11      0.10      0.11      6.67      1.22\n",
      " coefs[35]      0.08      0.00      0.08      0.08      0.08      2.49      2.80\n",
      " coefs[36]     -0.00      0.00     -0.00     -0.00     -0.00      6.23      1.31\n",
      " coefs[37]     -0.07      0.00     -0.07     -0.07     -0.07      2.72      2.36\n",
      " coefs[38]     -0.03      0.00     -0.03     -0.03     -0.03      3.97      1.52\n",
      " coefs[39]     -0.06      0.00     -0.06     -0.06     -0.06      6.16      1.26\n",
      " coefs[40]     -0.01      0.00     -0.01     -0.01     -0.01      2.86      2.07\n",
      " coefs[41]     -0.06      0.00     -0.06     -0.06     -0.06      3.02      1.98\n",
      " coefs[42]     -0.39      0.00     -0.39     -0.40     -0.39      2.67      2.45\n",
      " coefs[43]     -0.27      0.00     -0.27     -0.27     -0.27      5.15      1.33\n",
      " coefs[44]     -0.07      0.00     -0.07     -0.07     -0.07      5.75      1.30\n",
      " coefs[45]     -0.25      0.00     -0.25     -0.26     -0.25      2.57      2.50\n",
      " coefs[46]     -0.09      0.00     -0.09     -0.09     -0.09      8.72      1.00\n",
      " coefs[47]     -0.12      0.00     -0.12     -0.12     -0.12      3.10      1.73\n",
      " coefs[48]     -0.15      0.00     -0.15     -0.15     -0.15      4.95      1.33\n",
      " coefs[49]     -0.05      0.00     -0.05     -0.05     -0.05      2.99      2.32\n",
      " coefs[50]     -0.94      0.00     -0.94     -0.94     -0.94     10.08      1.00\n",
      " coefs[51]     -0.32      0.00     -0.32     -0.32     -0.32      3.90      1.75\n",
      " coefs[52]     -0.29      0.00     -0.29     -0.30     -0.29     13.85      1.05\n",
      " coefs[53]     -0.31      0.00     -0.31     -0.31     -0.31      8.21      1.01\n",
      " coefs[54]     -1.76      0.00     -1.76     -1.76     -1.76      3.24      1.54\n",
      "\n",
      "Number of divergences: 0\n"
     ]
    }
   ],
   "source": [
    "step_size = jnp.sqrt(0.5 / N)\n",
    "kernel = HMC(\n",
    "    model,\n",
    "    step_size=step_size,\n",
    "    trajectory_length=(10 * step_size),\n",
    "    adapt_step_size=False,\n",
    ")\n",
    "mcmc = MCMC(kernel, num_warmup=500, num_samples=500, progress_bar=False)\n",
    "mcmc.warmup(random.PRNGKey(2019), features, labels, extra_fields=(\"num_steps\",))\n",
    "mcmc.get_extra_fields()[\"num_steps\"].sum().copy()\n",
    "tic = time.time()\n",
    "mcmc.run(random.PRNGKey(2020), features, labels, extra_fields=[\"num_steps\"])\n",
    "num_leapfrogs = mcmc.get_extra_fields()[\"num_steps\"].sum().copy()\n",
    "toc = time.time()\n",
    "print(\"number of leapfrog steps:\", num_leapfrogs)\n",
    "print(\"avg. time for each step :\", (toc - tic) / num_leapfrogs)\n",
    "mcmc.print_summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In CPU, we get `avg. time for each step : 0.02782863507270813`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Benchmark NUTS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "number of leapfrog steps: 47406\n",
      "avg. time for each step : 0.0022662237908313812\n",
      "\n",
      "                mean       std    median      5.0%     95.0%     n_eff     r_hat\n",
      "  coefs[0]      1.97      0.01      1.97      1.95      1.98     74.56      1.05\n",
      "  coefs[1]     -0.04      0.00     -0.04     -0.05     -0.03     59.26      0.99\n",
      "  coefs[2]     -0.07      0.01     -0.06     -0.08     -0.05     35.80      1.12\n",
      "  coefs[3]     -0.30      0.00     -0.30     -0.31     -0.29     54.31      1.00\n",
      "  coefs[4]     -0.09      0.00     -0.09     -0.10     -0.09     38.45      0.99\n",
      "  coefs[5]     -0.14      0.00     -0.14     -0.15     -0.14     26.25      1.12\n",
      "  coefs[6]      0.23      0.04      0.24      0.19      0.30     11.98      1.18\n",
      "  coefs[7]     -0.65      0.02     -0.65     -0.69     -0.62     17.16      1.16\n",
      "  coefs[8]      0.57      0.04      0.57      0.48      0.62     12.71      1.18\n",
      "  coefs[9]     -0.01      0.00     -0.01     -0.02     -0.01     58.92      0.99\n",
      " coefs[10]      0.71      0.84      0.67     -0.76      2.04      7.17      0.98\n",
      " coefs[11]      0.08      0.38      0.06     -0.57      0.68      7.18      0.98\n",
      " coefs[12]      0.39      0.84      0.35     -1.09      1.72      7.18      0.98\n",
      " coefs[13]     -1.54      0.53     -1.56     -2.20     -0.65     10.23      0.99\n",
      " coefs[14]     -0.48      0.52     -0.45     -1.25      0.25     16.10      0.98\n",
      " coefs[15]     -1.83      0.31     -1.80     -2.34     -1.48      5.35      0.98\n",
      " coefs[16]     -1.06      0.52     -0.96     -1.88     -0.19     31.52      1.00\n",
      " coefs[17]     -0.17      0.08     -0.15     -0.30     -0.06     15.07      1.38\n",
      " coefs[18]     -0.64      0.64     -0.59     -1.50      0.25     18.98      1.03\n",
      " coefs[19]     -0.74      0.57     -0.71     -1.66      0.07     12.04      1.11\n",
      " coefs[20]     -1.04      0.64     -1.14     -1.80     -0.10     16.18      1.00\n",
      " coefs[21]     -0.01      0.01     -0.01     -0.02      0.01     12.68      1.42\n",
      " coefs[22]      0.03      0.02      0.04     -0.00      0.07     15.54      1.37\n",
      " coefs[23]     -0.10      0.12     -0.07     -0.27      0.09     15.48      1.39\n",
      " coefs[24]     -0.09      0.08     -0.07     -0.21      0.02     15.48      1.36\n",
      " coefs[25]     -0.26      0.12     -0.24     -0.46     -0.10     15.62      1.37\n",
      " coefs[26]     -0.12      0.09     -0.10     -0.25      0.03     15.71      1.37\n",
      " coefs[27]     -1.11      0.47     -1.11     -1.83     -0.30     17.62      1.08\n",
      " coefs[28]     -0.83      0.70     -0.54     -2.04      0.02     34.06      0.99\n",
      " coefs[29]     -0.01      0.04      0.00     -0.06      0.05     15.94      1.36\n",
      " coefs[30]     -0.02      0.04     -0.00     -0.08      0.04     15.02      1.44\n",
      " coefs[31]     -0.05      0.03     -0.04     -0.09      0.00     16.46      1.28\n",
      " coefs[32]      0.01      0.04      0.02     -0.06      0.07     15.28      1.36\n",
      " coefs[33]      0.04      0.07      0.05     -0.06      0.14     15.73      1.37\n",
      " coefs[34]      0.11      0.02      0.11      0.08      0.14     14.67      1.33\n",
      " coefs[35]      0.13      0.12      0.16     -0.05      0.32     15.43      1.38\n",
      " coefs[36]      0.07      0.16      0.11     -0.16      0.32     15.53      1.37\n",
      " coefs[37]      0.00      0.10      0.02     -0.16      0.14     15.53      1.38\n",
      " coefs[38]     -0.04      0.02     -0.04     -0.06     -0.02     17.43      1.33\n",
      " coefs[39]     -0.05      0.04     -0.04     -0.10      0.01     15.25      1.40\n",
      " coefs[40]      0.01      0.02      0.02     -0.02      0.05     15.66      1.35\n",
      " coefs[41]     -0.04      0.02     -0.04     -0.08     -0.00     11.32      1.38\n",
      " coefs[42]     -0.31      0.21     -0.26     -0.61      0.03     15.56      1.38\n",
      " coefs[43]     -0.20      0.12     -0.18     -0.40     -0.04     15.60      1.38\n",
      " coefs[44]     -0.01      0.11      0.02     -0.17      0.16     15.52      1.38\n",
      " coefs[45]     -0.15      0.15     -0.11     -0.37      0.09     15.46      1.38\n",
      " coefs[46]     -0.02      0.14      0.00     -0.23      0.20     15.83      1.37\n",
      " coefs[47]     -0.12      0.03     -0.11     -0.16     -0.07     16.20      1.38\n",
      " coefs[48]     -0.12      0.03     -0.12     -0.17     -0.08     16.26      1.36\n",
      " coefs[49]     -0.04      0.01     -0.04     -0.05     -0.03     14.31      1.28\n",
      " coefs[50]     -0.98      0.44     -0.94     -1.71     -0.33     12.09      0.98\n",
      " coefs[51]     -0.26      0.09     -0.24     -0.40     -0.14     15.53      1.38\n",
      " coefs[52]     -0.25      0.08     -0.23     -0.36     -0.12     15.81      1.37\n",
      " coefs[53]     -0.26      0.06     -0.25     -0.36     -0.16     15.99      1.36\n",
      " coefs[54]     -1.98      0.13     -1.96     -2.16     -1.81     44.87      0.98\n",
      "\n",
      "Number of divergences: 0\n"
     ]
    }
   ],
   "source": [
    "mcmc = MCMC(NUTS(model), num_warmup=50, num_samples=50, progress_bar=False)\n",
    "mcmc.warmup(random.PRNGKey(2019), features, labels, extra_fields=(\"num_steps\",))\n",
    "mcmc.get_extra_fields()[\"num_steps\"].sum().copy()\n",
    "tic = time.time()\n",
    "mcmc.run(random.PRNGKey(2020), features, labels, extra_fields=[\"num_steps\"])\n",
    "num_leapfrogs = mcmc.get_extra_fields()[\"num_steps\"].sum().copy()\n",
    "toc = time.time()\n",
    "print(\"number of leapfrog steps:\", num_leapfrogs)\n",
    "print(\"avg. time for each step :\", (toc - tic) / num_leapfrogs)\n",
    "mcmc.print_summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In CPU, we get `avg. time for each step : 0.028006251705287415`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compare to other frameworks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "|               |    HMC    |    NUTS   |\n",
    "| ------------- |----------:|----------:|\n",
    "| Edward2 (CPU) |           |  56.1 ms  |\n",
    "| Edward2 (GPU) |           |   9.4 ms  |\n",
    "| Pyro (CPU)    |  35.4 ms  |  35.3 ms  |\n",
    "| Pyro (GPU)    |   3.5 ms  |   4.2 ms  |\n",
    "| NumPyro (CPU) |  27.8 ms  |  28.0 ms  |\n",
    "| NumPyro (GPU) |   1.6 ms  |   2.2 ms  |"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that in some situtation, HMC is slower than NUTS. The reason is the number of leapfrog steps in each HMC trajectory is fixed to $10$, while it is not fixed in NUTS."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Some takeaways:**\n",
    "+ The overhead of iterative NUTS is pretty small. So most of computation time is indeed spent for evaluating potential function and its gradient.\n",
    "+ GPU outperforms CPU by a large margin. The data is large, so evaluating potential function in GPU is clearly faster than doing so in CPU."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## References\n",
    "\n",
    "1. `Simple, Distributed, and Accelerated Probabilistic Programming,` [arxiv](https://arxiv.org/abs/1811.02091)<br>\n",
    "Dustin Tran, Matthew D. Hoffman, Dave Moore, Christopher Suter, Srinivas Vasudevan, Alexey Radul, Matthew Johnson, Rif A. Saurous"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
