{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Example: Holt-Winters Exponential Smoothing\n\nIn this example we show how to implement Exponential Smoothing.\nThis is intended to be a simple counter-part to the\n[Time Series Forecasting](https://num.pyro.ai/en/stable/tutorials/time_series_forecasting.html)\nnotebook.\n\nThe idea is that we have some times series\n\n\\begin{align}y_1, ..., y_T, y_{T+1}, ..., y_{T+H}\\end{align}\n\nwhere we train on $y_1, ..., y_T$ and predict for $y_{T+1}, ..., y_{T+H}$,\nwhere $T$ is the maximum training timestamp and $H$ is the maximum number of\nfuture timesteps for which we want to forecast.\n\nWe will be using the update equations from the excellent book\n[Forecasting Principles and Practice](https://otexts.com/fpp3/holt-winters.html):\n\n\\begin{align}\\hat{y}_{t+h|t} = l_t + hb_t + s_{t+h-m(k+1)}\n\n    l_t = \\alpha(y_t - s_{t-m}) + (1-\\alpha)(l_{t-1} + b_{t-1})\n\n    b_t = \\beta^*(l_t-l_{t-1}) + (1-\\beta^*)b_{t-1}\n\n    s_t = \\gamma(y_t-l_{t-1}-b_{t-1})+(1-\\gamma)s_{t-m}\\end{align}\n\nwhere\n\n* $\\hat{y}_t$ is the forecast at time $t$;\n* $h$ is the number of time steps into the future which we want to predict for;\n* $l_t$ is the level, $b_t$ is the trend,\n  and $s_t$ is the seasonality,\n* $\\alpha$ is the level smoothing, $\\beta^*$ is the trend\n  smoothing, and $\\gamma$ is the seasonality smoothing.\n* $k$ is the integer part of $(h-1)/m$ (this looks more complicated than it is,\n  it just takes the latest seasonality estimate for this time point).\n\n<img src=\"file://../_static/img/examples/holt_winters.png\" align=\"center\">\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import argparse\nimport os\nimport time\n\nimport matplotlib\nimport matplotlib.pyplot as plt\nimport numpy as np\n\nimport jax\nfrom jax import random\nimport jax.numpy as jnp\n\nimport numpyro\nfrom numpyro.contrib.control_flow import scan\nfrom numpyro.diagnostics import hpdi\nimport numpyro.distributions as dist\nfrom numpyro.infer import MCMC, NUTS, Predictive\n\nmatplotlib.use(\"Agg\")\n\nN_POINTS_PER_UNIT = 10  # number of points to plot for each unit interval\n\n\ndef holt_winters(y, n_seasons, future=0):\n    T = y.shape[0]\n    level_smoothing = numpyro.sample(\"level_smoothing\", dist.Beta(1, 1))\n    trend_smoothing = numpyro.sample(\"trend_smoothing\", dist.Beta(1, 1))\n    seasonality_smoothing = numpyro.sample(\"seasonality_smoothing\", dist.Beta(1, 1))\n    adj_seasonality_smoothing = seasonality_smoothing * (1 - level_smoothing)\n    noise = numpyro.sample(\"noise\", dist.HalfNormal(1))\n    level_init = numpyro.sample(\"level_init\", dist.Normal(0, 1))\n    trend_init = numpyro.sample(\"trend_init\", dist.Normal(0, 1))\n    with numpyro.plate(\"n_seasons\", n_seasons):\n        seasonality_init = numpyro.sample(\"seasonality_init\", dist.Normal(0, 1))\n\n    def transition_fn(carry, t):\n        previous_level, previous_trend, previous_seasonality = carry\n        level = jnp.where(\n            t < T,\n            level_smoothing * (y[t] - previous_seasonality[0])\n            + (1 - level_smoothing) * (previous_level + previous_trend),\n            previous_level,\n        )\n        trend = jnp.where(\n            t < T,\n            trend_smoothing * (level - previous_level)\n            + (1 - trend_smoothing) * previous_trend,\n            previous_trend,\n        )\n        new_season = jnp.where(\n            t < T,\n            adj_seasonality_smoothing * (y[t] - (previous_level + previous_trend))\n            + (1 - adj_seasonality_smoothing) * previous_seasonality[0],\n            previous_seasonality[0],\n        )\n        step = jnp.where(t < T, 1, t - T + 1)\n        mu = previous_level + step * previous_trend + previous_seasonality[0]\n        pred = numpyro.sample(\"pred\", dist.Normal(mu, noise))\n\n        seasonality = jnp.concatenate(\n            [previous_seasonality[1:], new_season[None]], axis=0\n        )\n        return (level, trend, seasonality), pred\n\n    with numpyro.handlers.condition(data={\"pred\": y}):\n        _, preds = scan(\n            transition_fn,\n            (level_init, trend_init, seasonality_init),\n            jnp.arange(T + future),\n        )\n\n    if future > 0:\n        numpyro.deterministic(\"y_forecast\", preds[-future:])\n\n\ndef run_inference(model, args, rng_key, y, n_seasons):\n    start = time.time()\n    sampler = NUTS(model)\n    mcmc = MCMC(\n        sampler,\n        num_warmup=args.num_warmup,\n        num_samples=args.num_samples,\n        num_chains=args.num_chains,\n        progress_bar=False if \"NUMPYRO_SPHINXBUILD\" in os.environ else True,\n    )\n    mcmc.run(rng_key, y=y, n_seasons=n_seasons)\n    mcmc.print_summary()\n    print(\"\\nMCMC elapsed time:\", time.time() - start)\n    return mcmc.get_samples()\n\n\ndef predict(model, args, samples, rng_key, y, n_seasons):\n    predictive = Predictive(model, samples, return_sites=[\"y_forecast\"])\n    return predictive(\n        rng_key, y=y, n_seasons=n_seasons, future=args.future * N_POINTS_PER_UNIT\n    )[\"y_forecast\"]\n\n\ndef main(args):\n    # generate artifical dataset\n    rng_key, _ = random.split(random.PRNGKey(0))\n    T = args.T\n    t = jnp.linspace(0, T + args.future, (T + args.future) * N_POINTS_PER_UNIT)\n    y = jnp.sin(2 * np.pi * t) + 0.3 * t + jax.random.normal(rng_key, t.shape) * 0.1\n    n_seasons = N_POINTS_PER_UNIT\n    y_train = y[: -args.future * N_POINTS_PER_UNIT]\n    t_test = t[-args.future * N_POINTS_PER_UNIT :]\n\n    # do inference\n    rng_key, _ = random.split(random.PRNGKey(1))\n    samples = run_inference(holt_winters, args, rng_key, y_train, n_seasons)\n\n    # do prediction\n    rng_key, _ = random.split(random.PRNGKey(2))\n    preds = predict(holt_winters, args, samples, rng_key, y_train, n_seasons)\n    mean_preds = preds.mean(axis=0)\n    hpdi_preds = hpdi(preds)\n\n    # make plots\n    fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)\n\n    # plot true data and predictions\n    ax.plot(t, y, color=\"blue\", label=\"True values\")\n    ax.plot(t_test, mean_preds, color=\"orange\", label=\"Mean predictions\")\n    ax.fill_between(t_test, *hpdi_preds, color=\"orange\", alpha=0.2, label=\"90% CI\")\n    ax.set(xlabel=\"time\", ylabel=\"y\", title=\"Holt-Winters Exponential Smoothing\")\n    ax.legend()\n\n    plt.savefig(\"holt_winters_plot.pdf\")\n\n\nif __name__ == \"__main__\":\n    assert numpyro.__version__.startswith(\"0.13.2\")\n    parser = argparse.ArgumentParser(description=\"Holt-Winters\")\n    parser.add_argument(\"--T\", nargs=\"?\", default=6, type=int)\n    parser.add_argument(\"--future\", nargs=\"?\", default=1, type=int)\n    parser.add_argument(\"-n\", \"--num-samples\", nargs=\"?\", default=1000, type=int)\n    parser.add_argument(\"--num-warmup\", nargs=\"?\", default=1000, type=int)\n    parser.add_argument(\"--num-chains\", nargs=\"?\", default=1, type=int)\n    parser.add_argument(\"--device\", default=\"cpu\", type=str, help='use \"cpu\" or \"gpu\".')\n    args = parser.parse_args()\n\n    numpyro.set_platform(args.device)\n    numpyro.set_host_device_count(args.num_chains)\n\n    main(args)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}