
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "examples/sparse_regression.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_examples_sparse_regression.py>`
        to download the full example code

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_examples_sparse_regression.py:


Example: Sparse Regression
==========================

We demonstrate how to do (fully Bayesian) sparse linear regression using the
approach described in [1]. This approach is particularly suitable for situations
with many feature dimensions (large P) but not too many datapoints (small N).
In particular we consider a quadratic regressor of the form:

.. math::

    f(X) = \text{constant} + \sum_i \theta_i X_i + \sum_{i<j} \theta_{ij} X_i X_j + \text{observation noise}

**References:**

    1. Raj Agrawal, Jonathan H. Huggins, Brian Trippe, Tamara Broderick (2019),
       "The Kernel Interaction Trick: Fast Bayesian Discovery of Pairwise Interactions in High Dimensions",
       (https://arxiv.org/abs/1905.06501)

.. GENERATED FROM PYTHON SOURCE LINES 23-402

.. code-block:: default


    import argparse
    import itertools
    import os
    import time

    import numpy as np

    from jax import vmap
    import jax.numpy as jnp
    import jax.random as random
    from jax.scipy.linalg import cho_factor, cho_solve, solve_triangular

    import numpyro
    import numpyro.distributions as dist
    from numpyro.infer import MCMC, NUTS


    def dot(X, Z):
        return jnp.dot(X, Z[..., None])[..., 0]


    # The kernel that corresponds to our quadratic regressor.
    def kernel(X, Z, eta1, eta2, c, jitter=1.0e-4):
        eta1sq = jnp.square(eta1)
        eta2sq = jnp.square(eta2)
        k1 = 0.5 * eta2sq * jnp.square(1.0 + dot(X, Z))
        k2 = -0.5 * eta2sq * dot(jnp.square(X), jnp.square(Z))
        k3 = (eta1sq - eta2sq) * dot(X, Z)
        k4 = jnp.square(c) - 0.5 * eta2sq
        if X.shape == Z.shape:
            k4 += jitter * jnp.eye(X.shape[0])
        return k1 + k2 + k3 + k4


    # Most of the model code is concerned with constructing the sparsity inducing prior.
    def model(X, Y, hypers):
        S, P, N = hypers["expected_sparsity"], X.shape[1], X.shape[0]

        sigma = numpyro.sample("sigma", dist.HalfNormal(hypers["alpha3"]))
        phi = sigma * (S / jnp.sqrt(N)) / (P - S)
        eta1 = numpyro.sample("eta1", dist.HalfCauchy(phi))

        msq = numpyro.sample("msq", dist.InverseGamma(hypers["alpha1"], hypers["beta1"]))
        xisq = numpyro.sample("xisq", dist.InverseGamma(hypers["alpha2"], hypers["beta2"]))

        eta2 = jnp.square(eta1) * jnp.sqrt(xisq) / msq

        lam = numpyro.sample("lambda", dist.HalfCauchy(jnp.ones(P)))
        kappa = jnp.sqrt(msq) * lam / jnp.sqrt(msq + jnp.square(eta1 * lam))

        # compute kernel
        kX = kappa * X
        k = kernel(kX, kX, eta1, eta2, hypers["c"]) + sigma**2 * jnp.eye(N)
        assert k.shape == (N, N)

        # sample Y according to the standard gaussian process formula
        numpyro.sample(
            "Y",
            dist.MultivariateNormal(loc=jnp.zeros(X.shape[0]), covariance_matrix=k),
            obs=Y,
        )


    # Compute the mean and variance of coefficient theta_i (where i = dimension) for a
    # MCMC sample of the kernel hyperparameters (eta1, xisq, ...).
    # Compare to theorem 5.1 in reference [1].
    def compute_singleton_mean_variance(X, Y, dimension, msq, lam, eta1, xisq, c, sigma):
        P, N = X.shape[1], X.shape[0]

        probe = jnp.zeros((2, P))
        probe = probe.at[:, dimension].set(jnp.array([1.0, -1.0]))

        eta2 = jnp.square(eta1) * jnp.sqrt(xisq) / msq
        kappa = jnp.sqrt(msq) * lam / jnp.sqrt(msq + jnp.square(eta1 * lam))

        kX = kappa * X
        kprobe = kappa * probe

        k_xx = kernel(kX, kX, eta1, eta2, c) + sigma**2 * jnp.eye(N)
        k_xx_inv = jnp.linalg.inv(k_xx)
        k_probeX = kernel(kprobe, kX, eta1, eta2, c)
        k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c)

        vec = jnp.array([0.50, -0.50])
        mu = jnp.matmul(k_probeX, jnp.matmul(k_xx_inv, Y))
        mu = jnp.dot(mu, vec)

        var = k_prbprb - jnp.matmul(k_probeX, jnp.matmul(k_xx_inv, jnp.transpose(k_probeX)))
        var = jnp.matmul(var, vec)
        var = jnp.dot(var, vec)

        return mu, var


    # Compute the mean and variance of coefficient theta_ij for a MCMC sample of the
    # kernel hyperparameters (eta1, xisq, ...). Compare to theorem 5.1 in reference [1].
    def compute_pairwise_mean_variance(X, Y, dim1, dim2, msq, lam, eta1, xisq, c, sigma):
        P, N = X.shape[1], X.shape[0]

        probe = jnp.zeros((4, P))
        probe = probe.at[:, dim1].set(jnp.array([1.0, 1.0, -1.0, -1.0]))
        probe = probe.at[:, dim2].set(jnp.array([1.0, -1.0, 1.0, -1.0]))

        eta2 = jnp.square(eta1) * jnp.sqrt(xisq) / msq
        kappa = jnp.sqrt(msq) * lam / jnp.sqrt(msq + jnp.square(eta1 * lam))

        kX = kappa * X
        kprobe = kappa * probe

        k_xx = kernel(kX, kX, eta1, eta2, c) + sigma**2 * jnp.eye(N)
        k_xx_inv = jnp.linalg.inv(k_xx)
        k_probeX = kernel(kprobe, kX, eta1, eta2, c)
        k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c)

        vec = jnp.array([0.25, -0.25, -0.25, 0.25])
        mu = jnp.matmul(k_probeX, jnp.matmul(k_xx_inv, Y))
        mu = jnp.dot(mu, vec)

        var = k_prbprb - jnp.matmul(k_probeX, jnp.matmul(k_xx_inv, jnp.transpose(k_probeX)))
        var = jnp.matmul(var, vec)
        var = jnp.dot(var, vec)

        return mu, var


    # Sample coefficients theta from the posterior for a given MCMC sample.
    # The first P returned values are {theta_1, theta_2, ...., theta_P}, while
    # the remaining values are {theta_ij} for i,j in the list `active_dims`,
    # sorted so that i < j.
    def sample_theta_space(X, Y, active_dims, msq, lam, eta1, xisq, c, sigma):
        P, N, M = X.shape[1], X.shape[0], len(active_dims)
        # the total number of coefficients we return
        num_coefficients = P + M * (M - 1) // 2

        probe = jnp.zeros((2 * P + 2 * M * (M - 1), P))
        vec = jnp.zeros((num_coefficients, 2 * P + 2 * M * (M - 1)))
        start1 = 0
        start2 = 0

        for dim in range(P):
            probe = probe.at[start1 : start1 + 2, dim].set(jnp.array([1.0, -1.0]))
            vec = vec.at[start2, start1 : start1 + 2].set(jnp.array([0.5, -0.5]))
            start1 += 2
            start2 += 1

        for dim1 in active_dims:
            for dim2 in active_dims:
                if dim1 >= dim2:
                    continue
                probe = probe.at[start1 : start1 + 4, dim1].set(
                    jnp.array([1.0, 1.0, -1.0, -1.0])
                )
                probe = probe.at[start1 : start1 + 4, dim2].set(
                    jnp.array([1.0, -1.0, 1.0, -1.0])
                )
                vec = vec.at[start2, start1 : start1 + 4].set(
                    jnp.array([0.25, -0.25, -0.25, 0.25])
                )
                start1 += 4
                start2 += 1

        eta2 = jnp.square(eta1) * jnp.sqrt(xisq) / msq
        kappa = jnp.sqrt(msq) * lam / jnp.sqrt(msq + jnp.square(eta1 * lam))

        kX = kappa * X
        kprobe = kappa * probe

        k_xx = kernel(kX, kX, eta1, eta2, c) + sigma**2 * jnp.eye(N)
        L = cho_factor(k_xx, lower=True)[0]
        k_probeX = kernel(kprobe, kX, eta1, eta2, c)
        k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c)

        mu = jnp.matmul(k_probeX, cho_solve((L, True), Y))
        mu = jnp.sum(mu * vec, axis=-1)

        Linv_k_probeX = solve_triangular(L, jnp.transpose(k_probeX), lower=True)
        covar = k_prbprb - jnp.matmul(jnp.transpose(Linv_k_probeX), Linv_k_probeX)
        covar = jnp.matmul(vec, jnp.matmul(covar, jnp.transpose(vec)))

        # sample from N(mu, covar)
        L = jnp.linalg.cholesky(covar)
        sample = mu + jnp.matmul(L, np.random.randn(num_coefficients))

        return sample


    # Helper function for doing HMC inference
    def run_inference(model, args, rng_key, X, Y, hypers):
        start = time.time()
        kernel = NUTS(model)
        mcmc = MCMC(
            kernel,
            num_warmup=args.num_warmup,
            num_samples=args.num_samples,
            num_chains=args.num_chains,
            progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
        )
        mcmc.run(rng_key, X, Y, hypers)
        mcmc.print_summary()
        print("\nMCMC elapsed time:", time.time() - start)
        return mcmc.get_samples()


    # Get the mean and variance of a gaussian mixture
    def gaussian_mixture_stats(mus, variances):
        mean_mu = jnp.mean(mus)
        mean_var = jnp.mean(variances) + jnp.mean(jnp.square(mus)) - jnp.square(mean_mu)
        return mean_mu, mean_var


    # Create artificial regression dataset where only S out of P feature
    # dimensions contain signal and where there is a single pairwise interaction
    # between the first and second dimensions.
    def get_data(N=20, S=2, P=10, sigma_obs=0.05):
        assert S < P and P > 1 and S > 0
        np.random.seed(0)

        X = np.random.randn(N, P)
        # generate S coefficients with non-negligible magnitude
        W = 0.5 + 2.5 * np.random.rand(S)
        # generate data using the S coefficients and a single pairwise interaction
        Y = (
            np.sum(X[:, 0:S] * W, axis=-1)
            + X[:, 0] * X[:, 1]
            + sigma_obs * np.random.randn(N)
        )
        Y -= jnp.mean(Y)
        Y_std = jnp.std(Y)

        assert X.shape == (N, P)
        assert Y.shape == (N,)

        return X, Y / Y_std, W / Y_std, 1.0 / Y_std


    # Helper function for analyzing the posterior statistics for coefficient theta_i
    def analyze_dimension(samples, X, Y, dimension, hypers):
        vmap_args = (
            samples["msq"],
            samples["lambda"],
            samples["eta1"],
            samples["xisq"],
            samples["sigma"],
        )
        mus, variances = vmap(
            lambda msq, lam, eta1, xisq, sigma: compute_singleton_mean_variance(
                X, Y, dimension, msq, lam, eta1, xisq, hypers["c"], sigma
            )
        )(*vmap_args)
        mean, variance = gaussian_mixture_stats(mus, variances)
        std = jnp.sqrt(variance)
        return mean, std


    # Helper function for analyzing the posterior statistics for coefficient theta_ij
    def analyze_pair_of_dimensions(samples, X, Y, dim1, dim2, hypers):
        vmap_args = (
            samples["msq"],
            samples["lambda"],
            samples["eta1"],
            samples["xisq"],
            samples["sigma"],
        )
        mus, variances = vmap(
            lambda msq, lam, eta1, xisq, sigma: compute_pairwise_mean_variance(
                X, Y, dim1, dim2, msq, lam, eta1, xisq, hypers["c"], sigma
            )
        )(*vmap_args)
        mean, variance = gaussian_mixture_stats(mus, variances)
        std = jnp.sqrt(variance)
        return mean, std


    def main(args):
        X, Y, expected_thetas, expected_pairwise = get_data(
            N=args.num_data, P=args.num_dimensions, S=args.active_dimensions
        )

        # setup hyperparameters
        hypers = {
            "expected_sparsity": max(1.0, args.num_dimensions / 10),
            "alpha1": 3.0,
            "beta1": 1.0,
            "alpha2": 3.0,
            "beta2": 1.0,
            "alpha3": 1.0,
            "c": 1.0,
        }

        # do inference
        rng_key = random.PRNGKey(0)
        samples = run_inference(model, args, rng_key, X, Y, hypers)

        # compute the mean and square root variance of each coefficient theta_i
        means, stds = vmap(lambda dim: analyze_dimension(samples, X, Y, dim, hypers))(
            jnp.arange(args.num_dimensions)
        )

        print(
            "Coefficients theta_1 to theta_%d used to generate the data:"
            % args.active_dimensions,
            expected_thetas,
        )
        print(
            "The single quadratic coefficient theta_{1,2} used to generate the data:",
            expected_pairwise,
        )
        active_dimensions = []

        for dim, (mean, std) in enumerate(zip(means, stds)):
            # we mark the dimension as inactive if the interval [mean - 3 * std, mean + 3 * std] contains zero
            lower, upper = mean - 3.0 * std, mean + 3.0 * std
            inactive = "inactive" if lower < 0.0 and upper > 0.0 else "active"
            if inactive == "active":
                active_dimensions.append(dim)
            print(
                "[dimension %02d/%02d]  %s:\t%.2e +- %.2e"
                % (dim + 1, args.num_dimensions, inactive, mean, std)
            )

        print(
            "Identified a total of %d active dimensions; expected %d."
            % (len(active_dimensions), args.active_dimensions)
        )

        # Compute the mean and square root variance of coefficients theta_ij for i,j active dimensions.
        # Note that the resulting numbers are only meaningful for i != j.
        if len(active_dimensions) > 0:
            dim_pairs = jnp.array(
                list(itertools.product(active_dimensions, active_dimensions))
            )
            means, stds = vmap(
                lambda dim_pair: analyze_pair_of_dimensions(
                    samples, X, Y, dim_pair[0], dim_pair[1], hypers
                )
            )(dim_pairs)
            for dim_pair, mean, std in zip(dim_pairs, means, stds):
                dim1, dim2 = dim_pair
                if dim1 >= dim2:
                    continue
                lower, upper = mean - 3.0 * std, mean + 3.0 * std
                if not (lower < 0.0 and upper > 0.0):
                    format_str = "Identified pairwise interaction between dimensions %d and %d: %.2e +- %.2e"
                    print(format_str % (dim1 + 1, dim2 + 1, mean, std))

            # Draw a single sample of coefficients theta from the posterior, where we return all singleton
            # coefficients theta_i and pairwise coefficients theta_ij for i, j active dimensions. We use the
            # final MCMC sample obtained from the HMC sampler.
            thetas = sample_theta_space(
                X,
                Y,
                active_dimensions,
                samples["msq"][-1],
                samples["lambda"][-1],
                samples["eta1"][-1],
                samples["xisq"][-1],
                hypers["c"],
                samples["sigma"][-1],
            )
            print("Single posterior sample theta:\n", thetas)


    if __name__ == "__main__":
        assert numpyro.__version__.startswith("0.13.2")
        parser = argparse.ArgumentParser(description="Gaussian Process example")
        parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
        parser.add_argument("--num-warmup", nargs="?", default=500, type=int)
        parser.add_argument("--num-chains", nargs="?", default=1, type=int)
        parser.add_argument("--num-data", nargs="?", default=100, type=int)
        parser.add_argument("--num-dimensions", nargs="?", default=20, type=int)
        parser.add_argument("--active-dimensions", nargs="?", default=3, type=int)
        parser.add_argument("--device", default="cpu", type=str, help='use "cpu" or "gpu".')
        args = parser.parse_args()

        numpyro.set_platform(args.device)
        numpyro.set_host_device_count(args.num_chains)

        main(args)


.. _sphx_glr_download_examples_sparse_regression.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example




    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: sparse_regression.py <sparse_regression.py>`

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: sparse_regression.ipynb <sparse_regression.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
