{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "39afab70",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "sys.path.append(r'../src')\n",
    "import uncertainpy\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1acfe532",
   "metadata": {},
   "source": [
    "# Read heart failure data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d724f804",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>anaemia</th>\n",
       "      <th>creatinine_phosphokinase</th>\n",
       "      <th>diabetes</th>\n",
       "      <th>ejection_fraction</th>\n",
       "      <th>high_blood_pressure</th>\n",
       "      <th>platelets</th>\n",
       "      <th>serum_creatinine</th>\n",
       "      <th>serum_sodium</th>\n",
       "      <th>sex</th>\n",
       "      <th>smoking</th>\n",
       "      <th>time</th>\n",
       "      <th>DEATH_EVENT</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>75.0</td>\n",
       "      <td>0</td>\n",
       "      <td>582</td>\n",
       "      <td>0</td>\n",
       "      <td>20</td>\n",
       "      <td>1</td>\n",
       "      <td>265000.00</td>\n",
       "      <td>1.9</td>\n",
       "      <td>130</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>4</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>55.0</td>\n",
       "      <td>0</td>\n",
       "      <td>7861</td>\n",
       "      <td>0</td>\n",
       "      <td>38</td>\n",
       "      <td>0</td>\n",
       "      <td>263358.03</td>\n",
       "      <td>1.1</td>\n",
       "      <td>136</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>6</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>65.0</td>\n",
       "      <td>0</td>\n",
       "      <td>146</td>\n",
       "      <td>0</td>\n",
       "      <td>20</td>\n",
       "      <td>0</td>\n",
       "      <td>162000.00</td>\n",
       "      <td>1.3</td>\n",
       "      <td>129</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>7</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>50.0</td>\n",
       "      <td>1</td>\n",
       "      <td>111</td>\n",
       "      <td>0</td>\n",
       "      <td>20</td>\n",
       "      <td>0</td>\n",
       "      <td>210000.00</td>\n",
       "      <td>1.9</td>\n",
       "      <td>137</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>7</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>65.0</td>\n",
       "      <td>1</td>\n",
       "      <td>160</td>\n",
       "      <td>1</td>\n",
       "      <td>20</td>\n",
       "      <td>0</td>\n",
       "      <td>327000.00</td>\n",
       "      <td>2.7</td>\n",
       "      <td>116</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>8</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>90.0</td>\n",
       "      <td>1</td>\n",
       "      <td>47</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>1</td>\n",
       "      <td>204000.00</td>\n",
       "      <td>2.1</td>\n",
       "      <td>132</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>8</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>75.0</td>\n",
       "      <td>1</td>\n",
       "      <td>246</td>\n",
       "      <td>0</td>\n",
       "      <td>15</td>\n",
       "      <td>0</td>\n",
       "      <td>127000.00</td>\n",
       "      <td>1.2</td>\n",
       "      <td>137</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>10</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>60.0</td>\n",
       "      <td>1</td>\n",
       "      <td>315</td>\n",
       "      <td>1</td>\n",
       "      <td>60</td>\n",
       "      <td>0</td>\n",
       "      <td>454000.00</td>\n",
       "      <td>1.1</td>\n",
       "      <td>131</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>10</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    age  anaemia  creatinine_phosphokinase  diabetes  ejection_fraction  \\\n",
       "0  75.0        0                       582         0                 20   \n",
       "1  55.0        0                      7861         0                 38   \n",
       "2  65.0        0                       146         0                 20   \n",
       "3  50.0        1                       111         0                 20   \n",
       "4  65.0        1                       160         1                 20   \n",
       "5  90.0        1                        47         0                 40   \n",
       "6  75.0        1                       246         0                 15   \n",
       "7  60.0        1                       315         1                 60   \n",
       "\n",
       "   high_blood_pressure  platelets  serum_creatinine  serum_sodium  sex  \\\n",
       "0                    1  265000.00               1.9           130    1   \n",
       "1                    0  263358.03               1.1           136    1   \n",
       "2                    0  162000.00               1.3           129    1   \n",
       "3                    0  210000.00               1.9           137    1   \n",
       "4                    0  327000.00               2.7           116    0   \n",
       "5                    1  204000.00               2.1           132    1   \n",
       "6                    0  127000.00               1.2           137    1   \n",
       "7                    0  454000.00               1.1           131    1   \n",
       "\n",
       "   smoking  time  DEATH_EVENT  \n",
       "0        0     4            1  \n",
       "1        0     6            1  \n",
       "2        1     7            1  \n",
       "3        0     7            1  \n",
       "4        0     8            1  \n",
       "5        1     8            1  \n",
       "6        0    10            1  \n",
       "7        1    10            1  "
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = pd.read_csv('../data/heart_failure_clinical_records_dataset.csv')\n",
    "data.head(8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c9c40f51",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>anaemia</th>\n",
       "      <th>creatinine_phosphokinase</th>\n",
       "      <th>diabetes</th>\n",
       "      <th>ejection_fraction</th>\n",
       "      <th>high_blood_pressure</th>\n",
       "      <th>platelets</th>\n",
       "      <th>serum_creatinine</th>\n",
       "      <th>serum_sodium</th>\n",
       "      <th>sex</th>\n",
       "      <th>smoking</th>\n",
       "      <th>time</th>\n",
       "      <th>DEATH_EVENT</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>count</th>\n",
       "      <td>299.000000</td>\n",
       "      <td>299.000000</td>\n",
       "      <td>299.000000</td>\n",
       "      <td>299.000000</td>\n",
       "      <td>299.000000</td>\n",
       "      <td>299.000000</td>\n",
       "      <td>299.000000</td>\n",
       "      <td>299.00000</td>\n",
       "      <td>299.000000</td>\n",
       "      <td>299.000000</td>\n",
       "      <td>299.00000</td>\n",
       "      <td>299.000000</td>\n",
       "      <td>299.00000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mean</th>\n",
       "      <td>60.833893</td>\n",
       "      <td>0.431438</td>\n",
       "      <td>581.839465</td>\n",
       "      <td>0.418060</td>\n",
       "      <td>38.083612</td>\n",
       "      <td>0.351171</td>\n",
       "      <td>263358.029264</td>\n",
       "      <td>1.39388</td>\n",
       "      <td>136.625418</td>\n",
       "      <td>0.648829</td>\n",
       "      <td>0.32107</td>\n",
       "      <td>130.260870</td>\n",
       "      <td>0.32107</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>std</th>\n",
       "      <td>11.894809</td>\n",
       "      <td>0.496107</td>\n",
       "      <td>970.287881</td>\n",
       "      <td>0.494067</td>\n",
       "      <td>11.834841</td>\n",
       "      <td>0.478136</td>\n",
       "      <td>97804.236869</td>\n",
       "      <td>1.03451</td>\n",
       "      <td>4.412477</td>\n",
       "      <td>0.478136</td>\n",
       "      <td>0.46767</td>\n",
       "      <td>77.614208</td>\n",
       "      <td>0.46767</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>min</th>\n",
       "      <td>40.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>23.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>14.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>25100.000000</td>\n",
       "      <td>0.50000</td>\n",
       "      <td>113.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.00000</td>\n",
       "      <td>4.000000</td>\n",
       "      <td>0.00000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25%</th>\n",
       "      <td>51.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>116.500000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>30.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>212500.000000</td>\n",
       "      <td>0.90000</td>\n",
       "      <td>134.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.00000</td>\n",
       "      <td>73.000000</td>\n",
       "      <td>0.00000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50%</th>\n",
       "      <td>60.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>250.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>38.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>262000.000000</td>\n",
       "      <td>1.10000</td>\n",
       "      <td>137.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.00000</td>\n",
       "      <td>115.000000</td>\n",
       "      <td>0.00000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75%</th>\n",
       "      <td>70.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>582.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>45.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>303500.000000</td>\n",
       "      <td>1.40000</td>\n",
       "      <td>140.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.00000</td>\n",
       "      <td>203.000000</td>\n",
       "      <td>1.00000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>max</th>\n",
       "      <td>95.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>7861.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>80.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>850000.000000</td>\n",
       "      <td>9.40000</td>\n",
       "      <td>148.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.00000</td>\n",
       "      <td>285.000000</td>\n",
       "      <td>1.00000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "              age     anaemia  creatinine_phosphokinase    diabetes  \\\n",
       "count  299.000000  299.000000                299.000000  299.000000   \n",
       "mean    60.833893    0.431438                581.839465    0.418060   \n",
       "std     11.894809    0.496107                970.287881    0.494067   \n",
       "min     40.000000    0.000000                 23.000000    0.000000   \n",
       "25%     51.000000    0.000000                116.500000    0.000000   \n",
       "50%     60.000000    0.000000                250.000000    0.000000   \n",
       "75%     70.000000    1.000000                582.000000    1.000000   \n",
       "max     95.000000    1.000000               7861.000000    1.000000   \n",
       "\n",
       "       ejection_fraction  high_blood_pressure      platelets  \\\n",
       "count         299.000000           299.000000     299.000000   \n",
       "mean           38.083612             0.351171  263358.029264   \n",
       "std            11.834841             0.478136   97804.236869   \n",
       "min            14.000000             0.000000   25100.000000   \n",
       "25%            30.000000             0.000000  212500.000000   \n",
       "50%            38.000000             0.000000  262000.000000   \n",
       "75%            45.000000             1.000000  303500.000000   \n",
       "max            80.000000             1.000000  850000.000000   \n",
       "\n",
       "       serum_creatinine  serum_sodium         sex    smoking        time  \\\n",
       "count         299.00000    299.000000  299.000000  299.00000  299.000000   \n",
       "mean            1.39388    136.625418    0.648829    0.32107  130.260870   \n",
       "std             1.03451      4.412477    0.478136    0.46767   77.614208   \n",
       "min             0.50000    113.000000    0.000000    0.00000    4.000000   \n",
       "25%             0.90000    134.000000    0.000000    0.00000   73.000000   \n",
       "50%             1.10000    137.000000    1.000000    0.00000  115.000000   \n",
       "75%             1.40000    140.000000    1.000000    1.00000  203.000000   \n",
       "max             9.40000    148.000000    1.000000    1.00000  285.000000   \n",
       "\n",
       "       DEATH_EVENT  \n",
       "count    299.00000  \n",
       "mean       0.32107  \n",
       "std        0.46767  \n",
       "min        0.00000  \n",
       "25%        0.00000  \n",
       "50%        0.00000  \n",
       "75%        1.00000  \n",
       "max        1.00000  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.describe()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2c80fd2",
   "metadata": {},
   "source": [
    "# Turn features into boolean features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c3a4f7c6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>anaemia</th>\n",
       "      <th>diabetes</th>\n",
       "      <th>high_blood_pressure</th>\n",
       "      <th>sex</th>\n",
       "      <th>smoking</th>\n",
       "      <th>DEATH_EVENT</th>\n",
       "      <th>age_small</th>\n",
       "      <th>age_med</th>\n",
       "      <th>age_large</th>\n",
       "      <th>creatinine_phosphokinase_small</th>\n",
       "      <th>...</th>\n",
       "      <th>platelets_large</th>\n",
       "      <th>serum_creatinine_small</th>\n",
       "      <th>serum_creatinine_med</th>\n",
       "      <th>serum_creatinine_large</th>\n",
       "      <th>serum_sodium_small</th>\n",
       "      <th>serum_sodium_med</th>\n",
       "      <th>serum_sodium_large</th>\n",
       "      <th>time_small</th>\n",
       "      <th>time_med</th>\n",
       "      <th>time_large</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "      <td>True</td>\n",
       "      <td>True</td>\n",
       "      <td>False</td>\n",
       "      <td>True</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "      <td>True</td>\n",
       "      <td>False</td>\n",
       "      <td>...</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "      <td>True</td>\n",
       "      <td>True</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "      <td>True</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "      <td>True</td>\n",
       "      <td>False</td>\n",
       "      <td>True</td>\n",
       "      <td>False</td>\n",
       "      <td>True</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "      <td>...</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "      <td>True</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "      <td>True</td>\n",
       "      <td>False</td>\n",
       "      <td>True</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "      <td>True</td>\n",
       "      <td>True</td>\n",
       "      <td>True</td>\n",
       "      <td>False</td>\n",
       "      <td>True</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "      <td>...</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "      <td>True</td>\n",
       "      <td>False</td>\n",
       "      <td>True</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "      <td>True</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>True</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "      <td>True</td>\n",
       "      <td>False</td>\n",
       "      <td>True</td>\n",
       "      <td>True</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "      <td>True</td>\n",
       "      <td>...</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "      <td>True</td>\n",
       "      <td>False</td>\n",
       "      <td>True</td>\n",
       "      <td>False</td>\n",
       "      <td>True</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>True</td>\n",
       "      <td>True</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "      <td>True</td>\n",
       "      <td>False</td>\n",
       "      <td>True</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "      <td>...</td>\n",
       "      <td>True</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "      <td>True</td>\n",
       "      <td>True</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "      <td>True</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 27 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   anaemia  diabetes  high_blood_pressure    sex  smoking  DEATH_EVENT  \\\n",
       "0    False     False                 True   True    False         True   \n",
       "1    False     False                False   True    False         True   \n",
       "2    False     False                False   True     True         True   \n",
       "3     True     False                False   True    False         True   \n",
       "4     True      True                False  False    False         True   \n",
       "\n",
       "   age_small  age_med  age_large  creatinine_phosphokinase_small  ...  \\\n",
       "0      False    False       True                           False  ...   \n",
       "1      False     True      False                           False  ...   \n",
       "2      False     True      False                           False  ...   \n",
       "3       True    False      False                            True  ...   \n",
       "4      False     True      False                           False  ...   \n",
       "\n",
       "   platelets_large  serum_creatinine_small  serum_creatinine_med  \\\n",
       "0            False                   False                 False   \n",
       "1            False                   False                  True   \n",
       "2            False                   False                  True   \n",
       "3            False                   False                 False   \n",
       "4             True                   False                 False   \n",
       "\n",
       "   serum_creatinine_large  serum_sodium_small  serum_sodium_med  \\\n",
       "0                    True                True             False   \n",
       "1                   False               False              True   \n",
       "2                   False                True             False   \n",
       "3                    True               False              True   \n",
       "4                    True                True             False   \n",
       "\n",
       "   serum_sodium_large  time_small  time_med  time_large  \n",
       "0               False        True     False       False  \n",
       "1               False        True     False       False  \n",
       "2               False        True     False       False  \n",
       "3               False        True     False       False  \n",
       "4               False        True     False       False  \n",
       "\n",
       "[5 rows x 27 columns]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from uncertainpy.util.dataTransformation import booleanizeNumericalColumnUsingQuartiles\n",
    "from uncertainpy.util.dataTransformation import booleanizeBinaryColumn\n",
    "\n",
    "#assuming that all features are numerical or binary (categorical features should be treated  by binarizing them)\n",
    "numerical_features = [\"age\",\"creatinine_phosphokinase\",\"ejection_fraction\",\"platelets\",\"serum_creatinine\",\"serum_sodium\",\"time\"]\n",
    "binary_features = [\"anaemia\",\"diabetes\",\"high_blood_pressure\",\"sex\",\"smoking\",\"DEATH_EVENT\"]\n",
    "\n",
    "for f in numerical_features:\n",
    "    booleanizeNumericalColumnUsingQuartiles(data, f)\n",
    "    \n",
    "for f in binary_features:\n",
    "    booleanizeBinaryColumn(data, f)\n",
    "\n",
    "data.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06139e85",
   "metadata": {},
   "source": [
    "# Build small trees to find short rules that generalize well"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "46f9ff97",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict anaemia using the following features:\n",
      "['diabetes', 'high_blood_pressure', 'sex', 'smoking', 'DEATH_EVENT', 'age_small', 'age_med', 'age_large', 'creatinine_phosphokinase_small', 'creatinine_phosphokinase_med', 'creatinine_phosphokinase_large', 'ejection_fraction_small', 'ejection_fraction_med', 'ejection_fraction_large', 'platelets_small', 'platelets_med', 'platelets_large', 'serum_creatinine_small', 'serum_creatinine_med', 'serum_creatinine_large', 'serum_sodium_small', 'serum_sodium_med', 'serum_sodium_large', 'time_small', 'time_med', 'time_large']\n"
     ]
    }
   ],
   "source": [
    "target_variable_index = 0\n",
    "bmask = np.full((1, data.shape[1]), True).flatten()\n",
    "bmask[target_variable_index] = False\n",
    "\n",
    "features = data.columns[bmask].to_list()\n",
    "print(f'predict {data.columns[target_variable_index]} using the following features:')\n",
    "print(features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "8dd5f708",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "|--- creatinine_phosphokinase_large <= 0.50\n",
      "|   |--- smoking <= 0.50\n",
      "|   |   |--- time_large <= 0.50\n",
      "|   |   |   |--- serum_creatinine_med <= 0.50\n",
      "|   |   |   |   |--- platelets_small <= 0.50\n",
      "|   |   |   |   |   |--- weights: [23.00, 21.00] class: False\n",
      "|   |   |   |   |--- platelets_small >  0.50\n",
      "|   |   |   |   |   |--- weights: [2.00, 11.00] class: True\n",
      "|   |   |   |--- serum_creatinine_med >  0.50\n",
      "|   |   |   |   |--- weights: [7.00, 31.00] class: True\n",
      "|   |   |--- time_large >  0.50\n",
      "|   |   |   |--- serum_sodium_small <= 0.50\n",
      "|   |   |   |   |--- weights: [11.00, 10.00] class: False\n",
      "|   |   |   |--- serum_sodium_small >  0.50\n",
      "|   |   |   |   |--- weights: [8.00, 1.00] class: False\n",
      "|   |--- smoking >  0.50\n",
      "|   |   |--- weights: [37.00, 26.00] class: False\n",
      "|--- creatinine_phosphokinase_large >  0.50\n",
      "|   |--- age_med <= 0.50\n",
      "|   |   |--- weights: [52.00, 12.00] class: False\n",
      "|   |--- age_med >  0.50\n",
      "|   |   |--- weights: [30.00, 17.00] class: False\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from sklearn import tree\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "X = data.loc[:, bmask]\n",
    "Y = data.loc[:, bmask==False]\n",
    "\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size = 0.7, random_state = 1)\n",
    "clf = tree.DecisionTreeClassifier(max_leaf_nodes=8) #min_samples_split=40, min_samples_leaf=40)\n",
    "clf = clf.fit(X, Y)\n",
    "\n",
    "rules = tree.export_text(clf, feature_names=features,show_weights=True)\n",
    "print(rules)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74b48852",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
