{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Example: Bayesian Neural Network with SteinVI\nWe demonstrate how to use SteinVI to predict housing prices using a BNN for the Boston Housing prices dataset\nfrom the UCI regression benchmarks.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import argparse\nfrom collections import namedtuple\nimport datetime\nfrom functools import partial\nfrom time import time\n\nfrom matplotlib.collections import LineCollection\nimport matplotlib.pyplot as plt\nimport numpy as np\nfrom sklearn.model_selection import train_test_split\n\nimport jax\nfrom jax import random\nimport jax.numpy as jnp\n\nimport numpyro\nfrom numpyro import deterministic\nfrom numpyro.contrib.einstein import IMQKernel, SteinVI\nfrom numpyro.contrib.einstein.mixture_guide_predictive import MixtureGuidePredictive\nfrom numpyro.distributions import Gamma, Normal\nfrom numpyro.examples.datasets import BOSTON_HOUSING, load_dataset\nfrom numpyro.infer import init_to_uniform\nfrom numpyro.infer.autoguide import AutoNormal\nfrom numpyro.optim import Adagrad\n\nDataState = namedtuple(\"data\", [\"xtr\", \"xte\", \"ytr\", \"yte\"])\n\n\ndef load_data() -> DataState:\n    _, fetch = load_dataset(BOSTON_HOUSING, shuffle=False)\n    x, y = fetch()\n    xtr, xte, ytr, yte = train_test_split(x, y, train_size=0.90, random_state=1)\n\n    return DataState(*map(partial(jnp.array, dtype=float), (xtr, xte, ytr, yte)))\n\n\ndef normalize(val, mean=None, std=None):\n    \"\"\"Normalize data to zero mean, unit variance\"\"\"\n    if mean is None and std is None:\n        # Only use training data to estimate mean and std.\n        std = jnp.std(val, 0, keepdims=True)\n        std = jnp.where(std == 0, 1.0, std)\n        mean = jnp.mean(val, 0, keepdims=True)\n    return (val - mean) / std, mean, std\n\n\ndef model(x, y=None, hidden_dim=50, subsample_size=100):\n    \"\"\"BNN described in section 5 of [1].\n\n    **References:**\n        1. *Stein variational gradient descent: A general purpose bayesian inference algorithm*\n        Qiang Liu and Dilin Wang (2016).\n    \"\"\"\n\n    prec_nn = numpyro.sample(\n        \"prec_nn\", Gamma(1.0, 0.1)\n    )  # hyper prior for precision of nn weights and biases\n\n    n, m = x.shape\n\n    with numpyro.plate(\"l1_hidden\", hidden_dim, dim=-1):\n        # prior l1 bias term\n        b1 = numpyro.sample(\n            \"nn_b1\",\n            Normal(\n                0.0,\n                1.0 / jnp.sqrt(prec_nn),\n            ),\n        )\n        assert b1.shape == (hidden_dim,)\n\n        with numpyro.plate(\"l1_feat\", m, dim=-2):\n            w1 = numpyro.sample(\n                \"nn_w1\", Normal(0.0, 1.0 / jnp.sqrt(prec_nn))\n            )  # prior on l1 weights\n            assert w1.shape == (m, hidden_dim)\n\n    with numpyro.plate(\"l2_hidden\", hidden_dim, dim=-1):\n        w2 = numpyro.sample(\n            \"nn_w2\", Normal(0.0, 1.0 / jnp.sqrt(prec_nn))\n        )  # prior on output weights\n\n    b2 = numpyro.sample(\n        \"nn_b2\", Normal(0.0, 1.0 / jnp.sqrt(prec_nn))\n    )  # prior on output bias term\n\n    # precision prior on observations\n    prec_obs = numpyro.sample(\"prec_obs\", Gamma(1.0, 0.1))\n    with numpyro.plate(\n        \"data\",\n        x.shape[0],\n        subsample_size=subsample_size,\n        dim=-1,\n    ):\n        batch_x = numpyro.subsample(x, event_dim=1)\n        if y is not None:\n            batch_y = numpyro.subsample(y, event_dim=0)\n        else:\n            batch_y = y\n\n        loc_y = deterministic(\"y_pred\", jnp.maximum(batch_x @ w1 + b1, 0) @ w2 + b2)\n\n        numpyro.sample(\n            \"y\",\n            Normal(\n                loc_y, 1.0 / jnp.sqrt(prec_obs)\n            ),  # 1 hidden layer with ReLU activation\n            obs=batch_y,\n        )\n\n\ndef main(args):\n    data = load_data()\n\n    inf_key, pred_key, data_key = random.split(random.PRNGKey(args.rng_key), 3)\n    # normalize data and labels to zero mean unit variance!\n    x, xtr_mean, xtr_std = normalize(data.xtr)\n    y, ytr_mean, ytr_std = normalize(data.ytr)\n\n    rng_key, inf_key = random.split(inf_key)\n\n    guide = AutoNormal(model, init_loc_fn=partial(init_to_uniform, radius=0.1))\n\n    stein = SteinVI(\n        model,\n        guide,\n        Adagrad(0.05),\n        IMQKernel(),\n        # ProbabilityProductKernel(guide=guide, scale=1.),\n        repulsion_temperature=args.repulsion,\n        num_stein_particles=args.num_stein_particles,\n        num_elbo_particles=args.num_elbo_particles,\n    )\n    start = time()\n\n    # use keyword params for static (shape etc.)!\n    result = stein.run(\n        rng_key,\n        args.max_iter,\n        x,\n        y,\n        hidden_dim=args.hidden_dim,\n        subsample_size=args.subsample_size,\n        progress_bar=args.progress_bar,\n    )\n    time_taken = time() - start\n\n    pred = MixtureGuidePredictive(\n        model,\n        guide=stein.guide,\n        params=stein.get_params(result.state),\n        num_samples=100,\n        guide_sites=stein.guide_sites,\n    )\n    xte, _, _ = normalize(\n        data.xte, xtr_mean, xtr_std\n    )  # use train data statistics when accessing generalization\n    preds = pred(\n        pred_key, xte, subsample_size=xte.shape[0], hidden_dim=args.hidden_dim\n    )[\"y_pred\"]\n\n    y_pred = preds * ytr_std + ytr_mean\n    rmse = jnp.sqrt(jnp.mean((y_pred.mean(0) - data.yte) ** 2))\n\n    print(rf\"Time taken: {datetime.timedelta(seconds=int(time_taken))}\")\n    print(rf\"RMSE: {rmse:.2f}\")\n\n    # compute mean prediction and confidence interval around median\n    mean_prediction = y_pred.mean(0)\n\n    ran = np.arange(mean_prediction.shape[0])\n    percentiles = np.percentile(preds * ytr_std + ytr_mean, [5.0, 95.0], axis=0)\n\n    # make plots\n    fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)\n    ax.add_collection(\n        LineCollection(\n            zip(zip(ran, percentiles[0]), zip(ran, percentiles[1])), colors=\"lightblue\"\n        )\n    )\n    ax.plot(data.yte, \"kx\", label=\"y true\")\n    ax.plot(mean_prediction, \"ko\", label=\"y pred\")\n    ax.set(xlabel=\"example\", ylabel=\"y\", title=\"Mean predictions with 90% CI\")\n    ax.legend()\n    fig.savefig(\"stein_bnn.pdf\")\n\n\nif __name__ == \"__main__\":\n    jax.config.update(\"jax_debug_nans\", True)\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--subsample-size\", type=int, default=100)\n    parser.add_argument(\"--max-iter\", type=int, default=1000)\n    parser.add_argument(\"--repulsion\", type=float, default=1.0)\n    parser.add_argument(\"--verbose\", type=bool, default=True)\n    parser.add_argument(\"--num-elbo-particles\", type=int, default=50)\n    parser.add_argument(\"--num-stein-particles\", type=int, default=5)\n    parser.add_argument(\"--progress-bar\", type=bool, default=True)\n    parser.add_argument(\"--rng-key\", type=int, default=142)\n    parser.add_argument(\"--device\", default=\"cpu\", choices=[\"gpu\", \"cpu\"])\n    parser.add_argument(\"--hidden-dim\", default=50, type=int)\n\n    args = parser.parse_args()\n\n    numpyro.set_platform(args.device)\n\n    main(args)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}