{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Example: Horseshoe Regression\n\nWe demonstrate how to use NUTS to do sparse regression using\nthe Horseshoe prior [1] for both continuous- and binary-valued\nresponses. For a more complex modeling and inference approach\nthat also supports quadratic interaction terms in a way that\nis efficient in high dimensions see examples/sparse_regression.py.\n\nReferences:\n\n[1] \"Handling Sparsity via the Horseshoe,\"\n    Carlos M. Carvalho, Nicholas G. Polson, James G. Scott.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import argparse\nimport os\nimport time\n\nimport numpy as np\nfrom scipy.special import expit\n\nimport jax.numpy as jnp\nimport jax.random as random\n\nimport numpyro\nfrom numpyro.diagnostics import summary\nimport numpyro.distributions as dist\nfrom numpyro.infer import MCMC, NUTS\n\n\n# regression model with continuous-valued outputs/responses\ndef model_normal_likelihood(X, Y):\n    D_X = X.shape[1]\n\n    # sample from horseshoe prior\n    lambdas = numpyro.sample(\"lambdas\", dist.HalfCauchy(jnp.ones(D_X)))\n    tau = numpyro.sample(\"tau\", dist.HalfCauchy(jnp.ones(1)))\n\n    # note that in practice for a normal likelihood we would probably want to\n    # integrate out the coefficients (as is done for example in sparse_regression.py).\n    # however, this trick wouldn't be applicable to other likelihoods\n    # (e.g. bernoulli, see below) so we don't make use of it here.\n    unscaled_betas = numpyro.sample(\"unscaled_betas\", dist.Normal(0.0, jnp.ones(D_X)))\n    scaled_betas = numpyro.deterministic(\"betas\", tau * lambdas * unscaled_betas)\n\n    # compute mean function using linear coefficients\n    mean_function = jnp.dot(X, scaled_betas)\n\n    prec_obs = numpyro.sample(\"prec_obs\", dist.Gamma(3.0, 1.0))\n    sigma_obs = 1.0 / jnp.sqrt(prec_obs)\n\n    # observe data\n    numpyro.sample(\"Y\", dist.Normal(mean_function, sigma_obs), obs=Y)\n\n\n# regression model with binary-valued outputs/responses\ndef model_bernoulli_likelihood(X, Y):\n    D_X = X.shape[1]\n\n    # sample from horseshoe prior\n    lambdas = numpyro.sample(\"lambdas\", dist.HalfCauchy(jnp.ones(D_X)))\n    tau = numpyro.sample(\"tau\", dist.HalfCauchy(jnp.ones(1)))\n\n    # note that this reparameterization (i.e. coordinate transformation) improves\n    # posterior geometry and makes NUTS sampling more efficient\n    unscaled_betas = numpyro.sample(\"unscaled_betas\", dist.Normal(0.0, jnp.ones(D_X)))\n    scaled_betas = numpyro.deterministic(\"betas\", tau * lambdas * unscaled_betas)\n\n    # compute mean function using linear coefficients\n    mean_function = jnp.dot(X, scaled_betas)\n\n    # observe data\n    numpyro.sample(\"Y\", dist.Bernoulli(logits=mean_function), obs=Y)\n\n\n# helper function for HMC inference\ndef run_inference(model, args, rng_key, X, Y):\n    start = time.time()\n    kernel = NUTS(model)\n    mcmc = MCMC(\n        kernel,\n        num_warmup=args.num_warmup,\n        num_samples=args.num_samples,\n        num_chains=args.num_chains,\n        progress_bar=False if \"NUMPYRO_SPHINXBUILD\" in os.environ else True,\n    )\n\n    mcmc.run(rng_key, X, Y)\n    mcmc.print_summary(exclude_deterministic=False)\n\n    samples = mcmc.get_samples()\n    summary_dict = summary(samples, group_by_chain=False)\n\n    print(\"\\nMCMC elapsed time:\", time.time() - start)\n\n    return summary_dict\n\n\n# create artificial regression dataset with 3 non-zero regression coefficients\ndef get_data(N=50, D_X=3, sigma_obs=0.05, response=\"continuous\"):\n    assert response in [\"continuous\", \"binary\"]\n    assert D_X >= 3\n\n    np.random.seed(0)\n    X = np.random.randn(N, D_X)\n\n    # the response only depends on X_0, X_1, and X_2\n    W = np.array([2.0, -1.0, 0.50])\n    Y = jnp.dot(X[:, :3], W)\n    Y -= jnp.mean(Y)\n\n    if response == \"continuous\":\n        Y += sigma_obs * np.random.randn(N)\n    elif response == \"binary\":\n        Y = np.random.binomial(1, expit(Y))\n\n    assert X.shape == (N, D_X)\n    assert Y.shape == (N,)\n\n    return X, Y\n\n\ndef main(args):\n    N, D_X = args.num_data, 32\n\n    print(\"[Experiment with continuous-valued responses]\")\n    # first generate and analyze data with continuous-valued responses\n    X, Y = get_data(N=N, D_X=D_X, response=\"continuous\")\n\n    # do inference\n    rng_key, rng_key_predict = random.split(random.PRNGKey(0))\n    summary = run_inference(model_normal_likelihood, args, rng_key, X, Y)\n\n    # lambda should only be large for the first 3 dimensions, which\n    # correspond to relevant covariates (see get_data)\n    print(\"Posterior median over lambdas (leading 5 dimensions):\")\n    print(summary[\"lambdas\"][\"median\"][:5])\n    print(\"Posterior mean over betas (leading 5 dimensions):\")\n    print(summary[\"betas\"][\"mean\"][:5])\n\n    print(\"[Experiment with binary-valued responses]\")\n    # next generate and analyze data with binary-valued responses\n    # (note we use more data for the case of binary-valued responses,\n    # since each response carries less information than a real number)\n    X, Y = get_data(N=4 * N, D_X=D_X, response=\"binary\")\n\n    # do inference\n    rng_key, rng_key_predict = random.split(random.PRNGKey(0))\n    summary = run_inference(model_bernoulli_likelihood, args, rng_key, X, Y)\n\n    # lambda should only be large for the first 3 dimensions, which\n    # correspond to relevant covariates (see get_data)\n    print(\"Posterior median over lambdas (leading 5 dimensions):\")\n    print(summary[\"lambdas\"][\"median\"][:5])\n    print(\"Posterior mean over betas (leading 5 dimensions):\")\n    print(summary[\"betas\"][\"mean\"][:5])\n\n\nif __name__ == \"__main__\":\n    assert numpyro.__version__.startswith(\"0.13.2\")\n    parser = argparse.ArgumentParser(description=\"Horseshoe regression example\")\n    parser.add_argument(\"-n\", \"--num-samples\", nargs=\"?\", default=2000, type=int)\n    parser.add_argument(\"--num-warmup\", nargs=\"?\", default=1000, type=int)\n    parser.add_argument(\"--num-chains\", nargs=\"?\", default=1, type=int)\n    parser.add_argument(\"--num-data\", nargs=\"?\", default=100, type=int)\n    parser.add_argument(\"--device\", default=\"cpu\", type=str, help='use \"cpu\" or \"gpu\".')\n    args = parser.parse_args()\n\n    numpyro.set_platform(args.device)\n    numpyro.set_host_device_count(args.num_chains)\n\n    main(args)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}