{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Example: Sparse Regression\n\nWe demonstrate how to do (fully Bayesian) sparse linear regression using the\napproach described in [1]. This approach is particularly suitable for situations\nwith many feature dimensions (large P) but not too many datapoints (small N).\nIn particular we consider a quadratic regressor of the form:\n\n\\begin{align}f(X) = \\text{constant} + \\sum_i \\theta_i X_i + \\sum_{i<j} \\theta_{ij} X_i X_j + \\text{observation noise}\\end{align}\n\n**References:**\n\n    1. Raj Agrawal, Jonathan H. Huggins, Brian Trippe, Tamara Broderick (2019),\n       \"The Kernel Interaction Trick: Fast Bayesian Discovery of Pairwise Interactions in High Dimensions\",\n       (https://arxiv.org/abs/1905.06501)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import argparse\nimport itertools\nimport os\nimport time\n\nimport numpy as np\n\nfrom jax import vmap\nimport jax.numpy as jnp\nimport jax.random as random\nfrom jax.scipy.linalg import cho_factor, cho_solve, solve_triangular\n\nimport numpyro\nimport numpyro.distributions as dist\nfrom numpyro.infer import MCMC, NUTS\n\n\ndef dot(X, Z):\n    return jnp.dot(X, Z[..., None])[..., 0]\n\n\n# The kernel that corresponds to our quadratic regressor.\ndef kernel(X, Z, eta1, eta2, c, jitter=1.0e-4):\n    eta1sq = jnp.square(eta1)\n    eta2sq = jnp.square(eta2)\n    k1 = 0.5 * eta2sq * jnp.square(1.0 + dot(X, Z))\n    k2 = -0.5 * eta2sq * dot(jnp.square(X), jnp.square(Z))\n    k3 = (eta1sq - eta2sq) * dot(X, Z)\n    k4 = jnp.square(c) - 0.5 * eta2sq\n    if X.shape == Z.shape:\n        k4 += jitter * jnp.eye(X.shape[0])\n    return k1 + k2 + k3 + k4\n\n\n# Most of the model code is concerned with constructing the sparsity inducing prior.\ndef model(X, Y, hypers):\n    S, P, N = hypers[\"expected_sparsity\"], X.shape[1], X.shape[0]\n\n    sigma = numpyro.sample(\"sigma\", dist.HalfNormal(hypers[\"alpha3\"]))\n    phi = sigma * (S / jnp.sqrt(N)) / (P - S)\n    eta1 = numpyro.sample(\"eta1\", dist.HalfCauchy(phi))\n\n    msq = numpyro.sample(\"msq\", dist.InverseGamma(hypers[\"alpha1\"], hypers[\"beta1\"]))\n    xisq = numpyro.sample(\"xisq\", dist.InverseGamma(hypers[\"alpha2\"], hypers[\"beta2\"]))\n\n    eta2 = jnp.square(eta1) * jnp.sqrt(xisq) / msq\n\n    lam = numpyro.sample(\"lambda\", dist.HalfCauchy(jnp.ones(P)))\n    kappa = jnp.sqrt(msq) * lam / jnp.sqrt(msq + jnp.square(eta1 * lam))\n\n    # compute kernel\n    kX = kappa * X\n    k = kernel(kX, kX, eta1, eta2, hypers[\"c\"]) + sigma**2 * jnp.eye(N)\n    assert k.shape == (N, N)\n\n    # sample Y according to the standard gaussian process formula\n    numpyro.sample(\n        \"Y\",\n        dist.MultivariateNormal(loc=jnp.zeros(X.shape[0]), covariance_matrix=k),\n        obs=Y,\n    )\n\n\n# Compute the mean and variance of coefficient theta_i (where i = dimension) for a\n# MCMC sample of the kernel hyperparameters (eta1, xisq, ...).\n# Compare to theorem 5.1 in reference [1].\ndef compute_singleton_mean_variance(X, Y, dimension, msq, lam, eta1, xisq, c, sigma):\n    P, N = X.shape[1], X.shape[0]\n\n    probe = jnp.zeros((2, P))\n    probe = probe.at[:, dimension].set(jnp.array([1.0, -1.0]))\n\n    eta2 = jnp.square(eta1) * jnp.sqrt(xisq) / msq\n    kappa = jnp.sqrt(msq) * lam / jnp.sqrt(msq + jnp.square(eta1 * lam))\n\n    kX = kappa * X\n    kprobe = kappa * probe\n\n    k_xx = kernel(kX, kX, eta1, eta2, c) + sigma**2 * jnp.eye(N)\n    k_xx_inv = jnp.linalg.inv(k_xx)\n    k_probeX = kernel(kprobe, kX, eta1, eta2, c)\n    k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c)\n\n    vec = jnp.array([0.50, -0.50])\n    mu = jnp.matmul(k_probeX, jnp.matmul(k_xx_inv, Y))\n    mu = jnp.dot(mu, vec)\n\n    var = k_prbprb - jnp.matmul(k_probeX, jnp.matmul(k_xx_inv, jnp.transpose(k_probeX)))\n    var = jnp.matmul(var, vec)\n    var = jnp.dot(var, vec)\n\n    return mu, var\n\n\n# Compute the mean and variance of coefficient theta_ij for a MCMC sample of the\n# kernel hyperparameters (eta1, xisq, ...). Compare to theorem 5.1 in reference [1].\ndef compute_pairwise_mean_variance(X, Y, dim1, dim2, msq, lam, eta1, xisq, c, sigma):\n    P, N = X.shape[1], X.shape[0]\n\n    probe = jnp.zeros((4, P))\n    probe = probe.at[:, dim1].set(jnp.array([1.0, 1.0, -1.0, -1.0]))\n    probe = probe.at[:, dim2].set(jnp.array([1.0, -1.0, 1.0, -1.0]))\n\n    eta2 = jnp.square(eta1) * jnp.sqrt(xisq) / msq\n    kappa = jnp.sqrt(msq) * lam / jnp.sqrt(msq + jnp.square(eta1 * lam))\n\n    kX = kappa * X\n    kprobe = kappa * probe\n\n    k_xx = kernel(kX, kX, eta1, eta2, c) + sigma**2 * jnp.eye(N)\n    k_xx_inv = jnp.linalg.inv(k_xx)\n    k_probeX = kernel(kprobe, kX, eta1, eta2, c)\n    k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c)\n\n    vec = jnp.array([0.25, -0.25, -0.25, 0.25])\n    mu = jnp.matmul(k_probeX, jnp.matmul(k_xx_inv, Y))\n    mu = jnp.dot(mu, vec)\n\n    var = k_prbprb - jnp.matmul(k_probeX, jnp.matmul(k_xx_inv, jnp.transpose(k_probeX)))\n    var = jnp.matmul(var, vec)\n    var = jnp.dot(var, vec)\n\n    return mu, var\n\n\n# Sample coefficients theta from the posterior for a given MCMC sample.\n# The first P returned values are {theta_1, theta_2, ...., theta_P}, while\n# the remaining values are {theta_ij} for i,j in the list `active_dims`,\n# sorted so that i < j.\ndef sample_theta_space(X, Y, active_dims, msq, lam, eta1, xisq, c, sigma):\n    P, N, M = X.shape[1], X.shape[0], len(active_dims)\n    # the total number of coefficients we return\n    num_coefficients = P + M * (M - 1) // 2\n\n    probe = jnp.zeros((2 * P + 2 * M * (M - 1), P))\n    vec = jnp.zeros((num_coefficients, 2 * P + 2 * M * (M - 1)))\n    start1 = 0\n    start2 = 0\n\n    for dim in range(P):\n        probe = probe.at[start1 : start1 + 2, dim].set(jnp.array([1.0, -1.0]))\n        vec = vec.at[start2, start1 : start1 + 2].set(jnp.array([0.5, -0.5]))\n        start1 += 2\n        start2 += 1\n\n    for dim1 in active_dims:\n        for dim2 in active_dims:\n            if dim1 >= dim2:\n                continue\n            probe = probe.at[start1 : start1 + 4, dim1].set(\n                jnp.array([1.0, 1.0, -1.0, -1.0])\n            )\n            probe = probe.at[start1 : start1 + 4, dim2].set(\n                jnp.array([1.0, -1.0, 1.0, -1.0])\n            )\n            vec = vec.at[start2, start1 : start1 + 4].set(\n                jnp.array([0.25, -0.25, -0.25, 0.25])\n            )\n            start1 += 4\n            start2 += 1\n\n    eta2 = jnp.square(eta1) * jnp.sqrt(xisq) / msq\n    kappa = jnp.sqrt(msq) * lam / jnp.sqrt(msq + jnp.square(eta1 * lam))\n\n    kX = kappa * X\n    kprobe = kappa * probe\n\n    k_xx = kernel(kX, kX, eta1, eta2, c) + sigma**2 * jnp.eye(N)\n    L = cho_factor(k_xx, lower=True)[0]\n    k_probeX = kernel(kprobe, kX, eta1, eta2, c)\n    k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c)\n\n    mu = jnp.matmul(k_probeX, cho_solve((L, True), Y))\n    mu = jnp.sum(mu * vec, axis=-1)\n\n    Linv_k_probeX = solve_triangular(L, jnp.transpose(k_probeX), lower=True)\n    covar = k_prbprb - jnp.matmul(jnp.transpose(Linv_k_probeX), Linv_k_probeX)\n    covar = jnp.matmul(vec, jnp.matmul(covar, jnp.transpose(vec)))\n\n    # sample from N(mu, covar)\n    L = jnp.linalg.cholesky(covar)\n    sample = mu + jnp.matmul(L, np.random.randn(num_coefficients))\n\n    return sample\n\n\n# Helper function for doing HMC inference\ndef run_inference(model, args, rng_key, X, Y, hypers):\n    start = time.time()\n    kernel = NUTS(model)\n    mcmc = MCMC(\n        kernel,\n        num_warmup=args.num_warmup,\n        num_samples=args.num_samples,\n        num_chains=args.num_chains,\n        progress_bar=False if \"NUMPYRO_SPHINXBUILD\" in os.environ else True,\n    )\n    mcmc.run(rng_key, X, Y, hypers)\n    mcmc.print_summary()\n    print(\"\\nMCMC elapsed time:\", time.time() - start)\n    return mcmc.get_samples()\n\n\n# Get the mean and variance of a gaussian mixture\ndef gaussian_mixture_stats(mus, variances):\n    mean_mu = jnp.mean(mus)\n    mean_var = jnp.mean(variances) + jnp.mean(jnp.square(mus)) - jnp.square(mean_mu)\n    return mean_mu, mean_var\n\n\n# Create artificial regression dataset where only S out of P feature\n# dimensions contain signal and where there is a single pairwise interaction\n# between the first and second dimensions.\ndef get_data(N=20, S=2, P=10, sigma_obs=0.05):\n    assert S < P and P > 1 and S > 0\n    np.random.seed(0)\n\n    X = np.random.randn(N, P)\n    # generate S coefficients with non-negligible magnitude\n    W = 0.5 + 2.5 * np.random.rand(S)\n    # generate data using the S coefficients and a single pairwise interaction\n    Y = (\n        np.sum(X[:, 0:S] * W, axis=-1)\n        + X[:, 0] * X[:, 1]\n        + sigma_obs * np.random.randn(N)\n    )\n    Y -= jnp.mean(Y)\n    Y_std = jnp.std(Y)\n\n    assert X.shape == (N, P)\n    assert Y.shape == (N,)\n\n    return X, Y / Y_std, W / Y_std, 1.0 / Y_std\n\n\n# Helper function for analyzing the posterior statistics for coefficient theta_i\ndef analyze_dimension(samples, X, Y, dimension, hypers):\n    vmap_args = (\n        samples[\"msq\"],\n        samples[\"lambda\"],\n        samples[\"eta1\"],\n        samples[\"xisq\"],\n        samples[\"sigma\"],\n    )\n    mus, variances = vmap(\n        lambda msq, lam, eta1, xisq, sigma: compute_singleton_mean_variance(\n            X, Y, dimension, msq, lam, eta1, xisq, hypers[\"c\"], sigma\n        )\n    )(*vmap_args)\n    mean, variance = gaussian_mixture_stats(mus, variances)\n    std = jnp.sqrt(variance)\n    return mean, std\n\n\n# Helper function for analyzing the posterior statistics for coefficient theta_ij\ndef analyze_pair_of_dimensions(samples, X, Y, dim1, dim2, hypers):\n    vmap_args = (\n        samples[\"msq\"],\n        samples[\"lambda\"],\n        samples[\"eta1\"],\n        samples[\"xisq\"],\n        samples[\"sigma\"],\n    )\n    mus, variances = vmap(\n        lambda msq, lam, eta1, xisq, sigma: compute_pairwise_mean_variance(\n            X, Y, dim1, dim2, msq, lam, eta1, xisq, hypers[\"c\"], sigma\n        )\n    )(*vmap_args)\n    mean, variance = gaussian_mixture_stats(mus, variances)\n    std = jnp.sqrt(variance)\n    return mean, std\n\n\ndef main(args):\n    X, Y, expected_thetas, expected_pairwise = get_data(\n        N=args.num_data, P=args.num_dimensions, S=args.active_dimensions\n    )\n\n    # setup hyperparameters\n    hypers = {\n        \"expected_sparsity\": max(1.0, args.num_dimensions / 10),\n        \"alpha1\": 3.0,\n        \"beta1\": 1.0,\n        \"alpha2\": 3.0,\n        \"beta2\": 1.0,\n        \"alpha3\": 1.0,\n        \"c\": 1.0,\n    }\n\n    # do inference\n    rng_key = random.PRNGKey(0)\n    samples = run_inference(model, args, rng_key, X, Y, hypers)\n\n    # compute the mean and square root variance of each coefficient theta_i\n    means, stds = vmap(lambda dim: analyze_dimension(samples, X, Y, dim, hypers))(\n        jnp.arange(args.num_dimensions)\n    )\n\n    print(\n        \"Coefficients theta_1 to theta_%d used to generate the data:\"\n        % args.active_dimensions,\n        expected_thetas,\n    )\n    print(\n        \"The single quadratic coefficient theta_{1,2} used to generate the data:\",\n        expected_pairwise,\n    )\n    active_dimensions = []\n\n    for dim, (mean, std) in enumerate(zip(means, stds)):\n        # we mark the dimension as inactive if the interval [mean - 3 * std, mean + 3 * std] contains zero\n        lower, upper = mean - 3.0 * std, mean + 3.0 * std\n        inactive = \"inactive\" if lower < 0.0 and upper > 0.0 else \"active\"\n        if inactive == \"active\":\n            active_dimensions.append(dim)\n        print(\n            \"[dimension %02d/%02d]  %s:\\t%.2e +- %.2e\"\n            % (dim + 1, args.num_dimensions, inactive, mean, std)\n        )\n\n    print(\n        \"Identified a total of %d active dimensions; expected %d.\"\n        % (len(active_dimensions), args.active_dimensions)\n    )\n\n    # Compute the mean and square root variance of coefficients theta_ij for i,j active dimensions.\n    # Note that the resulting numbers are only meaningful for i != j.\n    if len(active_dimensions) > 0:\n        dim_pairs = jnp.array(\n            list(itertools.product(active_dimensions, active_dimensions))\n        )\n        means, stds = vmap(\n            lambda dim_pair: analyze_pair_of_dimensions(\n                samples, X, Y, dim_pair[0], dim_pair[1], hypers\n            )\n        )(dim_pairs)\n        for dim_pair, mean, std in zip(dim_pairs, means, stds):\n            dim1, dim2 = dim_pair\n            if dim1 >= dim2:\n                continue\n            lower, upper = mean - 3.0 * std, mean + 3.0 * std\n            if not (lower < 0.0 and upper > 0.0):\n                format_str = \"Identified pairwise interaction between dimensions %d and %d: %.2e +- %.2e\"\n                print(format_str % (dim1 + 1, dim2 + 1, mean, std))\n\n        # Draw a single sample of coefficients theta from the posterior, where we return all singleton\n        # coefficients theta_i and pairwise coefficients theta_ij for i, j active dimensions. We use the\n        # final MCMC sample obtained from the HMC sampler.\n        thetas = sample_theta_space(\n            X,\n            Y,\n            active_dimensions,\n            samples[\"msq\"][-1],\n            samples[\"lambda\"][-1],\n            samples[\"eta1\"][-1],\n            samples[\"xisq\"][-1],\n            hypers[\"c\"],\n            samples[\"sigma\"][-1],\n        )\n        print(\"Single posterior sample theta:\\n\", thetas)\n\n\nif __name__ == \"__main__\":\n    assert numpyro.__version__.startswith(\"0.13.2\")\n    parser = argparse.ArgumentParser(description=\"Gaussian Process example\")\n    parser.add_argument(\"-n\", \"--num-samples\", nargs=\"?\", default=1000, type=int)\n    parser.add_argument(\"--num-warmup\", nargs=\"?\", default=500, type=int)\n    parser.add_argument(\"--num-chains\", nargs=\"?\", default=1, type=int)\n    parser.add_argument(\"--num-data\", nargs=\"?\", default=100, type=int)\n    parser.add_argument(\"--num-dimensions\", nargs=\"?\", default=20, type=int)\n    parser.add_argument(\"--active-dimensions\", nargs=\"?\", default=3, type=int)\n    parser.add_argument(\"--device\", default=\"cpu\", type=str, help='use \"cpu\" or \"gpu\".')\n    args = parser.parse_args()\n\n    numpyro.set_platform(args.device)\n    numpyro.set_host_device_count(args.num_chains)\n\n    main(args)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}