{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Example: Hilbert space approximation for Gaussian processes.\n\nThis example replicates the model in the excellent case\nstudy by Aki Vehtari [1] (originally written using R and Stan).\nThe case study uses approximate Gaussian processes [2] to model the\nrelative number of births per day in the US from 1969 to 1988.\nThe Hilbert space approximation is way faster than the exact Gaussian\nprocesses because it circumvents the need for inverting the\ncovariance matrix.\n\nThe original case study also emphasizes the iterative\nprocess of building a Bayesian model, which is excellent as a pedagogical\nresource. Here, however, we replicate only the model that includes all\ncomponents (long term trend, smooth year seasonality, slowly varying day of week effect,\nday of the year effect and special floating days effects).\n\nThe different components of the model are isolated into separate functions\nso that they can easily be reused in different contexts. To combine the\nmultiple components into a single birthdays model, here we make use of Numpyro's\n`scope` handler which modifies the site names of the components by adding\na prefix to them. By doing this, we avoid duplication of site names\nwithin the model. Following this pattern, it is straightforward to construct the\nother models in [1] with the code provided here.\n\nThere are a few minor differences in the mathematical details of our models,\nwhich we had to make for the chains to mix properly or for ease of\nimplementation. We have commented on the places where our models are different.\n\nThe periodic kernel approximation requires tensorflow-probability on a jax backend.\nSee <https://www.tensorflow.org/probability/examples/TensorFlow_Probability_on_JAX>\nfor installation instructions.\n\n**References:**\n    1. Gelman, Vehtari, Simpson, et al (2020), `\"Bayesian workflow book - Birthdays\"\n       <https://avehtari.github.io/casestudies/Birthdays/birthdays.html>`.\n    2. Riutort-Mayol G, B\u00fcrkner PC, Andersen MR, et al (2020),\n       \"Practical hilbert space approximate bayesian gaussian processes for probabilistic programming\".\n\n<img src=\"file://../_static/img/examples/hsgp.png\" align=\"center\">\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import argparse\nimport os\n\nimport matplotlib.pyplot as plt\nimport pandas as pd\n\nimport jax\nimport jax.numpy as jnp\nfrom tensorflow_probability.substrates import jax as tfp\n\nimport numpyro\nfrom numpyro import deterministic, plate, sample\nimport numpyro.distributions as dist\nfrom numpyro.handlers import scope\nfrom numpyro.infer import MCMC, NUTS, init_to_median\n\n\n# --- Data processing functions\ndef get_labour_days(dates):\n    \"\"\"\n    First monday of September\n    \"\"\"\n    is_september = dates.dt.month.eq(9)\n    is_monday = dates.dt.weekday.eq(0)\n    is_first_week = dates.dt.day.le(7)\n\n    is_labour_day = is_september & is_monday & is_first_week\n    is_day_after = is_labour_day.shift(fill_value=False)\n\n    return is_labour_day | is_day_after\n\n\ndef get_memorial_days(dates):\n    \"\"\"\n    Last monday of May\n    \"\"\"\n    is_may = dates.dt.month.eq(5)\n    is_monday = dates.dt.weekday.eq(0)\n    is_last_week = dates.dt.day.ge(25)\n\n    is_memorial_day = is_may & is_monday & is_last_week\n    is_day_after = is_memorial_day.shift(fill_value=False)\n\n    return is_memorial_day | is_day_after\n\n\ndef get_thanksgiving_days(dates):\n    \"\"\"\n    Third thursday of November\n    \"\"\"\n    is_november = dates.dt.month.eq(11)\n    is_thursday = dates.dt.weekday.eq(3)\n    is_third_week = dates.dt.day.between(22, 28)\n\n    is_thanksgiving = is_november & is_thursday & is_third_week\n    is_day_after = is_thanksgiving.shift(fill_value=False)\n\n    return is_thanksgiving | is_day_after\n\n\ndef get_floating_days_indicators(dates):\n    def encode(x):\n        return jnp.array(x.values, dtype=jnp.result_type(int))\n\n    return {\n        \"labour_days_indicator\": encode(get_labour_days(dates)),\n        \"memorial_days_indicator\": encode(get_memorial_days(dates)),\n        \"thanksgiving_days_indicator\": encode(get_thanksgiving_days(dates)),\n    }\n\n\ndef load_data():\n    URL = \"https://raw.githubusercontent.com/avehtari/casestudies/master/Birthdays/data/births_usa_1969.csv\"\n    data = pd.read_csv(URL, sep=\",\")\n    day0 = pd.to_datetime(\"31-Dec-1968\")\n    dates = [day0 + pd.Timedelta(f\"{i}d\") for i in data[\"id\"]]\n    data[\"date\"] = dates\n    data[\"births_relative\"] = data[\"births\"] / data[\"births\"].mean()\n    return data\n\n\ndef make_birthdays_data_dict(data):\n    x = data[\"id\"].values\n    y = data[\"births_relative\"].values\n    dates = data[\"date\"]\n\n    xsd = jnp.array((x - x.mean()) / x.std())\n    ysd = jnp.array((y - y.mean()) / y.std())\n    day_of_week = jnp.array((data[\"day_of_week\"] - 1).values)\n    day_of_year = jnp.array((data[\"day_of_year\"] - 1).values)\n    floating_days = get_floating_days_indicators(dates)\n    period = 365.25\n    w0 = x.std() * (jnp.pi * 2 / period)\n    L = 1.5 * max(xsd)\n    M1 = 10\n    M2 = 10  # 20 in original case study\n    M3 = 5\n\n    return {\n        \"x\": xsd,\n        \"day_of_week\": day_of_week,\n        \"day_of_year\": day_of_year,\n        \"w0\": w0,\n        \"L\": L,\n        \"M1\": M1,\n        \"M2\": M2,\n        \"M3\": M3,\n        **floating_days,\n        \"y\": ysd,\n    }\n\n\n# --- Modelling utility functions --- #\ndef spectral_density(w, alpha, length):\n    c = alpha * jnp.sqrt(2 * jnp.pi) * length\n    e = jnp.exp(-0.5 * (length**2) * (w**2))\n    return c * e\n\n\ndef diag_spectral_density(alpha, length, L, M):\n    sqrt_eigenvalues = jnp.arange(1, 1 + M) * jnp.pi / 2 / L\n    return spectral_density(sqrt_eigenvalues, alpha, length)\n\n\ndef eigenfunctions(x, L, M):\n    \"\"\"\n    The first `M` eigenfunctions of the laplacian operator in `[-L, L]`\n    evaluated at `x`. These are used for the approximation of the\n    squared exponential kernel.\n    \"\"\"\n    m1 = (jnp.pi / (2 * L)) * jnp.tile(L + x[:, None], M)\n    m2 = jnp.diag(jnp.linspace(1, M, num=M))\n    num = jnp.sin(m1 @ m2)\n    den = jnp.sqrt(L)\n    return num / den\n\n\ndef modified_bessel_first_kind(v, z):\n    v = jnp.asarray(v, dtype=float)\n    return jnp.exp(jnp.abs(z)) * tfp.math.bessel_ive(v, z)\n\n\ndef diag_spectral_density_periodic(alpha, length, M):\n    \"\"\"\n    Not actually a spectral density but these are used in the same\n    way. These are simply the first `M` coefficients of the low rank\n    approximation for the periodic kernel.\n    \"\"\"\n    a = length ** (-2)\n    J = jnp.arange(0, M)\n    c = jnp.where(J > 0, 2, 1)\n    q2 = (c * alpha**2 / jnp.exp(a)) * modified_bessel_first_kind(J, a)\n    return q2\n\n\ndef eigenfunctions_periodic(x, w0, M):\n    \"\"\"\n    Basis functions for the approximation of the periodic kernel.\n    \"\"\"\n    m1 = jnp.tile(w0 * x[:, None], M)\n    m2 = jnp.diag(jnp.arange(M, dtype=jnp.float32))\n    mw0x = m1 @ m2\n    cosines = jnp.cos(mw0x)\n    sines = jnp.sin(mw0x)\n    return cosines, sines\n\n\n# --- Approximate Gaussian processes --- #\ndef approx_se_ncp(x, alpha, length, L, M):\n    \"\"\"\n    Hilbert space approximation for the squared\n    exponential kernel in the non-centered parametrisation.\n    \"\"\"\n    phi = eigenfunctions(x, L, M)\n    spd = jnp.sqrt(diag_spectral_density(alpha, length, L, M))\n    with plate(\"basis\", M):\n        beta = sample(\"beta\", dist.Normal(0, 1))\n\n    f = deterministic(\"f\", phi @ (spd * beta))\n    return f\n\n\ndef approx_periodic_gp_ncp(x, alpha, length, w0, M):\n    \"\"\"\n    Low rank approximation for the periodic squared\n    exponential kernel in the non-centered parametrisation.\n    \"\"\"\n    q2 = diag_spectral_density_periodic(alpha, length, M)\n    cosines, sines = eigenfunctions_periodic(x, w0, M)\n\n    with plate(\"cos_basis\", M):\n        beta_cos = sample(\"beta_cos\", dist.Normal(0, 1))\n\n    with plate(\"sin_basis\", M - 1):\n        beta_sin = sample(\"beta_sin\", dist.Normal(0, 1))\n\n    # The first eigenfunction for the sine component\n    # is zero, so the first parameter wouldn't contribute to the approximation.\n    # We set it to zero to identify the model and avoid divergences.\n    zero = jnp.array([0.0])\n    beta_sin = jnp.concatenate((zero, beta_sin))\n\n    f = deterministic(\"f\", cosines @ (q2 * beta_cos) + sines @ (q2 * beta_sin))\n    return f\n\n\n# --- Components of the Birthdays model --- #\ndef trend_gp(x, L, M):\n    alpha = sample(\"alpha\", dist.HalfNormal(1.0))\n    length = sample(\"length\", dist.InverseGamma(10.0, 2.0))\n    f = approx_se_ncp(x, alpha, length, L, M)\n    return f\n\n\ndef year_gp(x, w0, M):\n    alpha = sample(\"alpha\", dist.HalfNormal(1.0))\n    length = sample(\"length\", dist.HalfNormal(0.2))  # scale=0.1 in original\n    f = approx_periodic_gp_ncp(x, alpha, length, w0, M)\n    return f\n\n\ndef weekday_effect(day_of_week):\n    with plate(\"plate_day_of_week\", 6):\n        weekday = sample(\"_beta\", dist.Normal(0, 1))\n\n    monday = jnp.array([-jnp.sum(weekday)])  # Monday = 0 in original\n    beta = deterministic(\"beta\", jnp.concatenate((monday, weekday)))\n    return beta[day_of_week]\n\n\ndef yearday_effect(day_of_year):\n    slab_df = 50  # 100 in original case study\n    slab_scale = 2\n    scale_global = 0.1\n    tau = sample(\n        \"tau\", dist.HalfNormal(2 * scale_global)\n    )  # Original uses half-t with 100df\n    c_aux = sample(\"c_aux\", dist.InverseGamma(0.5 * slab_df, 0.5 * slab_df))\n    c = slab_scale * jnp.sqrt(c_aux)\n\n    # Jan 1st:  Day 0\n    # Feb 29th: Day 59\n    # Dec 31st: Day 365\n    with plate(\"plate_day_of_year\", 366):\n        lam = sample(\"lam\", dist.HalfCauchy(scale=1))\n        lam_tilde = jnp.sqrt(c) * lam / jnp.sqrt(c + (tau * lam) ** 2)\n        beta = sample(\"beta\", dist.Normal(loc=0, scale=tau * lam_tilde))\n\n    return beta[day_of_year]\n\n\ndef special_effect(indicator):\n    beta = sample(\"beta\", dist.Normal(0, 1))\n    return beta * indicator\n\n\n# --- Model --- #\ndef birthdays_model(\n    x,\n    day_of_week,\n    day_of_year,\n    memorial_days_indicator,\n    labour_days_indicator,\n    thanksgiving_days_indicator,\n    w0,\n    L,\n    M1,\n    M2,\n    M3,\n    y=None,\n):\n    intercept = sample(\"intercept\", dist.Normal(0, 1))\n    f1 = scope(trend_gp, \"trend\")(x, L, M1)\n    f2 = scope(year_gp, \"year\")(x, w0, M2)\n    g3 = scope(trend_gp, \"week-trend\")(\n        x, L, M3\n    )  # length ~ lognormal(-1, 1) in original\n    weekday = scope(weekday_effect, \"week\")(day_of_week)\n    yearday = scope(yearday_effect, \"day\")(day_of_year)\n\n    # # --- special days\n    memorial = scope(special_effect, \"memorial\")(memorial_days_indicator)\n    labour = scope(special_effect, \"labour\")(labour_days_indicator)\n    thanksgiving = scope(special_effect, \"thanksgiving\")(thanksgiving_days_indicator)\n\n    day = yearday + memorial + labour + thanksgiving\n    # --- Combine components\n    f = deterministic(\"f\", intercept + f1 + f2 + jnp.exp(g3) * weekday + day)\n    sigma = sample(\"sigma\", dist.HalfNormal(0.5))\n    with plate(\"obs\", x.shape[0]):\n        sample(\"y\", dist.Normal(f, sigma), obs=y)\n\n\n# --- plotting function --- #\nDATA_STYLE = dict(marker=\".\", alpha=0.8, lw=0, label=\"data\", c=\"lightgray\")\nMODEL_STYLE = dict(lw=2, color=\"k\")\n\n\ndef plot_trend(data, samples, ax=None):\n    y = data[\"births_relative\"]\n    x = data[\"date\"]\n    fsd = samples[\"intercept\"][:, None] + samples[\"trend/f\"]\n    f = jnp.quantile(fsd * y.std() + y.mean(), 0.50, axis=0)\n\n    if ax is None:\n        ax = plt.gca()\n\n    ax.plot(x, y, **DATA_STYLE)\n    ax.plot(x, f, **MODEL_STYLE)\n    return ax\n\n\ndef plot_seasonality(data, samples, ax=None):\n    y = data[\"births_relative\"]\n    sdev = y.std()\n    mean = y.mean()\n    baseline = (samples[\"intercept\"][:, None] + samples[\"trend/f\"]) * sdev\n    y_detrended = y - baseline.mean(0)\n    y_year_mean = y_detrended.groupby(data[\"day_of_year\"]).mean()\n    x = y_year_mean.index\n\n    f_median = (\n        pd.DataFrame(samples[\"year/f\"] * sdev + mean, columns=data[\"day_of_year\"])\n        .melt(var_name=\"day_of_year\")\n        .groupby(\"day_of_year\")[\"value\"]\n        .median()\n    )\n\n    if ax is None:\n        ax = plt.gca()\n\n    ax.plot(x, y_year_mean, **DATA_STYLE)\n    ax.plot(x, f_median, **MODEL_STYLE)\n    return ax\n\n\ndef plot_week(data, samples, ax=None):\n    if ax is None:\n        ax = plt.gca()\n\n    weekdays = [\"Mon\", \"Tue\", \"Wed\", \"Thu\", \"Fri\", \"Sat\", \"Sun\"]\n    y = data[\"births_relative\"]\n    x = data[\"day_of_week\"] - 1\n    f = jnp.median(samples[\"week/beta\"] * y.std() + y.mean(), 0)\n\n    ax.plot(x, y, **DATA_STYLE)\n    ax.plot(range(7), f, **MODEL_STYLE)\n    ax.set_xticks(range(7))\n    ax.set_xticklabels(weekdays)\n    return ax\n\n\ndef plot_weektrend(data, samples, ax=None):\n    dates = data[\"date\"]\n    weekdays = [\"Mon\", \"Tue\", \"Wed\", \"Thu\", \"Fri\", \"Sat\", \"Sun\"]\n    y = data[\"births_relative\"]\n    mean, sdev = y.mean(), y.std()\n    intercept = samples[\"intercept\"][:, None]\n    f1 = samples[\"trend/f\"]\n    f2 = samples[\"year/f\"]\n    g3 = samples[\"week-trend/f\"]\n    baseline = ((intercept + f1 + f2) * y.std()).mean(0)\n\n    if ax is None:\n        ax = plt.gca()\n\n    ax.plot(dates, y - baseline, **DATA_STYLE)\n    for n, day in enumerate(weekdays):\n        week_beta = samples[\"week/beta\"][:, n][:, None]\n        fsd = jnp.exp(g3) * week_beta\n        f = jnp.quantile(fsd * sdev + mean, 0.50, axis=0)\n        ax.plot(dates, f, **MODEL_STYLE)\n        ax.text(dates.iloc[-1], f[-1], day)\n\n    return ax\n\n\ndef plot_1988(data, samples, ax=None):\n    indicators = get_floating_days_indicators(data[\"date\"])\n    memorial_beta = samples[\"memorial/beta\"][:, None]\n    labour_beta = samples[\"labour/beta\"][:, None]\n    thanks_beta = samples[\"thanksgiving/beta\"][:, None]\n\n    memorials = indicators[\"memorial_days_indicator\"] * memorial_beta\n    labour = indicators[\"labour_days_indicator\"] * labour_beta\n    thanksgiving = indicators[\"thanksgiving_days_indicator\"] * thanks_beta\n    floating_days = memorials + labour + thanksgiving\n\n    is_1988 = data[\"date\"].dt.year == 1988\n    days_in_1988 = data[\"day_of_year\"][is_1988] - 1\n    days_effect = samples[\"day/beta\"][:, days_in_1988.values]\n    floating_effect = floating_days[:, jnp.argwhere(is_1988.values).ravel()]\n\n    y = data[\"births_relative\"]\n    f = (days_effect + floating_effect) * y.std() + y.mean()\n    f_median = jnp.median(f, axis=0)\n\n    special_days = {\n        \"Valentine's\": \"1988-02-14\",\n        \"Leap day\": \"1988-02-29\",\n        \"Halloween\": \"1988-10-31\",\n        \"Christmas eve\": \"1988-12-24\",\n        \"Christmas day\": \"1988-12-25\",\n        \"New year\": \"1988-01-01\",\n        \"New year's eve\": \"1988-12-31\",\n        \"April 1st\": \"1988-04-01\",\n        \"Independence day\": \"1988-07-04\",\n        \"Labour day\": \"1988-09-05\",\n        \"Memorial day\": \"1988-05-30\",\n        \"Thanksgiving\": \"1988-11-24\",\n    }\n\n    if ax is None:\n        ax = plt.gca()\n\n    ax.plot(days_in_1988, f_median, color=\"k\", lw=2)\n\n    for name, date in special_days.items():\n        xs = pd.to_datetime(date).day_of_year - 1\n        ys = f_median[xs]\n        text = ax.text(xs - 3, ys, name, horizontalalignment=\"right\")\n        text.set_bbox(dict(facecolor=\"white\", alpha=0.5, edgecolor=\"none\"))\n\n    is_day_13 = data[\"date\"].dt.day == 13\n    bad_luck_days = data.loc[is_1988 & is_day_13, \"day_of_year\"] - 1\n    ax.plot(\n        bad_luck_days,\n        f_median[bad_luck_days.values],\n        marker=\"o\",\n        mec=\"gray\",\n        c=\"none\",\n        ms=10,\n        lw=0,\n    )\n\n    return ax\n\n\ndef make_figure(data, samples):\n    import matplotlib.ticker as mtick\n\n    fig = plt.figure(figsize=(15, 9))\n    grid = plt.GridSpec(2, 3, wspace=0.1, hspace=0.25)\n    axes = (\n        plt.subplot(grid[0, :]),\n        plt.subplot(grid[1, 0]),\n        plt.subplot(grid[1, 1]),\n        plt.subplot(grid[1, 2]),\n    )\n    plot_1988(data, samples, ax=axes[0])\n    plot_trend(data, samples, ax=axes[1])\n    plot_seasonality(data, samples, ax=axes[2])\n    plot_week(data, samples, ax=axes[3])\n\n    for ax in axes:\n        ax.axhline(y=1, linestyle=\"--\", color=\"gray\", lw=1)\n        if not ax.get_subplotspec().is_first_row():\n            ax.set_ylim(0.65, 1.35)\n\n        if not ax.get_subplotspec().is_first_col():\n            ax.set_yticks([])\n            ax.set_ylabel(\"\")\n        else:\n            ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1))\n            ax.set_ylabel(\"Relative number of births\")\n\n    axes[0].set_title(\"Special day effect\")\n    axes[0].set_xlabel(\"Day of year\")\n    axes[1].set_title(\"Long term trend\")\n    axes[1].set_xlabel(\"Year\")\n    axes[2].set_title(\"Year seasonality\")\n    axes[2].set_xlabel(\"Day of year\")\n    axes[3].set_title(\"Day of week effect\")\n    axes[3].set_xlabel(\"Day of week\")\n    return fig\n\n\n# --- functions for running the model --- #\ndef parse_arguments():\n    parser = argparse.ArgumentParser(description=\"Hilbert space approx for GPs\")\n    parser.add_argument(\"--num-samples\", nargs=\"?\", default=1000, type=int)\n    parser.add_argument(\"--num-warmup\", nargs=\"?\", default=1000, type=int)\n    parser.add_argument(\"--num-chains\", nargs=\"?\", default=1, type=int)\n    parser.add_argument(\"--device\", default=\"cpu\", type=str, help='use \"cpu\" or \"gpu\".')\n    parser.add_argument(\"--x64\", action=\"store_true\", help=\"Enable double precision\")\n    parser.add_argument(\n        \"--save-figure\",\n        default=\"\",\n        type=str,\n        help=\"Path where to save the plot with matplotlib.\",\n    )\n    args = parser.parse_args()\n    return args\n\n\ndef main(args):\n    is_sphinxbuild = \"NUMPYRO_SPHINXBUILD\" in os.environ\n    data = load_data()\n    data_dict = make_birthdays_data_dict(data)\n    mcmc = MCMC(\n        NUTS(birthdays_model, init_strategy=init_to_median),\n        num_warmup=args.num_warmup,\n        num_samples=args.num_samples,\n        num_chains=args.num_chains,\n        progress_bar=(not is_sphinxbuild),\n    )\n    mcmc.run(jax.random.PRNGKey(0), **data_dict)\n    if not is_sphinxbuild:\n        mcmc.print_summary()\n\n    if args.save_figure:\n        samples = mcmc.get_samples()\n        print(f\"Saving figure at {args.save_figure}\")\n        fig = make_figure(data, samples)\n        fig.savefig(args.save_figure)\n        plt.close()\n\n    return mcmc\n\n\nif __name__ == \"__main__\":\n    args = parse_arguments()\n    numpyro.enable_x64(args.x64)\n    numpyro.set_platform(args.device)\n    numpyro.set_host_device_count(args.num_chains)\n    main(args)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}