{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Example: Proportion Test\nYou are managing a business and want to test if calling your customers will\nincrease their chance of making a purchase. You get 100,000 customer records and call\nroughly half of them and record if they make a purchase in the next three months.\nYou do the same for the half that did not get called. After three months, the data is in -\ndid calling help?\n\nThis example answers this question by estimating a logistic regression model where the\ncovariates are whether the customer got called and their gender. We place a multivariate\nnormal prior on the regression coefficients. We report the 95% highest posterior\ndensity interval for the effect of making a call.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import argparse\nimport os\n\nfrom jax import random\nimport jax.numpy as jnp\nfrom jax.scipy.special import expit\n\nimport numpyro\nfrom numpyro.diagnostics import hpdi\nimport numpyro.distributions as dist\nfrom numpyro.infer import MCMC, NUTS\n\n\ndef make_dataset(rng_key) -> tuple[jnp.ndarray, jnp.ndarray]:\n    \"\"\"\n    Make simulated dataset where potential customers who get a\n    sales calls have ~2% higher chance of making another purchase.\n    \"\"\"\n    key1, key2, key3 = random.split(rng_key, 3)\n\n    num_calls = 51342\n    num_no_calls = 48658\n\n    made_purchase_got_called = dist.Bernoulli(0.084).sample(\n        key1, sample_shape=(num_calls,)\n    )\n    made_purchase_no_calls = dist.Bernoulli(0.061).sample(\n        key2, sample_shape=(num_no_calls,)\n    )\n\n    made_purchase = jnp.concatenate([made_purchase_got_called, made_purchase_no_calls])\n\n    is_female = dist.Bernoulli(0.5).sample(\n        key3, sample_shape=(num_calls + num_no_calls,)\n    )\n    got_called = jnp.concatenate([jnp.ones(num_calls), jnp.zeros(num_no_calls)])\n    design_matrix = jnp.hstack(\n        [\n            jnp.ones((num_no_calls + num_calls, 1)),\n            got_called.reshape(-1, 1),\n            is_female.reshape(-1, 1),\n        ]\n    )\n\n    return design_matrix, made_purchase\n\n\ndef model(design_matrix: jnp.ndarray, outcome: jnp.ndarray = None) -> None:\n    \"\"\"\n    Model definition: Log odds of making a purchase is a linear combination\n    of covariates. Specify a Normal prior over regression coefficients.\n    :param design_matrix: Covariates. All categorical variables have been one-hot\n        encoded.\n    :param outcome: Binary response variable. In this case, whether or not the\n        customer made a purchase.\n    \"\"\"\n\n    beta = numpyro.sample(\n        \"coefficients\",\n        dist.MultivariateNormal(\n            loc=0.0, covariance_matrix=jnp.eye(design_matrix.shape[1])\n        ),\n    )\n    logits = design_matrix.dot(beta)\n\n    with numpyro.plate(\"data\", design_matrix.shape[0]):\n        numpyro.sample(\"obs\", dist.Bernoulli(logits=logits), obs=outcome)\n\n\ndef print_results(coef: jnp.ndarray, interval_size: float = 0.95) -> None:\n    \"\"\"\n    Print the confidence interval for the effect size with interval_size\n    probability mass.\n    \"\"\"\n\n    baseline_response = expit(coef[:, 0])\n    response_with_calls = expit(coef[:, 0] + coef[:, 1])\n\n    impact_on_probability = hpdi(\n        response_with_calls - baseline_response, prob=interval_size\n    )\n\n    effect_of_gender = hpdi(coef[:, 2], prob=interval_size)\n\n    print(\n        f\"There is a {interval_size * 100}% probability that calling customers \"\n        \"increases the chance they'll make a purchase by \"\n        f\"{(100 * impact_on_probability[0]):.2} to {(100 * impact_on_probability[1]):.2} percentage points.\"\n    )\n\n    print(\n        f\"There is a {interval_size * 100}% probability the effect of gender on the log odds of conversion \"\n        f\"lies in the interval ({effect_of_gender[0]:.2}, {effect_of_gender[1]:.2f}).\"\n        \" Since this interval contains 0, we can conclude gender does not impact the conversion rate.\"\n    )\n\n\ndef run_inference(\n    design_matrix: jnp.ndarray,\n    outcome: jnp.ndarray,\n    rng_key: jnp.ndarray,\n    num_warmup: int,\n    num_samples: int,\n    num_chains: int,\n    interval_size: float = 0.95,\n) -> None:\n    \"\"\"\n    Estimate the effect size.\n    \"\"\"\n\n    kernel = NUTS(model)\n    mcmc = MCMC(\n        kernel,\n        num_warmup=num_warmup,\n        num_samples=num_samples,\n        num_chains=num_chains,\n        progress_bar=False if \"NUMPYRO_SPHINXBUILD\" in os.environ else True,\n    )\n    mcmc.run(rng_key, design_matrix, outcome)\n\n    # 0th column is intercept (not getting called)\n    # 1st column is effect of getting called\n    # 2nd column is effect of gender (should be none since assigned at random)\n    coef = mcmc.get_samples()[\"coefficients\"]\n    print_results(coef, interval_size)\n\n\ndef main(args):\n    rng_key, _ = random.split(random.PRNGKey(3))\n    design_matrix, response = make_dataset(rng_key)\n    run_inference(\n        design_matrix,\n        response,\n        rng_key,\n        args.num_warmup,\n        args.num_samples,\n        args.num_chains,\n        args.interval_size,\n    )\n\n\nif __name__ == \"__main__\":\n    assert numpyro.__version__.startswith(\"0.13.2\")\n    parser = argparse.ArgumentParser(description=\"Testing whether  \")\n    parser.add_argument(\"-n\", \"--num-samples\", nargs=\"?\", default=500, type=int)\n    parser.add_argument(\"--num-warmup\", nargs=\"?\", default=1500, type=int)\n    parser.add_argument(\"--num-chains\", nargs=\"?\", default=1, type=int)\n    parser.add_argument(\"--interval-size\", nargs=\"?\", default=0.95, type=float)\n    parser.add_argument(\"--device\", default=\"cpu\", type=str, help='use \"cpu\" or \"gpu\".')\n    args = parser.parse_args()\n\n    numpyro.set_platform(args.device)\n    numpyro.set_host_device_count(args.num_chains)\n\n    main(args)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}