{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Example: AutoDAIS\n\nAutoDAIS constructs a guide that combines elements of Hamiltonian Monte Carlo,\nAnnealed Importance Sampling, and Variational Inference.\n\nIn this demo script we construct a somewhat artificial example involving a gaussian\nprocess binary classifier. We aim to demonstrate that:\n\n- DAIS can achieve better ELBOs than e.g. mean field variational inference\n- DAIS can achieve better posterior approximations than e.g. mean field variational inference\n- DAIS improves as you increase K, the number of HMC steps used in the sampler\n\nReferences:\n\n[1] \"MCMC Variational Inference via Uncorrected Hamiltonian Annealing,\"\n    Tomas Geffner, Justin Domke.\n[2] \"Differentiable Annealed Importance Sampling and the Perils of Gradient Noise,\"\n    Guodong Zhang, Kyle Hsu, Jianing Li, Chelsea Finn, Roger Grosse.\n\n<img src=\"file://../_static/img/dais_demo.png\" align=\"center\">\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import argparse\n\nimport matplotlib\nimport matplotlib.pyplot as plt\nimport numpy as np\nfrom scipy.special import expit\nimport seaborn as sns\n\nfrom jax import random\nimport jax.numpy as jnp\n\nimport numpyro\nimport numpyro.distributions as dist\nfrom numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO, autoguide\nfrom numpyro.util import enable_x64\n\nmatplotlib.use(\"Agg\")  # noqa: E402\n\n\n# squared exponential kernel\ndef kernel(X, Z, length, jitter=1.0e-6):\n    deltaXsq = jnp.power((X[:, None] - Z) / length, 2.0)\n    k = jnp.exp(-0.5 * deltaXsq) + jitter * jnp.eye(X.shape[0])\n    return k\n\n\ndef model(X, Y, length=0.2):\n    # compute kernel\n    k = kernel(X, X, length)\n\n    # sample from gaussian process prior\n    f = numpyro.sample(\n        \"f\",\n        dist.MultivariateNormal(loc=jnp.zeros(X.shape[0]), covariance_matrix=k),\n    )\n    # we use a non-standard link function to induce extra non-gaussianity\n    numpyro.sample(\"obs\", dist.Bernoulli(logits=jnp.power(f, 3.0)), obs=Y)\n\n\n# create artificial binary classification dataset\ndef get_data(N=16):\n    np.random.seed(0)\n    X = np.linspace(-1, 1, N)\n    Y = X + 0.2 * np.power(X, 3.0) + 0.5 * np.power(0.5 + X, 2.0) * np.sin(4.0 * X)\n    Y -= np.mean(Y)\n    Y /= np.std(Y)\n    Y = np.random.binomial(1, expit(Y))\n\n    assert X.shape == (N,)\n    assert Y.shape == (N,)\n\n    return X, Y\n\n\n# helper function for running SVI with a particular autoguide\ndef run_svi(rng_key, X, Y, guide_family=\"AutoDiagonalNormal\", K=8):\n    assert guide_family in [\"AutoDiagonalNormal\", \"AutoDAIS\"]\n\n    if guide_family == \"AutoDAIS\":\n        guide = autoguide.AutoDAIS(model, K=K, eta_init=0.02, eta_max=0.5)\n        step_size = 5e-4\n    elif guide_family == \"AutoDiagonalNormal\":\n        guide = autoguide.AutoDiagonalNormal(model)\n        step_size = 3e-3\n\n    optimizer = numpyro.optim.Adam(step_size=step_size)\n    svi = SVI(model, guide, optimizer, loss=Trace_ELBO())\n    svi_result = svi.run(rng_key, args.num_svi_steps, X, Y)\n    params = svi_result.params\n\n    final_elbo = -Trace_ELBO(num_particles=1000).loss(\n        rng_key, params, model, guide, X, Y\n    )\n\n    guide_name = guide_family\n    if guide_family == \"AutoDAIS\":\n        guide_name += \"-{}\".format(K)\n\n    print(\"[{}] final elbo: {:.2f}\".format(guide_name, final_elbo))\n\n    return guide.sample_posterior(\n        random.PRNGKey(1), params, sample_shape=(args.num_samples,)\n    )\n\n\n# helper function for running mcmc\ndef run_nuts(mcmc_key, args, X, Y):\n    mcmc = MCMC(NUTS(model), num_warmup=args.num_warmup, num_samples=args.num_samples)\n    mcmc.run(mcmc_key, X, Y)\n    mcmc.print_summary()\n    return mcmc.get_samples()\n\n\ndef main(args):\n    X, Y = get_data()\n\n    rng_keys = random.split(random.PRNGKey(0), 4)\n\n    # run SVI with an AutoDAIS guide for two values of K\n    dais8_samples = run_svi(rng_keys[1], X, Y, guide_family=\"AutoDAIS\", K=8)\n    dais128_samples = run_svi(rng_keys[2], X, Y, guide_family=\"AutoDAIS\", K=128)\n\n    # run SVI with an AutoDiagonalNormal guide\n    meanfield_samples = run_svi(rng_keys[3], X, Y, guide_family=\"AutoDiagonalNormal\")\n\n    # run MCMC inference\n    nuts_samples = run_nuts(rng_keys[0], args, X, Y)\n\n    # make 2d density plots of the (f_0, f_1) marginal posterior\n    if args.num_samples >= 1000:\n        sns.set_style(\"white\")\n\n        coord1, coord2 = 0, 1\n\n        fig, axes = plt.subplots(\n            2, 2, sharex=True, figsize=(6, 6), constrained_layout=True\n        )\n\n        xlim = (-3, 3)\n        ylim = (-3, 3)\n\n        def add_fig(samples, title, ax):\n            sns.kdeplot(x=samples[\"f\"][:, coord1], y=samples[\"f\"][:, coord2], ax=ax)\n            ax.set(title=title, xlim=xlim, ylim=ylim)\n\n        add_fig(dais8_samples, \"AutoDAIS (K=8)\", axes[0][0])\n        add_fig(dais128_samples, \"AutoDAIS (K=128)\", axes[0][1])\n        add_fig(meanfield_samples, \"AutoDiagonalNormal\", axes[1][0])\n        add_fig(nuts_samples, \"NUTS\", axes[1][1])\n\n        plt.savefig(\"dais_demo.png\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\"Usage example for AutoDAIS guide.\")\n    parser.add_argument(\"--num-svi-steps\", type=int, default=80 * 1000)\n    parser.add_argument(\"--num-warmup\", type=int, default=2000)\n    parser.add_argument(\"--num-samples\", type=int, default=10 * 1000)\n    parser.add_argument(\"--device\", default=\"cpu\", type=str, choices=[\"cpu\", \"gpu\"])\n\n    args = parser.parse_args()\n\n    enable_x64()\n    numpyro.set_platform(args.device)\n\n    main(args)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}