{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Example: Stochastic Volatility\n\nGenerative model:\n\n\\begin{align}:nowrap:\n\n    \\begin{align}\n        \\sigma & \\sim \\text{Exponential}(50) \\\\\n        \\nu & \\sim \\text{Exponential}(.1) \\\\\n        s_i & \\sim \\text{Normal}(s_{i-1}, \\sigma^{- 2}) \\\\\n        r_i & \\sim \\text{StudentT}(\\nu, 0, \\exp(s_i))\n    \\end{align}\\end{align}\n\nThis example is from PyMC3 [1], which itself is adapted from the original experiment\nfrom [2]. A discussion about translating this in Pyro appears in [3].\n\nWe take this example to illustrate how to use the functional interface `hmc`. However,\nwe recommend readers to use `MCMC` class as in other examples because it is more stable\nand has more features supported.\n\n**References:**\n\n    1. *Stochastic Volatility Model*, https://docs.pymc.io/notebooks/stochastic_volatility.html\n    2. *The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo*,\n       https://arxiv.org/pdf/1111.4246.pdf\n    3. Pyro forum discussion, https://forum.pyro.ai/t/problems-transforming-a-pymc3-model-to-pyro-mcmc/208/14\n\n<img src=\"file://../_static/img/examples/stochastic_volatility.png\" align=\"center\">\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import argparse\nimport os\n\nimport matplotlib\nimport matplotlib.dates as mdates\nimport matplotlib.pyplot as plt\n\nimport jax.numpy as jnp\nimport jax.random as random\n\nimport numpyro\nimport numpyro.distributions as dist\nfrom numpyro.examples.datasets import SP500, load_dataset\nfrom numpyro.infer.hmc import hmc\nfrom numpyro.infer.util import initialize_model\nfrom numpyro.util import fori_collect\n\nmatplotlib.use(\"Agg\")  # noqa: E402\n\n\ndef model(returns):\n    step_size = numpyro.sample(\"sigma\", dist.Exponential(50.0))\n    s = numpyro.sample(\n        \"s\", dist.GaussianRandomWalk(scale=step_size, num_steps=jnp.shape(returns)[0])\n    )\n    nu = numpyro.sample(\"nu\", dist.Exponential(0.1))\n    return numpyro.sample(\n        \"r\", dist.StudentT(df=nu, loc=0.0, scale=jnp.exp(s)), obs=returns\n    )\n\n\ndef print_results(posterior, dates):\n    def _print_row(values, row_name=\"\"):\n        quantiles = jnp.array([0.2, 0.4, 0.5, 0.6, 0.8])\n        row_name_fmt = \"{:>8}\"\n        header_format = row_name_fmt + \"{:>12}\" * 5\n        row_format = row_name_fmt + \"{:>12.3f}\" * 5\n        columns = [\"(p{})\".format(int(q * 100)) for q in quantiles]\n        q_values = jnp.quantile(values, quantiles, axis=0)\n        print(header_format.format(\"\", *columns))\n        print(row_format.format(row_name, *q_values))\n        print(\"\\n\")\n\n    print(\"=\" * 20, \"sigma\", \"=\" * 20)\n    _print_row(posterior[\"sigma\"])\n    print(\"=\" * 20, \"nu\", \"=\" * 20)\n    _print_row(posterior[\"nu\"])\n    print(\"=\" * 20, \"volatility\", \"=\" * 20)\n    for i in range(0, len(dates), 180):\n        _print_row(jnp.exp(posterior[\"s\"][:, i]), dates[i])\n\n\ndef main(args):\n    _, fetch = load_dataset(SP500, shuffle=False)\n    dates, returns = fetch()\n    init_rng_key, sample_rng_key = random.split(random.PRNGKey(args.rng_seed))\n    model_info = initialize_model(init_rng_key, model, model_args=(returns,))\n    init_kernel, sample_kernel = hmc(model_info.potential_fn, algo=\"NUTS\")\n    hmc_state = init_kernel(\n        model_info.param_info, args.num_warmup, rng_key=sample_rng_key\n    )\n    hmc_states = fori_collect(\n        args.num_warmup,\n        args.num_warmup + args.num_samples,\n        sample_kernel,\n        hmc_state,\n        transform=lambda hmc_state: model_info.postprocess_fn(hmc_state.z),\n        progbar=False if \"NUMPYRO_SPHINXBUILD\" in os.environ else True,\n    )\n    print_results(hmc_states, dates)\n\n    fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)\n    dates = mdates.num2date(mdates.datestr2num(dates))\n    ax.plot(dates, returns, lw=0.5)\n    # format the ticks\n    ax.xaxis.set_major_locator(mdates.YearLocator())\n    ax.xaxis.set_major_formatter(mdates.DateFormatter(\"%Y\"))\n    ax.xaxis.set_minor_locator(mdates.MonthLocator())\n\n    ax.plot(dates, jnp.exp(hmc_states[\"s\"].T), \"r\", alpha=0.01)\n    legend = ax.legend([\"returns\", \"volatility\"], loc=\"upper right\")\n    legend.legendHandles[1].set_alpha(0.6)\n    ax.set(xlabel=\"time\", ylabel=\"returns\", title=\"Volatility of S&P500 over time\")\n\n    plt.savefig(\"stochastic_volatility_plot.pdf\")\n\n\nif __name__ == \"__main__\":\n    assert numpyro.__version__.startswith(\"0.13.2\")\n    parser = argparse.ArgumentParser(description=\"Stochastic Volatility Model\")\n    parser.add_argument(\"-n\", \"--num-samples\", nargs=\"?\", default=600, type=int)\n    parser.add_argument(\"--num-warmup\", nargs=\"?\", default=600, type=int)\n    parser.add_argument(\"--device\", default=\"cpu\", type=str, help='use \"cpu\" or \"gpu\".')\n    parser.add_argument(\n        \"--rng_seed\", default=21, type=int, help=\"random number generator seed\"\n    )\n    args = parser.parse_args()\n\n    numpyro.set_platform(args.device)\n\n    main(args)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}