{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Bayesian Imputation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Real-world datasets often contain many missing values. In those situations, we have to either remove those missing data (also known as \"complete case\") or replace them by some values. Though using complete case is pretty straightforward, it is only applicable when the number of missing entries is so small that throwing away those entries would not affect much the power of the analysis we are conducting on the data. The second strategy, also known as [imputation](https://en.wikipedia.org/wiki/Imputation_%28statistics%29), is more applicable and will be our focus in this tutorial.\n",
    "\n",
    "Probably the most popular way to perform imputation is to fill a missing value with the mean, median, or mode of its corresponding feature. In that case, we implicitly assume that the feature containing missing values has no correlation with the remaining features of our dataset. This is a pretty strong assumption and might not be true in general. In addition, it does not encode any uncertainty that we might put on those values. Below, we will construct a *Bayesian* setting to resolve those issues. In particular, given a model on the dataset, we will\n",
    "\n",
    "+ create a generative model for the feature with missing value\n",
    "+ and consider missing values as unobserved latent variables."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# first, we need some imports\n",
    "import os\n",
    "\n",
    "from IPython.display import set_matplotlib_formats\n",
    "from matplotlib import pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from jax import numpy as jnp\n",
    "from jax import random\n",
    "from jax.scipy.special import expit\n",
    "\n",
    "import numpyro\n",
    "from numpyro import distributions as dist\n",
    "from numpyro.distributions import constraints\n",
    "from numpyro.infer import MCMC, NUTS, Predictive\n",
    "\n",
    "plt.style.use(\"seaborn\")\n",
    "if \"NUMPYRO_SPHINXBUILD\" in os.environ:\n",
    "    set_matplotlib_formats(\"svg\")\n",
    "\n",
    "assert numpyro.__version__.startswith(\"0.13.2\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The data is taken from the competition [Titanic: Machine Learning from Disaster](https://www.kaggle.com/c/titanic) hosted on [kaggle](https://www.kaggle.com/). It contains information of passengers in the [Titanic accident](https://en.wikipedia.org/wiki/Sinking_of_the_RMS_Titanic) such as name, age, gender,... And our target is to predict if a person is more likely to survive."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'pandas.core.frame.DataFrame'>\n",
      "RangeIndex: 891 entries, 0 to 890\n",
      "Data columns (total 12 columns):\n",
      " #   Column       Non-Null Count  Dtype  \n",
      "---  ------       --------------  -----  \n",
      " 0   PassengerId  891 non-null    int64  \n",
      " 1   Survived     891 non-null    int64  \n",
      " 2   Pclass       891 non-null    int64  \n",
      " 3   Name         891 non-null    object \n",
      " 4   Sex          891 non-null    object \n",
      " 5   Age          714 non-null    float64\n",
      " 6   SibSp        891 non-null    int64  \n",
      " 7   Parch        891 non-null    int64  \n",
      " 8   Ticket       891 non-null    object \n",
      " 9   Fare         891 non-null    float64\n",
      " 10  Cabin        204 non-null    object \n",
      " 11  Embarked     889 non-null    object \n",
      "dtypes: float64(2), int64(5), object(5)\n",
      "memory usage: 83.7+ KB\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>PassengerId</th>\n",
       "      <th>Survived</th>\n",
       "      <th>Pclass</th>\n",
       "      <th>Name</th>\n",
       "      <th>Sex</th>\n",
       "      <th>Age</th>\n",
       "      <th>SibSp</th>\n",
       "      <th>Parch</th>\n",
       "      <th>Ticket</th>\n",
       "      <th>Fare</th>\n",
       "      <th>Cabin</th>\n",
       "      <th>Embarked</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>Braund, Mr. Owen Harris</td>\n",
       "      <td>male</td>\n",
       "      <td>22.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>A/5 21171</td>\n",
       "      <td>7.2500</td>\n",
       "      <td>NaN</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>Cumings, Mrs. John Bradley (Florence Briggs Th...</td>\n",
       "      <td>female</td>\n",
       "      <td>38.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>PC 17599</td>\n",
       "      <td>71.2833</td>\n",
       "      <td>C85</td>\n",
       "      <td>C</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>Heikkinen, Miss. Laina</td>\n",
       "      <td>female</td>\n",
       "      <td>26.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>STON/O2. 3101282</td>\n",
       "      <td>7.9250</td>\n",
       "      <td>NaN</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>Futrelle, Mrs. Jacques Heath (Lily May Peel)</td>\n",
       "      <td>female</td>\n",
       "      <td>35.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>113803</td>\n",
       "      <td>53.1000</td>\n",
       "      <td>C123</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>Allen, Mr. William Henry</td>\n",
       "      <td>male</td>\n",
       "      <td>35.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>373450</td>\n",
       "      <td>8.0500</td>\n",
       "      <td>NaN</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   PassengerId  Survived  Pclass  \\\n",
       "0            1         0       3   \n",
       "1            2         1       1   \n",
       "2            3         1       3   \n",
       "3            4         1       1   \n",
       "4            5         0       3   \n",
       "\n",
       "                                                Name     Sex   Age  SibSp  \\\n",
       "0                            Braund, Mr. Owen Harris    male  22.0      1   \n",
       "1  Cumings, Mrs. John Bradley (Florence Briggs Th...  female  38.0      1   \n",
       "2                             Heikkinen, Miss. Laina  female  26.0      0   \n",
       "3       Futrelle, Mrs. Jacques Heath (Lily May Peel)  female  35.0      1   \n",
       "4                           Allen, Mr. William Henry    male  35.0      0   \n",
       "\n",
       "   Parch            Ticket     Fare Cabin Embarked  \n",
       "0      0         A/5 21171   7.2500   NaN        S  \n",
       "1      0          PC 17599  71.2833   C85        C  \n",
       "2      0  STON/O2. 3101282   7.9250   NaN        S  \n",
       "3      0            113803  53.1000  C123        S  \n",
       "4      0            373450   8.0500   NaN        S  "
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_df = pd.read_csv(\n",
    "    \"https://raw.githubusercontent.com/agconti/kaggle-titanic/master/data/train.csv\"\n",
    ")\n",
    "train_df.info()\n",
    "train_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Look at the data info, we know that there are missing data at `Age`, `Cabin`, and `Embarked` columns. Although `Cabin` is an important feature (because the position of a cabin in the ship can affect the chance of people in that cabin to survive), we will skip it in this tutorial for simplicity. In the dataset, there are many categorical columns and two numerical columns `Age` and `Fare`. Let's first look at the distribution of those categorical columns:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0    549\n",
      "1    342\n",
      "Name: Survived, dtype: int64\n",
      "\n",
      "3    491\n",
      "1    216\n",
      "2    184\n",
      "Name: Pclass, dtype: int64\n",
      "\n",
      "male      577\n",
      "female    314\n",
      "Name: Sex, dtype: int64\n",
      "\n",
      "0    608\n",
      "1    209\n",
      "2     28\n",
      "4     18\n",
      "3     16\n",
      "8      7\n",
      "5      5\n",
      "Name: SibSp, dtype: int64\n",
      "\n",
      "0    678\n",
      "1    118\n",
      "2     80\n",
      "3      5\n",
      "5      5\n",
      "4      4\n",
      "6      1\n",
      "Name: Parch, dtype: int64\n",
      "\n",
      "S    644\n",
      "C    168\n",
      "Q     77\n",
      "Name: Embarked, dtype: int64\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for col in [\"Survived\", \"Pclass\", \"Sex\", \"SibSp\", \"Parch\", \"Embarked\"]:\n",
    "    print(train_df[col].value_counts(), end=\"\\n\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, we will merge rare groups in `SibSp` and `Parch` columns together. In addition, we'll fill 2 missing entries in `Embarked` by the mode `S`. Note that we can make a generative model for those missing entries in `Embarked` but let's skip doing so for simplicity."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df.SibSp.clip(0, 1, inplace=True)\n",
    "train_df.Parch.clip(0, 2, inplace=True)\n",
    "train_df.Embarked.fillna(\"S\", inplace=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Looking closer at the data, we can observe that each name contains a title. We know that age is correlated with the title of the name: e.g. those with Mrs. would be older than those with `Miss.` (on average) so it might be good to create that feature. The distribution of titles is:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Mr.          517\n",
       "Miss.        182\n",
       "Mrs.         125\n",
       "Master.       40\n",
       "Dr.            7\n",
       "Rev.           6\n",
       "Mlle.          2\n",
       "Col.           2\n",
       "Major.         2\n",
       "Lady.          1\n",
       "Sir.           1\n",
       "the            1\n",
       "Ms.            1\n",
       "Capt.          1\n",
       "Mme.           1\n",
       "Jonkheer.      1\n",
       "Don.           1\n",
       "Name: Name, dtype: int64"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_df.Name.str.split(\", \").str.get(1).str.split(\" \").str.get(0).value_counts()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will make a new column `Title`, where rare titles are merged into one group `Misc.`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df[\"Title\"] = (\n",
    "    train_df.Name.str.split(\", \")\n",
    "    .str.get(1)\n",
    "    .str.split(\" \")\n",
    "    .str.get(0)\n",
    "    .apply(lambda x: x if x in [\"Mr.\", \"Miss.\", \"Mrs.\", \"Master.\"] else \"Misc.\")\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, it is ready to turn the dataframe, which includes categorical values, into numpy arrays. We also perform standardization (a good practice for regression models) for `Age` column."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "title_cat = pd.CategoricalDtype(\n",
    "    categories=[\"Mr.\", \"Miss.\", \"Mrs.\", \"Master.\", \"Misc.\"], ordered=True\n",
    ")\n",
    "embarked_cat = pd.CategoricalDtype(categories=[\"S\", \"C\", \"Q\"], ordered=True)\n",
    "age_mean, age_std = train_df.Age.mean(), train_df.Age.std()\n",
    "data = dict(\n",
    "    age=train_df.Age.pipe(lambda x: (x - age_mean) / age_std).values,\n",
    "    pclass=train_df.Pclass.values - 1,\n",
    "    title=train_df.Title.astype(title_cat).cat.codes.values,\n",
    "    sex=(train_df.Sex == \"male\").astype(int).values,\n",
    "    sibsp=train_df.SibSp.values,\n",
    "    parch=train_df.Parch.values,\n",
    "    embarked=train_df.Embarked.astype(embarked_cat).cat.codes.values,\n",
    ")\n",
    "survived = train_df.Survived.values\n",
    "# compute the age mean for each title\n",
    "age_notnan = data[\"age\"][jnp.isfinite(data[\"age\"])]\n",
    "title_notnan = data[\"title\"][jnp.isfinite(data[\"age\"])]\n",
    "age_mean_by_title = jnp.stack([age_notnan[title_notnan == i].mean() for i in range(5)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Modelling"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, we want to note that in NumPyro, the following models\n",
    "```python\n",
    "def model1a():\n",
    "    x = numpyro.sample(\"x\", dist.Normal(0, 1).expand([10]))\n",
    "```\n",
    "and\n",
    "```python\n",
    "def model1b():\n",
    "    x = numpyro.sample(\"x\", dist.Normal(0, 1).expand([10]).mask(False))\n",
    "    numpyro.sample(\"x_obs\", dist.Normal(0, 1).expand([10]), obs=x)\n",
    "```\n",
    "are equivalent in the sense that both of them have\n",
    "\n",
    "+ the same latent sites `x` drawn from `dist.Normal(0, 1)` prior,\n",
    "+ and the same log densities `dist.Normal(0, 1).log_prob(x)`.\n",
    "\n",
    "Now, assume that we observed the last 6 values of `x` (non-observed entries take value `NaN`), the typical model will be\n",
    "```python\n",
    "def model2a(x):\n",
    "    x_impute = numpyro.sample(\"x_impute\", dist.Normal(0, 1).expand([4]))\n",
    "    x_obs = numpyro.sample(\"x_obs\", dist.Normal(0, 1).expand([6]), obs=x[4:])\n",
    "    x_imputed = jnp.concatenate([x_impute, x_obs])\n",
    "```\n",
    "or with the usage of `mask`,\n",
    "```python\n",
    "def model2b(x):\n",
    "    x_impute = numpyro.sample(\"x_impute\", dist.Normal(0, 1).expand([4]).mask(False))\n",
    "    x_imputed = jnp.concatenate([x_impute, x[4:]])\n",
    "    numpyro.sample(\"x\", dist.Normal(0, 1).expand([10]), obs=x_imputed)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Both approaches to model the partial observed data `x` are equivalent. For the model below, we will use the latter method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model(\n",
    "    age, pclass, title, sex, sibsp, parch, embarked, survived=None, bayesian_impute=True\n",
    "):\n",
    "    b_pclass = numpyro.sample(\"b_Pclass\", dist.Normal(0, 1).expand([3]))\n",
    "    b_title = numpyro.sample(\"b_Title\", dist.Normal(0, 1).expand([5]))\n",
    "    b_sex = numpyro.sample(\"b_Sex\", dist.Normal(0, 1).expand([2]))\n",
    "    b_sibsp = numpyro.sample(\"b_SibSp\", dist.Normal(0, 1).expand([2]))\n",
    "    b_parch = numpyro.sample(\"b_Parch\", dist.Normal(0, 1).expand([3]))\n",
    "    b_embarked = numpyro.sample(\"b_Embarked\", dist.Normal(0, 1).expand([3]))\n",
    "\n",
    "    # impute age by Title\n",
    "    isnan = np.isnan(age)\n",
    "    age_nanidx = np.nonzero(isnan)[0]\n",
    "    if bayesian_impute:\n",
    "        age_mu = numpyro.sample(\"age_mu\", dist.Normal(0, 1).expand([5]))\n",
    "        age_mu = age_mu[title]\n",
    "        age_sigma = numpyro.sample(\"age_sigma\", dist.Normal(0, 1).expand([5]))\n",
    "        age_sigma = age_sigma[title]\n",
    "        age_impute = numpyro.sample(\n",
    "            \"age_impute\",\n",
    "            dist.Normal(age_mu[age_nanidx], age_sigma[age_nanidx]).mask(False),\n",
    "        )\n",
    "        age = jnp.asarray(age).at[age_nanidx].set(age_impute)\n",
    "        numpyro.sample(\"age\", dist.Normal(age_mu, age_sigma), obs=age)\n",
    "    else:\n",
    "        # fill missing data by the mean of ages for each title\n",
    "        age_impute = age_mean_by_title[title][age_nanidx]\n",
    "        age = jnp.asarray(age).at[age_nanidx].set(age_impute)\n",
    "\n",
    "    a = numpyro.sample(\"a\", dist.Normal(0, 1))\n",
    "    b_age = numpyro.sample(\"b_Age\", dist.Normal(0, 1))\n",
    "    logits = a + b_age * age\n",
    "    logits = logits + b_title[title] + b_pclass[pclass] + b_sex[sex]\n",
    "    logits = logits + b_sibsp[sibsp] + b_parch[parch] + b_embarked[embarked]\n",
    "    numpyro.sample(\"survived\", dist.Bernoulli(logits=logits), obs=survived)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that in the model, the prior for `age` is `dist.Normal(age_mu, age_sigma)`, where the values of `age_mu` and `age_sigma` depend on `title`. Because there are missing values in `age`, we will encode those missing values in the latent parameter `age_impute`. Then we can replace `NaN` entries in `age` with the vector `age_impute`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sampling"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will use MCMC with NUTS kernel to sample both regression coefficients and imputed values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "sample: 100%|██████████| 2000/2000 [00:15<00:00, 132.15it/s, 63 steps of size 5.68e-02. acc. prob=0.95]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "                     mean       std    median      5.0%     95.0%     n_eff     r_hat\n",
      "              a      0.12      0.82      0.11     -1.21      1.49    887.50      1.00\n",
      "  age_impute[0]      0.20      0.84      0.18     -1.22      1.53   1346.09      1.00\n",
      "  age_impute[1]     -0.06      0.86     -0.08     -1.41      1.26   1057.70      1.00\n",
      "  age_impute[2]      0.38      0.73      0.39     -0.80      1.58   1570.36      1.00\n",
      "  age_impute[3]      0.25      0.84      0.23     -0.99      1.86   1027.43      1.00\n",
      "  age_impute[4]     -0.63      0.91     -0.59     -1.99      0.87   1183.66      1.00\n",
      "  age_impute[5]      0.21      0.89      0.19     -1.02      1.97   1456.79      1.00\n",
      "  age_impute[6]      0.45      0.82      0.46     -0.90      1.73   1239.22      1.00\n",
      "  age_impute[7]     -0.62      0.86     -0.62     -2.13      0.72   1406.09      1.00\n",
      "  age_impute[8]     -0.13      0.90     -0.14     -1.64      1.38   1905.07      1.00\n",
      "  age_impute[9]      0.24      0.84      0.26     -1.06      1.77   1471.12      1.00\n",
      " age_impute[10]      0.20      0.89      0.21     -1.26      1.65   1588.79      1.00\n",
      " age_impute[11]      0.17      0.91      0.19     -1.59      1.48   1446.52      1.00\n",
      " age_impute[12]     -0.65      0.89     -0.68     -2.12      0.77   1457.47      1.00\n",
      " age_impute[13]      0.21      0.85      0.18     -1.24      1.53   1057.77      1.00\n",
      " age_impute[14]      0.05      0.92      0.05     -1.40      1.65   1207.08      1.00\n",
      " age_impute[15]      0.37      0.94      0.37     -1.02      1.98   1326.55      1.00\n",
      " age_impute[16]     -1.74      0.26     -1.74     -2.13     -1.32   1320.08      1.00\n",
      " age_impute[17]      0.21      0.89      0.22     -1.30      1.60   1545.73      1.00\n",
      " age_impute[18]      0.18      0.90      0.18     -1.26      1.58   2013.12      1.00\n",
      " age_impute[19]     -0.67      0.86     -0.66     -1.97      0.85   1499.50      1.00\n",
      " age_impute[20]      0.23      0.89      0.27     -1.19      1.71   1712.24      1.00\n",
      " age_impute[21]      0.21      0.87      0.20     -1.11      1.68   1400.55      1.00\n",
      " age_impute[22]      0.19      0.90      0.18     -1.26      1.63   1400.37      1.00\n",
      " age_impute[23]     -0.15      0.85     -0.15     -1.57      1.24   1205.10      1.00\n",
      " age_impute[24]     -0.71      0.89     -0.73     -2.05      0.82   1085.52      1.00\n",
      " age_impute[25]      0.20      0.85      0.19     -1.20      1.62   1708.01      1.00\n",
      " age_impute[26]      0.21      0.88      0.21     -1.20      1.68   1363.75      1.00\n",
      " age_impute[27]     -0.69      0.91     -0.73     -2.20      0.77   1224.06      1.00\n",
      " age_impute[28]      0.60      0.77      0.60     -0.61      1.95   1312.44      1.00\n",
      " age_impute[29]      0.20      0.89      0.17     -1.23      1.71    938.19      1.00\n",
      " age_impute[30]      0.24      0.87      0.23     -1.14      1.60   1324.50      1.00\n",
      " age_impute[31]     -1.72      0.26     -1.72     -2.11     -1.28   1425.46      1.00\n",
      " age_impute[32]      0.44      0.77      0.43     -0.83      1.58   1587.41      1.00\n",
      " age_impute[33]      0.34      0.89      0.32     -1.14      1.73   1375.14      1.00\n",
      " age_impute[34]     -1.72      0.26     -1.71     -2.11     -1.26   1007.71      1.00\n",
      " age_impute[35]     -0.45      0.90     -0.47     -2.06      0.92   1329.44      1.00\n",
      " age_impute[36]      0.30      0.84      0.30     -1.03      1.73   1080.80      1.00\n",
      " age_impute[37]      0.33      0.88      0.32     -1.10      1.81   1033.30      1.00\n",
      " age_impute[38]      0.33      0.76      0.35     -0.94      1.56   1550.68      1.00\n",
      " age_impute[39]      0.19      0.93      0.21     -1.32      1.82   1203.79      1.00\n",
      " age_impute[40]     -0.67      0.88     -0.69     -1.94      0.88   1382.98      1.00\n",
      " age_impute[41]      0.17      0.89      0.14     -1.30      1.43   1438.18      1.00\n",
      " age_impute[42]      0.23      0.82      0.25     -1.12      1.48   1499.59      1.00\n",
      " age_impute[43]      0.22      0.82      0.21     -1.19      1.45   1236.67      1.00\n",
      " age_impute[44]     -0.41      0.85     -0.42     -1.96      0.78    812.53      1.00\n",
      " age_impute[45]     -0.36      0.89     -0.35     -2.01      0.94   1488.83      1.00\n",
      " age_impute[46]     -0.33      0.91     -0.32     -1.76      1.27   1628.61      1.00\n",
      " age_impute[47]     -0.71      0.85     -0.69     -2.12      0.64   1363.89      1.00\n",
      " age_impute[48]      0.21      0.85      0.24     -1.21      1.64   1552.65      1.00\n",
      " age_impute[49]      0.42      0.82      0.41     -0.83      1.77    754.08      1.00\n",
      " age_impute[50]      0.26      0.86      0.24     -1.18      1.63   1155.49      1.00\n",
      " age_impute[51]     -0.29      0.91     -0.30     -1.83      1.15   1212.08      1.00\n",
      " age_impute[52]      0.36      0.85      0.34     -1.12      1.68   1190.99      1.00\n",
      " age_impute[53]     -0.68      0.89     -0.65     -2.09      0.75   1104.75      1.00\n",
      " age_impute[54]      0.27      0.90      0.25     -1.24      1.68   1331.19      1.00\n",
      " age_impute[55]      0.36      0.89      0.36     -0.96      1.86   1917.52      1.00\n",
      " age_impute[56]      0.38      0.86      0.40     -1.00      1.75   1862.00      1.00\n",
      " age_impute[57]      0.01      0.91      0.03     -1.33      1.56   1285.43      1.00\n",
      " age_impute[58]     -0.69      0.91     -0.66     -2.13      0.78   1438.41      1.00\n",
      " age_impute[59]     -0.14      0.85     -0.16     -1.44      1.37   1135.79      1.00\n",
      " age_impute[60]     -0.59      0.94     -0.61     -2.19      0.93   1222.88      1.00\n",
      " age_impute[61]      0.24      0.92      0.25     -1.35      1.65   1341.95      1.00\n",
      " age_impute[62]     -0.55      0.91     -0.57     -2.01      0.96    753.85      1.00\n",
      " age_impute[63]      0.21      0.90      0.19     -1.42      1.60   1238.50      1.00\n",
      " age_impute[64]     -0.66      0.88     -0.68     -2.04      0.73   1214.85      1.00\n",
      " age_impute[65]      0.44      0.78      0.48     -0.93      1.57   1174.41      1.00\n",
      " age_impute[66]      0.22      0.94      0.20     -1.35      1.69   1910.00      1.00\n",
      " age_impute[67]      0.33      0.76      0.34     -0.85      1.63   1210.24      1.00\n",
      " age_impute[68]      0.31      0.84      0.33     -1.08      1.60   1756.60      1.00\n",
      " age_impute[69]      0.26      0.91      0.25     -1.29      1.75   1155.87      1.00\n",
      " age_impute[70]     -0.67      0.86     -0.70     -2.02      0.70   1186.22      1.00\n",
      " age_impute[71]     -0.70      0.90     -0.69     -2.21      0.75   1469.35      1.00\n",
      " age_impute[72]      0.24      0.86      0.24     -1.07      1.66   1604.16      1.00\n",
      " age_impute[73]      0.34      0.72      0.35     -0.77      1.55   1144.55      1.00\n",
      " age_impute[74]     -0.64      0.85     -0.64     -2.10      0.77   1513.79      1.00\n",
      " age_impute[75]      0.41      0.78      0.42     -0.96      1.60    796.47      1.00\n",
      " age_impute[76]      0.18      0.89      0.21     -1.19      1.74    755.44      1.00\n",
      " age_impute[77]      0.21      0.84      0.22     -1.22      1.63   1371.73      1.00\n",
      " age_impute[78]     -0.36      0.87     -0.33     -1.81      1.01   1017.23      1.00\n",
      " age_impute[79]      0.20      0.84      0.19     -1.35      1.37   1677.57      1.00\n",
      " age_impute[80]      0.23      0.84      0.24     -1.09      1.61   1545.61      1.00\n",
      " age_impute[81]      0.28      0.90      0.32     -1.08      1.83   1735.91      1.00\n",
      " age_impute[82]      0.61      0.80      0.60     -0.61      2.03   1353.67      1.00\n",
      " age_impute[83]      0.24      0.89      0.26     -1.22      1.66   1165.03      1.00\n",
      " age_impute[84]      0.21      0.91      0.21     -1.35      1.65   1584.00      1.00\n",
      " age_impute[85]      0.24      0.92      0.21     -1.33      1.63   1271.37      1.00\n",
      " age_impute[86]      0.31      0.81      0.30     -0.86      1.76   1198.70      1.00\n",
      " age_impute[87]     -0.11      0.84     -0.10     -1.42      1.23   1248.38      1.00\n",
      " age_impute[88]      0.21      0.94      0.22     -1.31      1.77   1082.82      1.00\n",
      " age_impute[89]      0.24      0.86      0.23     -1.08      1.67   2141.98      1.00\n",
      " age_impute[90]      0.41      0.84      0.45     -0.88      1.90   1518.73      1.00\n",
      " age_impute[91]      0.21      0.86      0.20     -1.21      1.58   1723.50      1.00\n",
      " age_impute[92]      0.21      0.84      0.20     -1.21      1.57   1742.44      1.00\n",
      " age_impute[93]      0.22      0.87      0.23     -1.29      1.50   1359.74      1.00\n",
      " age_impute[94]      0.22      0.87      0.18     -1.09      1.70    906.55      1.00\n",
      " age_impute[95]      0.22      0.87      0.23     -1.16      1.65   1112.58      1.00\n",
      " age_impute[96]      0.30      0.84      0.26     -1.18      1.57   1680.70      1.00\n",
      " age_impute[97]      0.23      0.87      0.25     -1.22      1.63   1408.40      1.00\n",
      " age_impute[98]     -0.36      0.91     -0.37     -1.96      1.03   1083.67      1.00\n",
      " age_impute[99]      0.15      0.87      0.14     -1.22      1.61   1644.46      1.00\n",
      "age_impute[100]      0.27      0.85      0.30     -1.27      1.45   1266.96      1.00\n",
      "age_impute[101]      0.25      0.87      0.25     -1.19      1.57   1220.96      1.00\n",
      "age_impute[102]     -0.29      0.85     -0.28     -1.70      1.10   1392.91      1.00\n",
      "age_impute[103]      0.01      0.89      0.01     -1.46      1.39   1137.34      1.00\n",
      "age_impute[104]      0.21      0.86      0.24     -1.16      1.64   1018.70      1.00\n",
      "age_impute[105]      0.24      0.93      0.21     -1.14      1.90   1479.67      1.00\n",
      "age_impute[106]      0.21      0.83      0.21     -1.09      1.55   1471.11      1.00\n",
      "age_impute[107]      0.22      0.85      0.22     -1.09      1.64   1941.83      1.00\n",
      "age_impute[108]      0.31      0.88      0.30     -1.10      1.76   1342.10      1.00\n",
      "age_impute[109]      0.22      0.86      0.23     -1.25      1.56   1198.01      1.00\n",
      "age_impute[110]      0.33      0.78      0.35     -0.95      1.62   1267.01      1.00\n",
      "age_impute[111]      0.22      0.88      0.21     -1.11      1.71   1404.51      1.00\n",
      "age_impute[112]     -0.03      0.90     -0.02     -1.38      1.55   1625.35      1.00\n",
      "age_impute[113]      0.24      0.85      0.23     -1.17      1.62   1361.84      1.00\n",
      "age_impute[114]      0.36      0.86      0.37     -0.99      1.76   1155.67      1.00\n",
      "age_impute[115]      0.26      0.96      0.28     -1.37      1.81   1245.97      1.00\n",
      "age_impute[116]      0.21      0.86      0.24     -1.18      1.69   1565.59      1.00\n",
      "age_impute[117]     -0.31      0.94     -0.33     -1.91      1.19   1593.65      1.00\n",
      "age_impute[118]      0.21      0.87      0.22     -1.20      1.64   1315.42      1.00\n",
      "age_impute[119]     -0.69      0.88     -0.74     -2.00      0.90   1536.44      1.00\n",
      "age_impute[120]      0.63      0.81      0.66     -0.65      1.89    899.61      1.00\n",
      "age_impute[121]      0.27      0.90      0.26     -1.16      1.74   1744.32      1.00\n",
      "age_impute[122]      0.18      0.87      0.18     -1.23      1.60   1625.58      1.00\n",
      "age_impute[123]     -0.39      0.88     -0.38     -1.71      1.12   1266.58      1.00\n",
      "age_impute[124]     -0.62      0.95     -0.63     -2.03      1.01   1600.28      1.00\n",
      "age_impute[125]      0.23      0.88      0.23     -1.15      1.71   1604.27      1.00\n",
      "age_impute[126]      0.18      0.91      0.18     -1.24      1.63   1527.38      1.00\n",
      "age_impute[127]      0.32      0.85      0.36     -1.08      1.73   1074.98      1.00\n",
      "age_impute[128]      0.25      0.88      0.25     -1.10      1.69   1486.79      1.00\n",
      "age_impute[129]     -0.70      0.87     -0.68     -2.20      0.56   1506.55      1.00\n",
      "age_impute[130]      0.21      0.88      0.20     -1.16      1.68   1451.63      1.00\n",
      "age_impute[131]      0.22      0.87      0.23     -1.22      1.61    905.86      1.00\n",
      "age_impute[132]      0.33      0.83      0.33     -1.01      1.66   1517.67      1.00\n",
      "age_impute[133]      0.18      0.86      0.18     -1.19      1.59   1050.00      1.00\n",
      "age_impute[134]     -0.14      0.92     -0.15     -1.77      1.24   1386.20      1.00\n",
      "age_impute[135]      0.19      0.85      0.18     -1.22      1.53   1290.94      1.00\n",
      "age_impute[136]      0.16      0.92      0.16     -1.35      1.74   1767.36      1.00\n",
      "age_impute[137]     -0.71      0.90     -0.68     -2.24      0.82   1154.14      1.00\n",
      "age_impute[138]      0.18      0.91      0.16     -1.30      1.67   1160.90      1.00\n",
      "age_impute[139]      0.24      0.90      0.24     -1.15      1.76   1289.37      1.00\n",
      "age_impute[140]      0.41      0.80      0.39     -1.05      1.53   1532.92      1.00\n",
      "age_impute[141]      0.27      0.83      0.29     -1.04      1.60   1310.29      1.00\n",
      "age_impute[142]     -0.28      0.89     -0.29     -1.68      1.22   1088.65      1.00\n",
      "age_impute[143]     -0.12      0.91     -0.11     -1.56      1.40   1324.74      1.00\n",
      "age_impute[144]     -0.65      0.87     -0.63     -1.91      0.93   1672.31      1.00\n",
      "age_impute[145]     -1.73      0.26     -1.74     -2.11     -1.26   1502.96      1.00\n",
      "age_impute[146]      0.40      0.85      0.40     -0.85      1.84   1443.81      1.00\n",
      "age_impute[147]      0.23      0.87      0.20     -1.37      1.49   1220.62      1.00\n",
      "age_impute[148]     -0.70      0.88     -0.70     -2.08      0.87   1846.67      1.00\n",
      "age_impute[149]      0.27      0.87      0.29     -1.11      1.76   1451.79      1.00\n",
      "age_impute[150]      0.21      0.90      0.20     -1.10      1.78   1409.94      1.00\n",
      "age_impute[151]      0.25      0.87      0.26     -1.21      1.63   1224.08      1.00\n",
      "age_impute[152]      0.05      0.85      0.05     -1.42      1.39   1164.23      1.00\n",
      "age_impute[153]      0.18      0.90      0.15     -1.19      1.72   1697.92      1.00\n",
      "age_impute[154]      1.05      0.93      1.04     -0.24      2.84   1212.82      1.00\n",
      "age_impute[155]      0.20      0.84      0.18     -1.18      1.54   1398.45      1.00\n",
      "age_impute[156]      0.23      0.95      0.19     -1.19      1.87   1773.79      1.00\n",
      "age_impute[157]      0.19      0.85      0.22     -1.13      1.64   1123.21      1.00\n",
      "age_impute[158]      0.22      0.86      0.22     -1.18      1.60   1307.64      1.00\n",
      "age_impute[159]      0.18      0.84      0.18     -1.09      1.59   1499.97      1.00\n",
      "age_impute[160]      0.24      0.89      0.28     -1.23      1.65   1100.08      1.00\n",
      "age_impute[161]     -0.45      0.88     -0.45     -1.86      1.05   1414.97      1.00\n",
      "age_impute[162]      0.39      0.89      0.40     -1.00      1.87   1525.80      1.00\n",
      "age_impute[163]      0.34      0.89      0.35     -1.14      1.75   1600.03      1.00\n",
      "age_impute[164]      0.21      0.94      0.19     -1.13      1.91   1090.05      1.00\n",
      "age_impute[165]      0.22      0.85      0.20     -1.11      1.60   1330.87      1.00\n",
      "age_impute[166]     -0.13      0.91     -0.15     -1.69      1.28   1284.90      1.00\n",
      "age_impute[167]      0.22      0.89      0.24     -1.15      1.76   1261.93      1.00\n",
      "age_impute[168]      0.20      0.90      0.18     -1.18      1.83   1217.16      1.00\n",
      "age_impute[169]      0.07      0.89      0.05     -1.29      1.60   2007.16      1.00\n",
      "age_impute[170]      0.23      0.90      0.24     -1.25      1.67    937.57      1.00\n",
      "age_impute[171]      0.41      0.80      0.42     -0.82      1.82   1404.02      1.00\n",
      "age_impute[172]      0.23      0.87      0.20     -1.33      1.51   2032.72      1.00\n",
      "age_impute[173]     -0.44      0.88     -0.44     -1.81      1.08   1006.62      1.00\n",
      "age_impute[174]      0.19      0.84      0.19     -1.11      1.63   1495.21      1.00\n",
      "age_impute[175]      0.20      0.85      0.20     -1.17      1.63   1551.22      1.00\n",
      "age_impute[176]     -0.43      0.92     -0.44     -1.83      1.21   1477.58      1.00\n",
      "      age_mu[0]      0.19      0.04      0.19      0.12      0.26    749.16      1.00\n",
      "      age_mu[1]     -0.54      0.07     -0.54     -0.66     -0.42    786.30      1.00\n",
      "      age_mu[2]      0.43      0.08      0.42      0.31      0.55   1134.72      1.00\n",
      "      age_mu[3]     -1.73      0.04     -1.73     -1.79     -1.65   1194.53      1.00\n",
      "      age_mu[4]      0.85      0.17      0.85      0.58      1.13   1111.96      1.00\n",
      "   age_sigma[0]      0.88      0.03      0.88      0.82      0.93    766.67      1.00\n",
      "   age_sigma[1]      0.90      0.06      0.90      0.81      0.99    992.72      1.00\n",
      "   age_sigma[2]      0.79      0.05      0.78      0.71      0.87    708.34      1.00\n",
      "   age_sigma[3]      0.26      0.03      0.25      0.20      0.31    959.62      1.00\n",
      "   age_sigma[4]      0.93      0.13      0.93      0.74      1.15   1092.88      1.00\n",
      "          b_Age     -0.45      0.14     -0.44     -0.66     -0.22    744.95      1.00\n",
      "  b_Embarked[0]     -0.28      0.58     -0.30     -1.28      0.64    496.51      1.00\n",
      "  b_Embarked[1]      0.30      0.60      0.29     -0.74      1.20    495.25      1.00\n",
      "  b_Embarked[2]      0.04      0.61      0.03     -0.93      1.02    482.67      1.00\n",
      "     b_Parch[0]      0.45      0.57      0.47     -0.45      1.42    336.02      1.02\n",
      "     b_Parch[1]      0.12      0.58      0.14     -0.91      1.00    377.61      1.02\n",
      "     b_Parch[2]     -0.49      0.58     -0.45     -1.48      0.41    358.61      1.01\n",
      "    b_Pclass[0]      1.22      0.57      1.24      0.33      2.17    371.15      1.00\n",
      "    b_Pclass[1]      0.06      0.57      0.07     -0.84      1.03    369.58      1.00\n",
      "    b_Pclass[2]     -1.18      0.57     -1.16     -2.18     -0.31    373.55      1.00\n",
      "       b_Sex[0]      1.15      0.74      1.18     -0.03      2.31    568.65      1.00\n",
      "       b_Sex[1]     -1.05      0.74     -1.02     -2.18      0.21    709.29      1.00\n",
      "     b_SibSp[0]      0.28      0.66      0.26     -0.86      1.25    585.03      1.00\n",
      "     b_SibSp[1]     -0.17      0.67     -0.18     -1.28      0.87    596.44      1.00\n",
      "     b_Title[0]     -0.94      0.54     -0.96     -1.86     -0.11    437.32      1.00\n",
      "     b_Title[1]     -0.33      0.61     -0.33     -1.32      0.60    570.32      1.00\n",
      "     b_Title[2]      0.53      0.62      0.53     -0.52      1.46    452.87      1.00\n",
      "     b_Title[3]      1.48      0.59      1.48      0.60      2.48    562.71      1.00\n",
      "     b_Title[4]     -0.68      0.58     -0.66     -1.71      0.15    472.57      1.00\n",
      "\n",
      "Number of divergences: 0\n"
     ]
    }
   ],
   "source": [
    "mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=1000)\n",
    "mcmc.run(random.PRNGKey(0), **data, survived=survived)\n",
    "mcmc.print_summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To double check that the assumption \"age is correlated with title\" is reasonable, let's look at the infered age by title. Recall that we performed standarization on `age`, so here we need to scale back to original domain."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'Mr.': 32.434227,\n",
       " 'Miss.': 21.763992,\n",
       " 'Mrs.': 35.852997,\n",
       " 'Master.': 4.6297398,\n",
       " 'Misc.': 42.081936}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "age_by_title = age_mean + age_std * mcmc.get_samples()[\"age_mu\"].mean(axis=0)\n",
    "dict(zip(title_cat.categories, age_by_title))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The infered result confirms our assumption that `Age` is correlated with `Title`:\n",
    "\n",
    "+ those with `Master.` title has pretty small age (in other words, they are children in the ship) comparing to the other groups,\n",
    "+ those with `Mrs.` title have larger age than those with `Miss.` title (in average).\n",
    "\n",
    "We can also see that the result is similar to the actual statistical mean of `Age` given `Title` in our training dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Title\n",
       "Master.     4.574167\n",
       "Misc.      42.384615\n",
       "Miss.      21.773973\n",
       "Mr.        32.368090\n",
       "Mrs.       35.898148\n",
       "Name: Age, dtype: float64"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_df.groupby(\"Title\")[\"Age\"].mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "So far so good, we have many information about the regression coefficients together with imputed values and their uncertainties. Let's inspect those results a bit:\n",
    "\n",
    "+ The mean value `-0.44` of `b_Age` implies that those with smaller ages have better chance to survive.\n",
    "+ The mean value `(1.11, -1.07)` of `b_Sex` implies that female passengers have higher chance to survive than male passengers."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prediction"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In NumPyro, we can use [Predictive](http://num.pyro.ai/en/stable/utilities.html#numpyro.infer.util.Predictive) utility for making predictions from posterior samples. Let's check how well the model performs on the training dataset. For simplicity, we will get a `survived` prediction for each posterior sample and perform the majority rule on the predictions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.8271605\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th>predict</th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>actual</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.876138</td>\n",
       "      <td>0.198830</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.156648</td>\n",
       "      <td>0.748538</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "predict         0         1\n",
       "actual                     \n",
       "0        0.876138  0.198830\n",
       "1        0.156648  0.748538"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "posterior = mcmc.get_samples()\n",
    "survived_pred = Predictive(model, posterior)(random.PRNGKey(1), **data)[\"survived\"]\n",
    "survived_pred = (survived_pred.mean(axis=0) >= 0.5).astype(jnp.uint8)\n",
    "print(\"Accuracy:\", (survived_pred == survived).sum() / survived.shape[0])\n",
    "confusion_matrix = pd.crosstab(\n",
    "    pd.Series(survived, name=\"actual\"), pd.Series(survived_pred, name=\"predict\")\n",
    ")\n",
    "confusion_matrix / confusion_matrix.sum(axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is a pretty good result using a simple logistic regression model. Let's see how the model performs if we don't use Bayesian imputation here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "sample: 100%|██████████| 2000/2000 [00:11<00:00, 166.79it/s, 63 steps of size 7.18e-02. acc. prob=0.93] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.82042646\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th>predict</th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>actual</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.872495</td>\n",
       "      <td>0.204678</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.163934</td>\n",
       "      <td>0.736842</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "predict         0         1\n",
       "actual                     \n",
       "0        0.872495  0.204678\n",
       "1        0.163934  0.736842"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mcmc.run(random.PRNGKey(2), **data, survived=survived, bayesian_impute=False)\n",
    "posterior_1 = mcmc.get_samples()\n",
    "survived_pred_1 = Predictive(model, posterior_1)(random.PRNGKey(2), **data)[\"survived\"]\n",
    "survived_pred_1 = (survived_pred_1.mean(axis=0) >= 0.5).astype(jnp.uint8)\n",
    "print(\"Accuracy:\", (survived_pred_1 == survived).sum() / survived.shape[0])\n",
    "confusion_matrix = pd.crosstab(\n",
    "    pd.Series(survived, name=\"actual\"), pd.Series(survived_pred_1, name=\"predict\")\n",
    ")\n",
    "confusion_matrix / confusion_matrix.sum(axis=1)\n",
    "confusion_matrix = pd.crosstab(\n",
    "    pd.Series(survived, name=\"actual\"), pd.Series(survived_pred_1, name=\"predict\")\n",
    ")\n",
    "confusion_matrix / confusion_matrix.sum(axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that Bayesian imputation does a little bit better here."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Remark.** When using `posterior` samples to perform prediction on the new data, we need to marginalize out `age_impute` because those imputing values are specific to the training data:\n",
    "```python\n",
    "posterior.pop(\"age_impute\")\n",
    "survived_pred = Predictive(model, posterior)(random.PRNGKey(3), **new_data)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## References\n",
    "\n",
    "1. McElreath, R. (2016). Statistical Rethinking: A Bayesian Course with Examples in R and Stan.\n",
    "2. Kaggle competition: [Titanic: Machine Learning from Disaster](https://www.kaggle.com/c/titanic)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
