{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Trusted-AI/AIF360/blob/master/examples/sklearn/monthly_bee_datasets_metrics.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Method of the month (July) -- Dataset loading and running metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Install AIF360\n",
    "!pip install 'aif360' "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "\n",
    "from sklearn.compose import make_column_transformer\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.model_selection import GridSearchCV, train_test_split\n",
    "from sklearn.pipeline import make_pipeline\n",
    "from sklearn.preprocessing import OneHotEncoder, StandardScaler\n",
    "\n",
    "from aif360.sklearn.datasets import fetch_adult, standardize_dataset\n",
    "from aif360.sklearn.metrics import disparate_impact_ratio, average_odds_error\n",
    "from aif360.sklearn.metrics import base_rate, ratio"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Datasets are formatted as separate `X` (# samples x # features) and `y` (# samples x # labels) DataFrames. The index of each DataFrame contains protected attribute values per sample. Datasets may also load a `sample_weight` object to be used with certain algorithms/metrics. All of this makes it so that aif360 is compatible with scikit-learn objects.\n",
    "\n",
    "For example, we can easily load the Adult dataset from UCI with the following line:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>education</th>\n",
       "      <th>education-num</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "      <th>native-country</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Black</th>\n",
       "      <th>Male</th>\n",
       "      <td>25.0</td>\n",
       "      <td>Private</td>\n",
       "      <td>11th</td>\n",
       "      <td>7.0</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Machine-op-inspct</td>\n",
       "      <td>Own-child</td>\n",
       "      <td>Black</td>\n",
       "      <td>Male</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>40.0</td>\n",
       "      <td>United-States</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"2\" valign=\"top\">White</th>\n",
       "      <th>Male</th>\n",
       "      <td>38.0</td>\n",
       "      <td>Private</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9.0</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Farming-fishing</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>50.0</td>\n",
       "      <td>United-States</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Male</th>\n",
       "      <td>28.0</td>\n",
       "      <td>Local-gov</td>\n",
       "      <td>Assoc-acdm</td>\n",
       "      <td>12.0</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Protective-serv</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>40.0</td>\n",
       "      <td>United-States</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Black</th>\n",
       "      <th>Male</th>\n",
       "      <td>44.0</td>\n",
       "      <td>Private</td>\n",
       "      <td>Some-college</td>\n",
       "      <td>10.0</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Machine-op-inspct</td>\n",
       "      <td>Husband</td>\n",
       "      <td>Black</td>\n",
       "      <td>Male</td>\n",
       "      <td>7688.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>40.0</td>\n",
       "      <td>United-States</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>White</th>\n",
       "      <th>Male</th>\n",
       "      <td>34.0</td>\n",
       "      <td>Private</td>\n",
       "      <td>10th</td>\n",
       "      <td>6.0</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Other-service</td>\n",
       "      <td>Not-in-family</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>30.0</td>\n",
       "      <td>United-States</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "             age  workclass     education  education-num      marital-status  \\\n",
       "race  sex                                                                      \n",
       "Black Male  25.0    Private          11th            7.0       Never-married   \n",
       "White Male  38.0    Private       HS-grad            9.0  Married-civ-spouse   \n",
       "      Male  28.0  Local-gov    Assoc-acdm           12.0  Married-civ-spouse   \n",
       "Black Male  44.0    Private  Some-college           10.0  Married-civ-spouse   \n",
       "White Male  34.0    Private          10th            6.0       Never-married   \n",
       "\n",
       "                   occupation   relationship   race   sex  capital-gain  \\\n",
       "race  sex                                                                 \n",
       "Black Male  Machine-op-inspct      Own-child  Black  Male           0.0   \n",
       "White Male    Farming-fishing        Husband  White  Male           0.0   \n",
       "      Male    Protective-serv        Husband  White  Male           0.0   \n",
       "Black Male  Machine-op-inspct        Husband  Black  Male        7688.0   \n",
       "White Male      Other-service  Not-in-family  White  Male           0.0   \n",
       "\n",
       "            capital-loss  hours-per-week native-country  \n",
       "race  sex                                                \n",
       "Black Male           0.0            40.0  United-States  \n",
       "White Male           0.0            50.0  United-States  \n",
       "      Male           0.0            40.0  United-States  \n",
       "Black Male           0.0            40.0  United-States  \n",
       "White Male           0.0            30.0  United-States  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X, y, sample_weight = fetch_adult(binary_race=False)\n",
    "X.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The protected attribute information is also replicated in the labels (this will be useful when running metics):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "race   sex \n",
       "Black  Male    <=50K\n",
       "White  Male    <=50K\n",
       "       Male     >50K\n",
       "Black  Male     >50K\n",
       "White  Male    <=50K\n",
       "Name: annual-income, dtype: category\n",
       "Categories (2, object): ['<=50K' < '>50K']"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "By default, this will drop rows which include missing (NA) values. Changing the `dropna` arg will allow us to handle those samples differently, if we so choose."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((45222, 13), (48842, 13))"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_num, y_num, _ = fetch_adult(dropna=False)\n",
    "X.shape, X_num.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There are a number of other arguments one can explore. For example, the `numeric_only` arg will _drop_ all non-numeric columns (this will be useful for running `DisparateImpactRemover`). This is equivalent to using the `usecols` or `dropcols` arguments. Finally, these can all be done manually as well since the data are simply Pandas data types but it is convenient, especially when combined with `dropna` to use the provided interface."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can then split the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "(X_train, X_test,\n",
    " y_train, y_test) = train_test_split(X, y, train_size=0.7, random_state=1234567)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "and one-hot encode the categorical features:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from sklearn.base import BaseEstimator, MetaEstimatorMixin, clone\n",
    "\n",
    "\n",
    "class PandasMeta(BaseEstimator, MetaEstimatorMixin):\n",
    "    def __init__(self, estimator):\n",
    "        self.estimator = estimator\n",
    "\n",
    "    def fit(self, X, y=None, **fit_params):\n",
    "        self.estimator_ = clone(self.estimator)\n",
    "        self.estimator_.fit(X, y, **fit_params)\n",
    "        return self\n",
    "\n",
    "    def transform(self, X):\n",
    "        assert isinstance(X, pd.DataFrame)\n",
    "        output = self.estimator_.transform(X)\n",
    "        if not isinstance(output, pd.DataFrame):\n",
    "            output = pd.DataFrame(output, index=X.index)\n",
    "            try:\n",
    "                columns = self.estimator_.get_feature_names_out()\n",
    "                output.columns = columns\n",
    "            except:\n",
    "                pass\n",
    "        return output\n",
    "    \n",
    "    def fit_transform(self, X, y=None, **fit_params):\n",
    "        assert isinstance(X, pd.DataFrame)\n",
    "        self.estimator_ = clone(self.estimator)\n",
    "        output = self.estimator_.fit_transform(X, y, **fit_params)\n",
    "        if not isinstance(output, pd.DataFrame):\n",
    "            output = pd.DataFrame(output, index=X.index)\n",
    "            try:\n",
    "                columns = self.estimator_.get_feature_names_out()\n",
    "                output.columns = columns\n",
    "            except:\n",
    "                pass\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>workclass_Federal-gov</th>\n",
       "      <th>workclass_Local-gov</th>\n",
       "      <th>workclass_Private</th>\n",
       "      <th>workclass_Self-emp-inc</th>\n",
       "      <th>workclass_Self-emp-not-inc</th>\n",
       "      <th>workclass_State-gov</th>\n",
       "      <th>workclass_Without-pay</th>\n",
       "      <th>education_10th</th>\n",
       "      <th>education_11th</th>\n",
       "      <th>education_12th</th>\n",
       "      <th>...</th>\n",
       "      <th>native-country_Thailand</th>\n",
       "      <th>native-country_Trinadad&amp;Tobago</th>\n",
       "      <th>native-country_United-States</th>\n",
       "      <th>native-country_Vietnam</th>\n",
       "      <th>native-country_Yugoslavia</th>\n",
       "      <th>age</th>\n",
       "      <th>education-num</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"10\" valign=\"top\">White</th>\n",
       "      <th>Male</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.470648</td>\n",
       "      <td>0.345896</td>\n",
       "      <td>-0.147645</td>\n",
       "      <td>-0.217489</td>\n",
       "      <td>0.089733</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Female</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.941480</td>\n",
       "      <td>0.739149</td>\n",
       "      <td>-0.147645</td>\n",
       "      <td>-0.217489</td>\n",
       "      <td>-0.901858</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Male</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-0.948408</td>\n",
       "      <td>1.525654</td>\n",
       "      <td>-0.147645</td>\n",
       "      <td>4.469448</td>\n",
       "      <td>-0.075532</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Male</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.412311</td>\n",
       "      <td>-2.800127</td>\n",
       "      <td>-0.147645</td>\n",
       "      <td>-0.217489</td>\n",
       "      <td>-0.075532</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Male</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-0.419239</td>\n",
       "      <td>-1.620368</td>\n",
       "      <td>-0.147645</td>\n",
       "      <td>-0.217489</td>\n",
       "      <td>-0.075532</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Male</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.941480</td>\n",
       "      <td>-0.440610</td>\n",
       "      <td>-0.147645</td>\n",
       "      <td>-0.217489</td>\n",
       "      <td>-0.075532</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Male</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.336716</td>\n",
       "      <td>-0.440610</td>\n",
       "      <td>-0.147645</td>\n",
       "      <td>-0.217489</td>\n",
       "      <td>0.089733</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Male</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.395053</td>\n",
       "      <td>1.918907</td>\n",
       "      <td>-0.147645</td>\n",
       "      <td>-0.217489</td>\n",
       "      <td>-0.488695</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Male</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-0.948408</td>\n",
       "      <td>-0.440610</td>\n",
       "      <td>-0.147645</td>\n",
       "      <td>-0.217489</td>\n",
       "      <td>2.403445</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Black</th>\n",
       "      <th>Male</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>2.075412</td>\n",
       "      <td>-0.440610</td>\n",
       "      <td>-0.147645</td>\n",
       "      <td>-0.217489</td>\n",
       "      <td>-0.075532</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>31655 rows × 103 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "              workclass_Federal-gov  workclass_Local-gov  workclass_Private  \\\n",
       "race  sex                                                                     \n",
       "White Male                      0.0                  0.0                0.0   \n",
       "      Female                    0.0                  0.0                0.0   \n",
       "      Male                      0.0                  0.0                1.0   \n",
       "      Male                      0.0                  0.0                1.0   \n",
       "      Male                      0.0                  0.0                1.0   \n",
       "...                             ...                  ...                ...   \n",
       "      Male                      0.0                  0.0                1.0   \n",
       "      Male                      0.0                  0.0                1.0   \n",
       "      Male                      0.0                  0.0                0.0   \n",
       "      Male                      0.0                  0.0                1.0   \n",
       "Black Male                      0.0                  0.0                1.0   \n",
       "\n",
       "              workclass_Self-emp-inc  workclass_Self-emp-not-inc  \\\n",
       "race  sex                                                          \n",
       "White Male                       0.0                         1.0   \n",
       "      Female                     0.0                         1.0   \n",
       "      Male                       0.0                         0.0   \n",
       "      Male                       0.0                         0.0   \n",
       "      Male                       0.0                         0.0   \n",
       "...                              ...                         ...   \n",
       "      Male                       0.0                         0.0   \n",
       "      Male                       0.0                         0.0   \n",
       "      Male                       1.0                         0.0   \n",
       "      Male                       0.0                         0.0   \n",
       "Black Male                       0.0                         0.0   \n",
       "\n",
       "              workclass_State-gov  workclass_Without-pay  education_10th  \\\n",
       "race  sex                                                                  \n",
       "White Male                    0.0                    0.0             0.0   \n",
       "      Female                  0.0                    0.0             0.0   \n",
       "      Male                    0.0                    0.0             0.0   \n",
       "      Male                    0.0                    0.0             0.0   \n",
       "      Male                    0.0                    0.0             1.0   \n",
       "...                           ...                    ...             ...   \n",
       "      Male                    0.0                    0.0             0.0   \n",
       "      Male                    0.0                    0.0             0.0   \n",
       "      Male                    0.0                    0.0             0.0   \n",
       "      Male                    0.0                    0.0             0.0   \n",
       "Black Male                    0.0                    0.0             0.0   \n",
       "\n",
       "              education_11th  education_12th  ...  native-country_Thailand  \\\n",
       "race  sex                                     ...                            \n",
       "White Male               0.0             0.0  ...                      0.0   \n",
       "      Female             0.0             0.0  ...                      0.0   \n",
       "      Male               0.0             0.0  ...                      0.0   \n",
       "      Male               0.0             0.0  ...                      0.0   \n",
       "      Male               0.0             0.0  ...                      0.0   \n",
       "...                      ...             ...  ...                      ...   \n",
       "      Male               0.0             0.0  ...                      0.0   \n",
       "      Male               0.0             0.0  ...                      0.0   \n",
       "      Male               0.0             0.0  ...                      0.0   \n",
       "      Male               0.0             0.0  ...                      0.0   \n",
       "Black Male               0.0             0.0  ...                      0.0   \n",
       "\n",
       "              native-country_Trinadad&Tobago  native-country_United-States  \\\n",
       "race  sex                                                                    \n",
       "White Male                               0.0                           1.0   \n",
       "      Female                             0.0                           0.0   \n",
       "      Male                               0.0                           1.0   \n",
       "      Male                               0.0                           0.0   \n",
       "      Male                               0.0                           1.0   \n",
       "...                                      ...                           ...   \n",
       "      Male                               0.0                           1.0   \n",
       "      Male                               0.0                           1.0   \n",
       "      Male                               0.0                           1.0   \n",
       "      Male                               0.0                           1.0   \n",
       "Black Male                               0.0                           1.0   \n",
       "\n",
       "              native-country_Vietnam  native-country_Yugoslavia       age  \\\n",
       "race  sex                                                                   \n",
       "White Male                       0.0                        0.0  1.470648   \n",
       "      Female                     0.0                        0.0  0.941480   \n",
       "      Male                       0.0                        0.0 -0.948408   \n",
       "      Male                       0.0                        0.0  0.412311   \n",
       "      Male                       0.0                        0.0 -0.419239   \n",
       "...                              ...                        ...       ...   \n",
       "      Male                       0.0                        0.0  0.941480   \n",
       "      Male                       0.0                        0.0  0.336716   \n",
       "      Male                       0.0                        0.0  1.395053   \n",
       "      Male                       0.0                        0.0 -0.948408   \n",
       "Black Male                       0.0                        0.0  2.075412   \n",
       "\n",
       "              education-num  capital-gain  capital-loss  hours-per-week  \n",
       "race  sex                                                                \n",
       "White Male         0.345896     -0.147645     -0.217489        0.089733  \n",
       "      Female       0.739149     -0.147645     -0.217489       -0.901858  \n",
       "      Male         1.525654     -0.147645      4.469448       -0.075532  \n",
       "      Male        -2.800127     -0.147645     -0.217489       -0.075532  \n",
       "      Male        -1.620368     -0.147645     -0.217489       -0.075532  \n",
       "...                     ...           ...           ...             ...  \n",
       "      Male        -0.440610     -0.147645     -0.217489       -0.075532  \n",
       "      Male        -0.440610     -0.147645     -0.217489        0.089733  \n",
       "      Male         1.918907     -0.147645     -0.217489       -0.488695  \n",
       "      Male        -0.440610     -0.147645     -0.217489        2.403445  \n",
       "Black Male        -0.440610     -0.147645     -0.217489       -0.075532  \n",
       "\n",
       "[31655 rows x 103 columns]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pre = make_column_transformer(\n",
    "        (OneHotEncoder(sparse=False), X_train.dtypes == 'category'),\n",
    "        (StandardScaler(), X_train.dtypes != 'category'),\n",
    "        verbose_feature_names_out=False)\n",
    "pre = PandasMeta(pre)\n",
    "pre.fit_transform(X_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Running metrics"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we can get a baseline measurement for disparate impact ratio on the training data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.36011001703880235"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "di_train = disparate_impact_ratio(y_train, prot_attr='sex', priv_group='Male', pos_label='>50K')\n",
    "di_train"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Sidenote: this is equivalent to:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.36011001703880235"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ratio(base_rate, y_train, prot_attr='sex', priv_group='Male', pos_label='>50K')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With the data in this format, we can easily train a scikit-learn model and get predictions for the test data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8461708557529299"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = LogisticRegression(max_iter=1000)\n",
    "grid = GridSearchCV(model, param_grid={'C': [1, 10]})\n",
    "pipe = make_pipeline(pre, grid)\n",
    "\n",
    "y_pred = pipe.fit(X_train, y_train).predict(X_test)\n",
    "accuracy_score(y_test, y_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'C': 1}"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pipe.named_steps['gridsearchcv'].best_params_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we can analyze our predictions and quickly calucate the disparate impact for females vs. males:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.3006936739406486"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "di_test = disparate_impact_ratio(y_test, y_pred, prot_attr='sex', priv_group='Male', pos_label='>50K')\n",
    "di_test"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see our model (without mitigation) has worsened the disparate impact compared to the baseline present in the data itself."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD8CAYAAABn919SAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAMuElEQVR4nO3df6zd9V3H8edrrYgyBLPWaNqOkq3YVWI2uEEIJjYZmkJimzg11Cw6JesfDrdky5IaF375j0jUxKSblrhMNx1j/LHcbNX6C6ISOnvLGFnLqteCttWEjjH8QTZA3/5xD3p2ubfnlJ57b/vu85EQzvf7/dxz3iUnz3z7Pfd7SFUhSTr/vWGlB5AkTYZBl6QmDLokNWHQJakJgy5JTRh0SWpiZNCTfDzJs0m+ssjxJPndJLNJnkxyzeTHlCSNMs4Z+ieAbac5fjOwafDPLuBjZz+WJOlMjQx6Vf0N8PXTLNkB/FHNOQBcnuQHJjWgJGk8qyfwHOuA40PbJwb7/m3+wiS7mDuL55JLLrl28+bNE3h5SbpwHDp06GtVtXahY5MI+tiqai+wF2BqaqpmZmaW8+Ul6byX5J8XOzaJ33I5CWwY2l4/2CdJWkaTCPo08POD33a5Hnihql5zuUWStLRGXnJJ8mlgK7AmyQngTuA7AKrq94B9wC3ALPAi8ItLNawkaXEjg15VO0ccL+B9E5tIkvS6eKeoJDVh0CWpCYMuSU0YdElqwqBLUhMGXZKaMOiS1IRBl6QmDLokNWHQJakJgy5JTRh0SWrCoEtSEwZdkpow6JLUhEGXpCYMuiQ1YdAlqQmDLklNGHRJasKgS1ITBl2SmjDoktSEQZekJgy6JDVh0CWpCYMuSU0YdElqwqBLUhMGXZKaMOiS1IRBl6QmDLokNWHQJakJgy5JTYwV9CTbkhxNMptk9wLH35zk4SRfSvJkklsmP6ok6XRGBj3JKmAPcDOwBdiZZMu8ZR8BHqyqdwC3Ah+d9KCSpNNbPcaa64DZqjoGkOQBYAdwZGhNAd8zeHwZ8K8jn/XoUdi69UxmlSSdxjiXXNYBx4e2Twz2DbsLeHeSE8A+4FcWeqIku5LMJJl5+eWXX8e4kqTFjHOGPo6dwCeq6reS3AB8MsnVVfU/w4uqai+wF2Bqaqp45JEJvbwkXSCSRQ+Nc4Z+EtgwtL1+sG/YbcCDAFX1GHAxsOaMhpQknZVxgn4Q2JTkyiQXMfeh5/S8Nf8CvBMgyduYC/qpSQ4qSTq9kUGvqleA24H9wFPM/TbL4ST3JNk+WPYh4L1Jvgx8GnhPVdVSDS1Jeq2xrqFX1T7mPuwc3nfH0OMjwI2THU2SdCa8U1SSmjDoktSEQZekJgy6JDVh0CWpCYMuSU0YdElqwqBLUhMGXZKaMOiS1IRBl6QmDLokNWHQJakJgy5JTRh0SWrCoEtSEwZdkpow6JLUhEGXpCYMuiQ1YdAlqQmDLklNGHRJasKgS1ITBl2SmjDoktSEQZekJgy6JDVh0CWpCYMuSU0YdElqwqBLUhMGXZKaMOiS1MRYQU+yLcnRJLNJdi+y5meTHElyOMmfTHZMSdIoq0ctSLIK2AP8OHACOJhkuqqODK3ZBPwqcGNVPZ/k+5ZqYEnSwsY5Q78OmK2qY1X1EvAAsGPemvcCe6rqeYCqenayY0qSRhkn6OuA40PbJwb7hl0FXJXk0SQHkmxb6ImS7Eoyk2Tm1KlTr29iSdKCJvWh6GpgE7AV2Ancn+Ty+Yuqam9VTVXV1Nq1ayf00pIkGC/oJ4ENQ9vrB/uGnQCmq+rlqnoa+AfmAi9JWibjBP0gsCnJlUkuAm4Fpuet+RxzZ+ckWcPcJZhjkxtTkjTKyKBX1SvA7cB+4Cngwao6nOSeJNsHy/YDzyU5AjwMfLiqnluqoSVJr5WqWpEXnpqaqpmZmRV5bUk6XyU5VFVTCx3zTlFJasKgS1ITBl2SmjDoktSEQZekJgy6JDVh0CWpCYMuSU0YdElqwqBLUhMGXZKaMOiS1IRBl6QmDLokNWHQJakJgy5JTRh0SWrCoEtSEwZdkpow6JLUhEGXpCYMuiQ1YdAlqQmDLklNGHRJasKgS1ITBl2SmjDoktSEQZekJgy6JDVh0CWpCYMuSU0YdElqwqBLUhMGXZKaGCvoSbYlOZpkNsnu06x7V5JKMjW5ESVJ4xgZ9CSrgD3AzcAWYGeSLQusuxT4APDFSQ8pSRptnDP064DZqjpWVS8BDwA7Flj368C9wDcnOJ8kaUzjBH0dcHxo+8Rg3/9Jcg2woaq+cLonSrIryUySmVOnTp3xsJKkxZ31h6JJ3gD8NvChUWuram9VTVXV1Nq1a8/2pSVJQ8YJ+klgw9D2+sG+V10KXA08kuQZ4Hpg2g9GJWl5jRP0g8CmJFcmuQi4FZh+9WBVvVBVa6pqY1VtBA4A26tqZkkmliQtaGTQq+oV4HZgP/AU8GBVHU5yT5LtSz2gJGk8q8dZVFX7gH3z9t2xyNqtZz+WJOlMeaeoJDVh0CWpibEuuZyr7r777pUeQeegO++8c6VHkFaEZ+iS1IRBl6QmDLokNWHQJakJgy5JTRh0SWrCoEtSEwZdkpow6JLUhEGXpCYMuiQ1YdAlqYnz+su5pHPVlc88s9Ij6Bz09MaNS/r8nqFLUhMGXZKaMOiS1IRBl6QmDLokNWHQJakJgy5JTRh0SWrCoEtSEwZdkpow6JLUhEGXpCYMuiQ1YdAlqQmDLklNGHRJasKgS1ITBl2Smhgr6Em2JTmaZDbJ7gWOfzDJkSRPJvmrJFdMflRJ0umMDHqSVcAe4GZgC7AzyZZ5y74ETFXVDwMPAb856UElSac3zhn6dcBsVR2rqpeAB4Adwwuq6uGqenGweQBYP9kxJUmjjBP0dcDxoe0Tg32LuQ3404UOJNmVZCbJzKlTp8afUpI00kQ/FE3ybmAKuG+h41W1t6qmqmpq7dq1k3xpSbrgrR5jzUlgw9D2+sG+b5PkJuDXgB+rqm9NZjxJ0rjGOUM/CGxKcmWSi4BbgenhBUneAfw+sL2qnp38mJKkUUYGvapeAW4H9gNPAQ9W1eEk9yTZPlh2H/BG4LNJnkgyvcjTSZKWyDiXXKiqfcC+efvuGHp804TnkiSdIe8UlaQmDLokNWHQJakJgy5JTRh0SWrCoEtSEwZdkpow6JLUhEGXpCYMuiQ1YdAlqQmDLklNGHRJasKgS1ITBl2SmjDoktSEQZekJgy6JDVh0CWpCYMuSU0YdElqwqBLUhMGXZKaMOiS1IRBl6QmDLokNWHQJakJgy5JTRh0SWrCoEtSEwZdkpow6JLUhEGXpCYMuiQ1YdAlqYmxgp5kW5KjSWaT7F7g+Hcm+czg+BeTbJz4pJKk0xoZ9CSrgD3AzcAWYGeSLfOW3QY8X1VvBX4HuHfSg0qSTm+cM/TrgNmqOlZVLwEPADvmrdkB/OHg8UPAO5NkcmNKkkZZPcaadcDxoe0TwI8stqaqXknyAvAm4GvDi5LsAnYNNv8zydHXM7QWtIZ5/70vVHfddddKj6Bv53tzYEJnuVcsdmCcoE9MVe0F9i7na14oksxU1dRKzyHN53tz+YxzyeUksGFoe/1g34JrkqwGLgOem8SAkqTxjBP0g8CmJFcmuQi4FZiet2Ya+IXB458G/rqqanJjSpJGGXnJZXBN/HZgP7AK+HhVHU5yDzBTVdPAHwCfTDILfJ256Gt5eSlL5yrfm8sknkhLUg/eKSpJTRh0SWrCoJ+jklye5Jdfx8/tS3L5EowkLZkkW5N8fqXnON8Z9HPX5cBrgj74tdBFVdUtVfWNJZpJOiODrw7RMjHo567fAN6S5IkkB5P8bZJp4AhAks8lOZTk8OAOXAb7n0myJsnGJE8luX+w5s+TfNdK/WHUz+A99tUkfzx4rz2U5LsH78F7kzwO/EySn0jyWJLHk3w2yRsHP79t8POPAz+1sn+aHgz6uWs38E9V9Xbgw8A1wAeq6qrB8V+qqmuBKeD9Sd60wHNsAvZU1Q8B3wDeteRT60Lzg8BHq+ptwL/z/3+rfK6qrgH+EvgIcNNgewb4YJKLgfuBnwSuBb5/2SdvyKCfP/6+qp4e2n5/ki8DB5i7S3fTAj/zdFU9MXh8CNi4pBPqQnS8qh4dPP4U8KODx58Z/Pt65r6l9dEkTzB3A+IVwGbm3p//OLgJ8VPLN3Jfy/pdLjor//XqgyRbgZuAG6rqxSSPABcv8DPfGnr834CXXDRp829keXX71fdrgL+oqp3Di5K8fYnnuiB5hn7u+g/g0kWOXcbc98+/mGQzc2dB0kp4c5IbBo9/Dvi7eccPADcmeStAkkuSXAV8FdiY5C2DdTvRWTPo56iqeo65v6Z+Bbhv3uE/A1YneYq5D08PLPd80sBR4H2D9+L3Ah8bPlhVp4D3AJ9O8iTwGLC5qr7J3Fdpf2Hwoeizyzp1U976L+l1GfyvJj9fVVev9Cya4xm6JDXhGbokNeEZuiQ1YdAlqQmDLklNGHRJasKgS1IT/wtpE7YbpMXyMgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "ax = sns.barplot(x=['train', 'pred'], y=[di_train, di_test], palette=['gray', 'aqua'])\n",
    "ax.axes.set_ylim([0, 1]);\n",
    "ax.axes.axhline(y=0.8, c='r');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And similarly, we can assess how close the predictions are to equality of odds.\n",
    "\n",
    "`average_odds_error()` computes the (unweighted) average of the absolute values of the true positive rate (TPR) difference and false positive rate (FPR) difference, i.e.:\n",
    "\n",
    "$$ \\tfrac{1}{2}\\left(|FPR_{D = \\text{unprivileged}} - FPR_{D = \\text{privileged}}| + |TPR_{D = \\text{unprivileged}} - TPR_{D = \\text{privileged}}|\\right) $$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.08822379061098651"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "average_odds_error(y_test, y_pred, priv_group=('White', 'Male'), pos_label='>50K')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In that case, we chose to look at the intersection of all protected attributes (race and sex) and designate a single combination (white males) as privileged."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load a custom dataset\n",
    "\n",
    "We can just as easily load a dataset other than the ones included out-of-the-box. As an example, let us examine the ubiquitous Titanic dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>pclass</th>\n",
       "      <th>name</th>\n",
       "      <th>sex</th>\n",
       "      <th>age</th>\n",
       "      <th>sibsp</th>\n",
       "      <th>parch</th>\n",
       "      <th>ticket</th>\n",
       "      <th>fare</th>\n",
       "      <th>cabin</th>\n",
       "      <th>embarked</th>\n",
       "      <th>boat</th>\n",
       "      <th>body</th>\n",
       "      <th>home.dest</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1.0</td>\n",
       "      <td>Allen, Miss. Elisabeth Walton</td>\n",
       "      <td>female</td>\n",
       "      <td>29.0000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>24160</td>\n",
       "      <td>211.3375</td>\n",
       "      <td>B5</td>\n",
       "      <td>S</td>\n",
       "      <td>2</td>\n",
       "      <td>NaN</td>\n",
       "      <td>St Louis, MO</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1.0</td>\n",
       "      <td>Allison, Master. Hudson Trevor</td>\n",
       "      <td>male</td>\n",
       "      <td>0.9167</td>\n",
       "      <td>1.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>113781</td>\n",
       "      <td>151.5500</td>\n",
       "      <td>C22 C26</td>\n",
       "      <td>S</td>\n",
       "      <td>11</td>\n",
       "      <td>NaN</td>\n",
       "      <td>Montreal, PQ / Chesterville, ON</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1.0</td>\n",
       "      <td>Allison, Miss. Helen Loraine</td>\n",
       "      <td>female</td>\n",
       "      <td>2.0000</td>\n",
       "      <td>1.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>113781</td>\n",
       "      <td>151.5500</td>\n",
       "      <td>C22 C26</td>\n",
       "      <td>S</td>\n",
       "      <td>None</td>\n",
       "      <td>NaN</td>\n",
       "      <td>Montreal, PQ / Chesterville, ON</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1.0</td>\n",
       "      <td>Allison, Mr. Hudson Joshua Creighton</td>\n",
       "      <td>male</td>\n",
       "      <td>30.0000</td>\n",
       "      <td>1.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>113781</td>\n",
       "      <td>151.5500</td>\n",
       "      <td>C22 C26</td>\n",
       "      <td>S</td>\n",
       "      <td>None</td>\n",
       "      <td>135.0</td>\n",
       "      <td>Montreal, PQ / Chesterville, ON</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1.0</td>\n",
       "      <td>Allison, Mrs. Hudson J C (Bessie Waldo Daniels)</td>\n",
       "      <td>female</td>\n",
       "      <td>25.0000</td>\n",
       "      <td>1.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>113781</td>\n",
       "      <td>151.5500</td>\n",
       "      <td>C22 C26</td>\n",
       "      <td>S</td>\n",
       "      <td>None</td>\n",
       "      <td>NaN</td>\n",
       "      <td>Montreal, PQ / Chesterville, ON</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   pclass                                             name     sex      age  \\\n",
       "0     1.0                    Allen, Miss. Elisabeth Walton  female  29.0000   \n",
       "1     1.0                   Allison, Master. Hudson Trevor    male   0.9167   \n",
       "2     1.0                     Allison, Miss. Helen Loraine  female   2.0000   \n",
       "3     1.0             Allison, Mr. Hudson Joshua Creighton    male  30.0000   \n",
       "4     1.0  Allison, Mrs. Hudson J C (Bessie Waldo Daniels)  female  25.0000   \n",
       "\n",
       "   sibsp  parch  ticket      fare    cabin embarked  boat   body  \\\n",
       "0    0.0    0.0   24160  211.3375       B5        S     2    NaN   \n",
       "1    1.0    2.0  113781  151.5500  C22 C26        S    11    NaN   \n",
       "2    1.0    2.0  113781  151.5500  C22 C26        S  None    NaN   \n",
       "3    1.0    2.0  113781  151.5500  C22 C26        S  None  135.0   \n",
       "4    1.0    2.0  113781  151.5500  C22 C26        S  None    NaN   \n",
       "\n",
       "                         home.dest  \n",
       "0                     St Louis, MO  \n",
       "1  Montreal, PQ / Chesterville, ON  \n",
       "2  Montreal, PQ / Chesterville, ON  \n",
       "3  Montreal, PQ / Chesterville, ON  \n",
       "4  Montreal, PQ / Chesterville, ON  "
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.datasets import fetch_openml\n",
    "\n",
    "X_ti, y_ti = fetch_openml(\"titanic\", version=1, return_X_y=True)\n",
    "X_ti.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0    1\n",
       "1    1\n",
       "2    0\n",
       "3    0\n",
       "4    0\n",
       "Name: survived, dtype: int64"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_ti = y_ti.astype(int)\n",
    "y_ti.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use the `standardize_dataset` function to handle some of the preprocessing and protected attribute storage. We will only look at \"age\", \"sex\", \"pclass\" (passenger class), and \"fare\". We also drop samples with missing values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>sex</th>\n",
       "      <th>pclass</th>\n",
       "      <th>fare</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>pclass</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1.0</th>\n",
       "      <td>29.0000</td>\n",
       "      <td>female</td>\n",
       "      <td>1.0</td>\n",
       "      <td>211.3375</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1.0</th>\n",
       "      <td>0.9167</td>\n",
       "      <td>male</td>\n",
       "      <td>1.0</td>\n",
       "      <td>151.5500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1.0</th>\n",
       "      <td>2.0000</td>\n",
       "      <td>female</td>\n",
       "      <td>1.0</td>\n",
       "      <td>151.5500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1.0</th>\n",
       "      <td>30.0000</td>\n",
       "      <td>male</td>\n",
       "      <td>1.0</td>\n",
       "      <td>151.5500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1.0</th>\n",
       "      <td>25.0000</td>\n",
       "      <td>female</td>\n",
       "      <td>1.0</td>\n",
       "      <td>151.5500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3.0</th>\n",
       "      <td>45.5000</td>\n",
       "      <td>male</td>\n",
       "      <td>3.0</td>\n",
       "      <td>7.2250</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3.0</th>\n",
       "      <td>14.5000</td>\n",
       "      <td>female</td>\n",
       "      <td>3.0</td>\n",
       "      <td>14.4542</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3.0</th>\n",
       "      <td>26.5000</td>\n",
       "      <td>male</td>\n",
       "      <td>3.0</td>\n",
       "      <td>7.2250</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3.0</th>\n",
       "      <td>27.0000</td>\n",
       "      <td>male</td>\n",
       "      <td>3.0</td>\n",
       "      <td>7.2250</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3.0</th>\n",
       "      <td>29.0000</td>\n",
       "      <td>male</td>\n",
       "      <td>3.0</td>\n",
       "      <td>7.8750</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>1045 rows × 4 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "            age     sex  pclass      fare\n",
       "pclass                                   \n",
       "1.0     29.0000  female     1.0  211.3375\n",
       "1.0      0.9167    male     1.0  151.5500\n",
       "1.0      2.0000  female     1.0  151.5500\n",
       "1.0     30.0000    male     1.0  151.5500\n",
       "1.0     25.0000  female     1.0  151.5500\n",
       "...         ...     ...     ...       ...\n",
       "3.0     45.5000    male     3.0    7.2250\n",
       "3.0     14.5000  female     3.0   14.4542\n",
       "3.0     26.5000    male     3.0    7.2250\n",
       "3.0     27.0000    male     3.0    7.2250\n",
       "3.0     29.0000    male     3.0    7.8750\n",
       "\n",
       "[1045 rows x 4 columns]"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_ti, y_ti = standardize_dataset(X_ti, prot_attr='pclass', target=y_ti, usecols=['age', 'sex', 'pclass', 'fare'])\n",
    "X_ti"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Class 1 passengers paid more for their tickets so we can assume they were of higher socio-economic status."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "pclass\n",
       "1.0    92.229358\n",
       "2.0    21.855044\n",
       "3.0    12.879299\n",
       "Name: fare, dtype: float64"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_ti.groupby(X_ti.pclass).fare.mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "They also were about twice as likely to survive as non-first class passengers:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5072128124523562"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "disparate_impact_ratio(y_ti, priv_group=1)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  },
  "vscode": {
   "interpreter": {
    "hash": "d0c5ced7753e77a483fec8ff7063075635521cce6e0bd54998c8f174742209dd"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
