{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SenSeI: Sensitive Set Invariance for Enforcing Individual Fairness\n",
    "\n",
    "SenSeI is an in-processing method for individual fairness. In this method, individual fairness is formulated as invariance on certain sensitive sets. SenSeI minimizes a transport-based regularizer that enforces this version of individual fairness."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from sklearn.metrics import accuracy_score, balanced_accuracy_score\n",
    "from sklearn.compose import make_column_transformer\n",
    "from sklearn.preprocessing import OneHotEncoder, StandardScaler, minmax_scale\n",
    "from sklearn.model_selection import train_test_split\n",
    "from skorch import NeuralNetClassifier\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from inFairness import distances\n",
    "from inFairness.auditor import SenSeIAuditor\n",
    "\n",
    "import aif360\n",
    "from aif360.sklearn.datasets import fetch_adult\n",
    "from aif360.sklearn.metrics import consistency_score\n",
    "from aif360.sklearn.inprocessing import SenSeI"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will be using the Adult income dataset for this tutorial. For pre-processing, we apply the usual one-hot encoding for categorical features and standard scaling for continuous features. We divide the data into train and test splits with a 80/20 ratio. Finally, note we convert the dtype to 32-bit floats as this is the default precision for torch models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>workclass_Federal-gov</th>\n",
       "      <th>workclass_Local-gov</th>\n",
       "      <th>workclass_Private</th>\n",
       "      <th>workclass_Self-emp-inc</th>\n",
       "      <th>workclass_Self-emp-not-inc</th>\n",
       "      <th>workclass_State-gov</th>\n",
       "      <th>workclass_Without-pay</th>\n",
       "      <th>marital-status_Divorced</th>\n",
       "      <th>marital-status_Married-AF-spouse</th>\n",
       "      <th>marital-status_Married-civ-spouse</th>\n",
       "      <th>...</th>\n",
       "      <th>race_Asian-Pac-Islander</th>\n",
       "      <th>race_Black</th>\n",
       "      <th>race_Other</th>\n",
       "      <th>race_White</th>\n",
       "      <th>sex_Male</th>\n",
       "      <th>age</th>\n",
       "      <th>education-num</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"3\" valign=\"top\">White</th>\n",
       "      <th>Male</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>-0.500934</td>\n",
       "      <td>1.114976</td>\n",
       "      <td>-0.146659</td>\n",
       "      <td>-0.219919</td>\n",
       "      <td>-0.080047</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Male</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.484367</td>\n",
       "      <td>1.114976</td>\n",
       "      <td>-0.146659</td>\n",
       "      <td>-0.219919</td>\n",
       "      <td>0.835685</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Male</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>-0.197765</td>\n",
       "      <td>-0.444540</td>\n",
       "      <td>-0.146659</td>\n",
       "      <td>-0.219919</td>\n",
       "      <td>0.752437</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Non-white</th>\n",
       "      <th>Male</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>-1.107273</td>\n",
       "      <td>-2.004057</td>\n",
       "      <td>-0.146659</td>\n",
       "      <td>-0.219919</td>\n",
       "      <td>1.418425</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"7\" valign=\"top\">White</th>\n",
       "      <th>Female</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-1.410443</td>\n",
       "      <td>-0.054661</td>\n",
       "      <td>-0.146659</td>\n",
       "      <td>-0.219919</td>\n",
       "      <td>-0.080047</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Male</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>-1.562028</td>\n",
       "      <td>-1.224298</td>\n",
       "      <td>-0.146659</td>\n",
       "      <td>-0.219919</td>\n",
       "      <td>-0.912532</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Male</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>-0.728312</td>\n",
       "      <td>-0.054661</td>\n",
       "      <td>-0.146659</td>\n",
       "      <td>-0.219919</td>\n",
       "      <td>1.418425</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Male</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.393875</td>\n",
       "      <td>-3.173694</td>\n",
       "      <td>-0.146659</td>\n",
       "      <td>-0.219919</td>\n",
       "      <td>-0.080047</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Female</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-0.728312</td>\n",
       "      <td>-0.444540</td>\n",
       "      <td>-0.146659</td>\n",
       "      <td>-0.219919</td>\n",
       "      <td>-0.080047</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Male</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>-0.425142</td>\n",
       "      <td>1.504856</td>\n",
       "      <td>-0.146659</td>\n",
       "      <td>-0.219919</td>\n",
       "      <td>0.752437</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>36826 rows × 45 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                  workclass_Federal-gov  workclass_Local-gov  \\\n",
       "race      sex                                                  \n",
       "White     Male                      0.0                  0.0   \n",
       "          Male                      0.0                  0.0   \n",
       "          Male                      0.0                  0.0   \n",
       "Non-white Male                      0.0                  0.0   \n",
       "White     Female                    0.0                  0.0   \n",
       "...                                 ...                  ...   \n",
       "          Male                      0.0                  0.0   \n",
       "          Male                      0.0                  0.0   \n",
       "          Male                      0.0                  0.0   \n",
       "          Female                    0.0                  0.0   \n",
       "          Male                      0.0                  0.0   \n",
       "\n",
       "                  workclass_Private  workclass_Self-emp-inc  \\\n",
       "race      sex                                                 \n",
       "White     Male                  1.0                     0.0   \n",
       "          Male                  0.0                     1.0   \n",
       "          Male                  0.0                     0.0   \n",
       "Non-white Male                  1.0                     0.0   \n",
       "White     Female                1.0                     0.0   \n",
       "...                             ...                     ...   \n",
       "          Male                  1.0                     0.0   \n",
       "          Male                  1.0                     0.0   \n",
       "          Male                  1.0                     0.0   \n",
       "          Female                1.0                     0.0   \n",
       "          Male                  1.0                     0.0   \n",
       "\n",
       "                  workclass_Self-emp-not-inc  workclass_State-gov  \\\n",
       "race      sex                                                       \n",
       "White     Male                           0.0                  0.0   \n",
       "          Male                           0.0                  0.0   \n",
       "          Male                           1.0                  0.0   \n",
       "Non-white Male                           0.0                  0.0   \n",
       "White     Female                         0.0                  0.0   \n",
       "...                                      ...                  ...   \n",
       "          Male                           0.0                  0.0   \n",
       "          Male                           0.0                  0.0   \n",
       "          Male                           0.0                  0.0   \n",
       "          Female                         0.0                  0.0   \n",
       "          Male                           0.0                  0.0   \n",
       "\n",
       "                  workclass_Without-pay  marital-status_Divorced  \\\n",
       "race      sex                                                      \n",
       "White     Male                      0.0                      0.0   \n",
       "          Male                      0.0                      0.0   \n",
       "          Male                      0.0                      0.0   \n",
       "Non-white Male                      0.0                      0.0   \n",
       "White     Female                    0.0                      0.0   \n",
       "...                                 ...                      ...   \n",
       "          Male                      0.0                      0.0   \n",
       "          Male                      0.0                      1.0   \n",
       "          Male                      0.0                      1.0   \n",
       "          Female                    0.0                      0.0   \n",
       "          Male                      0.0                      0.0   \n",
       "\n",
       "                  marital-status_Married-AF-spouse  \\\n",
       "race      sex                                        \n",
       "White     Male                                 0.0   \n",
       "          Male                                 0.0   \n",
       "          Male                                 0.0   \n",
       "Non-white Male                                 0.0   \n",
       "White     Female                               0.0   \n",
       "...                                            ...   \n",
       "          Male                                 0.0   \n",
       "          Male                                 0.0   \n",
       "          Male                                 0.0   \n",
       "          Female                               0.0   \n",
       "          Male                                 0.0   \n",
       "\n",
       "                  marital-status_Married-civ-spouse  ...  \\\n",
       "race      sex                                        ...   \n",
       "White     Male                                  1.0  ...   \n",
       "          Male                                  1.0  ...   \n",
       "          Male                                  1.0  ...   \n",
       "Non-white Male                                  0.0  ...   \n",
       "White     Female                                0.0  ...   \n",
       "...                                             ...  ...   \n",
       "          Male                                  0.0  ...   \n",
       "          Male                                  0.0  ...   \n",
       "          Male                                  0.0  ...   \n",
       "          Female                                0.0  ...   \n",
       "          Male                                  1.0  ...   \n",
       "\n",
       "                  race_Asian-Pac-Islander  race_Black  race_Other  race_White  \\\n",
       "race      sex                                                                   \n",
       "White     Male                        0.0         0.0         0.0         1.0   \n",
       "          Male                        0.0         0.0         0.0         1.0   \n",
       "          Male                        0.0         0.0         0.0         1.0   \n",
       "Non-white Male                        0.0         0.0         1.0         0.0   \n",
       "White     Female                      0.0         0.0         0.0         1.0   \n",
       "...                                   ...         ...         ...         ...   \n",
       "          Male                        0.0         0.0         0.0         1.0   \n",
       "          Male                        0.0         0.0         0.0         1.0   \n",
       "          Male                        0.0         0.0         0.0         1.0   \n",
       "          Female                      0.0         0.0         0.0         1.0   \n",
       "          Male                        0.0         0.0         0.0         1.0   \n",
       "\n",
       "                  sex_Male       age  education-num  capital-gain  \\\n",
       "race      sex                                                       \n",
       "White     Male         1.0 -0.500934       1.114976     -0.146659   \n",
       "          Male         1.0  0.484367       1.114976     -0.146659   \n",
       "          Male         1.0 -0.197765      -0.444540     -0.146659   \n",
       "Non-white Male         1.0 -1.107273      -2.004057     -0.146659   \n",
       "White     Female       0.0 -1.410443      -0.054661     -0.146659   \n",
       "...                    ...       ...            ...           ...   \n",
       "          Male         1.0 -1.562028      -1.224298     -0.146659   \n",
       "          Male         1.0 -0.728312      -0.054661     -0.146659   \n",
       "          Male         1.0  1.393875      -3.173694     -0.146659   \n",
       "          Female       0.0 -0.728312      -0.444540     -0.146659   \n",
       "          Male         1.0 -0.425142       1.504856     -0.146659   \n",
       "\n",
       "                  capital-loss  hours-per-week  \n",
       "race      sex                                   \n",
       "White     Male       -0.219919       -0.080047  \n",
       "          Male       -0.219919        0.835685  \n",
       "          Male       -0.219919        0.752437  \n",
       "Non-white Male       -0.219919        1.418425  \n",
       "White     Female     -0.219919       -0.080047  \n",
       "...                        ...             ...  \n",
       "          Male       -0.219919       -0.912532  \n",
       "          Male       -0.219919        1.418425  \n",
       "          Male       -0.219919       -0.080047  \n",
       "          Female     -0.219919       -0.080047  \n",
       "          Male       -0.219919        0.752437  \n",
       "\n",
       "[36826 rows x 45 columns]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X, y, _ = fetch_adult(dropcols=['native-country', 'education'])\n",
    "(X_train, X_test,\n",
    " y_train, y_test) = train_test_split(X, y, train_size=0.8, random_state=123)\n",
    "\n",
    "pre = make_column_transformer(\n",
    "        (OneHotEncoder(sparse=False, drop='if_binary'), X_train.dtypes == 'category'),\n",
    "        (StandardScaler(), X_train.dtypes != 'category'),\n",
    "        verbose_feature_names_out=False)\n",
    "# NOTE: the torch models will only handle 32-bit floats\n",
    "X_train = pd.DataFrame(pre.fit_transform(X_train), index=X_train.index,\n",
    "                       columns=pre.get_feature_names_out(), dtype='float32')\n",
    "X_test = pd.DataFrame(pre.transform(X_test), index=X_test.index,\n",
    "                      columns=pre.get_feature_names_out(), dtype='float32')\n",
    "X_train"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "At this point, we can create a copy of the test data with the spouse variable flipped. This will be used in a counterfactual assessment of the model later."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_test_spouse_flipped = X_test.copy()\n",
    "X_test_spouse_flipped.relationship_Wife = 1 - X_test_spouse_flipped.relationship_Wife"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Another thing we need to keep track of for later is the protected attribute indices."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[38, 39]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "protected_vars = ['race_White', 'sex_Male']\n",
    "protected_idxs = [X_train.columns.get_loc(var) for var in protected_vars]\n",
    "protected_idxs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is the neural network we will use for the following experiment. It is a simple fully-connected network with ReLU activations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model(nn.Module):\n",
    "    def __init__(self, input_size, output_size=1):\n",
    "        super().__init__()\n",
    "        self.fc1 = nn.Linear(input_size, 100)\n",
    "        self.fc2 = nn.Linear(100, 100)\n",
    "        self.fcout = nn.Linear(100, output_size)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fcout(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Standard training\n",
    "\n",
    "Now let's train our model with no individual fairness loss. We can use the skorch library to convert the PyTorch model to a sklearn-friendly estimator.\n",
    "\n",
    "Note: we could alternatively set `output_size = 2` and `criterion = nn.CrossEntropyLoss`. SenSeI will encode the nominal values automatically, though, so this way we skip that step later since for a binary y it assumes the loss is BCE."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "EPOCHS = 10\n",
    "input_size = X_train.shape[1]\n",
    "output_size = 1\n",
    "optimizer = torch.optim.Adam\n",
    "criterion = nn.BCEWithLogitsLoss\n",
    "lr = 1e-3\n",
    "device = torch.device('cpu')\n",
    "\n",
    "network_standard = NeuralNetClassifier(\n",
    "    Model,\n",
    "    module__input_size=input_size,\n",
    "    module__output_size=output_size,\n",
    "    max_epochs=EPOCHS,\n",
    "    criterion=criterion,\n",
    "    optimizer=optimizer,\n",
    "    lr=lr,\n",
    "    train_split=None,\n",
    "    # this is not strictly necessary; it just handles the conversion from DataFrame -> ndarray\n",
    "    dataset=aif360.sklearn.inprocessing.infairness.Dataset,\n",
    "    iterator_train__shuffle=True, # Shuffle training data on each epoch\n",
    "    device=device,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "skorch does not automatically encode the targets so we need to convert them to 0/1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_train_enc = y_train.cat.codes.astype('float32')\n",
    "y_test_enc = y_test.cat.codes.astype('float32')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  epoch    train_loss     dur\n",
      "-------  ------------  ------\n",
      "      1        \u001b[36m0.3598\u001b[0m  0.5932\n",
      "      2        \u001b[36m0.3162\u001b[0m  0.5971\n",
      "      3        \u001b[36m0.3139\u001b[0m  0.6378\n",
      "      4        \u001b[36m0.3112\u001b[0m  0.4824\n",
      "      5        \u001b[36m0.3091\u001b[0m  0.5530\n",
      "      6        \u001b[36m0.3078\u001b[0m  0.5659\n",
      "      7        \u001b[36m0.3067\u001b[0m  0.5196\n",
      "      8        \u001b[36m0.3052\u001b[0m  0.4765\n",
      "      9        \u001b[36m0.3036\u001b[0m  0.5145\n",
      "     10        \u001b[36m0.3025\u001b[0m  0.5702\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<class 'skorch.classifier.NeuralNetClassifier'>[initialized](\n",
       "  module_=Model(\n",
       "    (fc1): Linear(in_features=45, out_features=100, bias=True)\n",
       "    (fc2): Linear(in_features=100, out_features=100, bias=True)\n",
       "    (fcout): Linear(in_features=100, out_features=1, bias=True)\n",
       "  ),\n",
       ")"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# the shape of y also needs to match the output of the network so we convert it to 2D first\n",
    "network_standard.fit(X_train, y_train_enc.to_frame())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As a baseline, let's print the accuracy, balanced accuracy, consistency with nearest neighbors, and the consistency of the predictions when the spouse column is flipped. The spouse feature should have no causal impact on the prediction so for an individually fair model, this should be close to 100%."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 85.13%\n",
      "Balanced accuracy: 77.59%\n",
      "Consistency: 93.79%\n",
      "Spouse consistency: 92.77%\n"
     ]
    }
   ],
   "source": [
    "y_pred_standard = network_standard.predict(X_test)\n",
    "accuracy = accuracy_score(y_test_enc, y_pred_standard)\n",
    "balanced_acc = balanced_accuracy_score(y_test_enc, y_pred_standard)\n",
    "consistency = consistency_score(minmax_scale(X_test), y_pred_standard.ravel())\n",
    "\n",
    "y_pred_flipped = network_standard.predict(X_test_spouse_flipped)\n",
    "spouse_consistency = accuracy_score(y_pred_standard, y_pred_flipped)\n",
    "\n",
    "print(f'Accuracy: {accuracy:.2%}')\n",
    "print(f'Balanced accuracy: {balanced_acc:.2%}')\n",
    "print(f'Consistency: {consistency:.2%}')\n",
    "print(f'Spouse consistency: {spouse_consistency:.2%}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Individually fair training\n",
    "\n",
    "Now let's train an individually fair model using SenSeI. First, we must define the distance functions we will be using in both the input and output spaces. For the input (X) space, we will use the Logistic Regression Sensitive Subspace distance metric and for the output (y) space, we will use a simple Squared Euclidean distance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "distance_x = distances.LogisticRegSensitiveSubspace()\n",
    "distance_y = distances.SquaredEuclideanDistance()\n",
    "\n",
    "X_train_tensor = torch.as_tensor(X_train.to_numpy())\n",
    "distance_x.fit(X_train_tensor, protected_idxs=protected_idxs)\n",
    "distance_y.fit(num_dims=output_size)\n",
    "\n",
    "distance_x.to(device)\n",
    "distance_y.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `SenSeI` class inherits from skorch so it looks very similar to the standard training setup."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "rho = 2.5\n",
    "eps = 0.1\n",
    "auditor_nsteps = 100\n",
    "auditor_lr = 1e-3\n",
    "\n",
    "network_fair = SenSeI(\n",
    "    Model,\n",
    "    module__input_size=input_size,\n",
    "    module__output_size=output_size,\n",
    "    distance_x=distance_x,\n",
    "    distance_y=distance_y,\n",
    "    rho=rho,\n",
    "    eps=eps,\n",
    "    auditor_nsteps=auditor_nsteps,\n",
    "    auditor_lr=auditor_lr,\n",
    "    max_epochs=EPOCHS,\n",
    "    criterion=criterion,\n",
    "    optimizer=optimizer,\n",
    "    lr=lr,\n",
    "    device=device,\n",
    "    iterator_train__shuffle=True, # Shuffle training data on each epoch\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  epoch    train_loss      dur\n",
      "-------  ------------  -------\n",
      "      1        \u001b[36m0.5119\u001b[0m  23.9251\n",
      "      2        \u001b[36m0.4335\u001b[0m  23.2762\n",
      "      3        \u001b[36m0.4028\u001b[0m  23.8665\n",
      "      4        \u001b[36m0.3949\u001b[0m  24.0242\n",
      "      5        \u001b[36m0.3914\u001b[0m  23.4092\n",
      "      6        \u001b[36m0.3886\u001b[0m  23.2872\n",
      "      7        \u001b[36m0.3864\u001b[0m  23.1221\n",
      "      8        \u001b[36m0.3853\u001b[0m  23.8168\n",
      "      9        \u001b[36m0.3836\u001b[0m  24.1062\n",
      "     10        \u001b[36m0.3822\u001b[0m  24.3297\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<class 'aif360.sklearn.inprocessing.infairness.SenSeI'>[initialized](\n",
       "  module_=SenSeI(\n",
       "    (distance_x): LogisticRegSensitiveSubspace()\n",
       "    (distance_y): SquaredEuclideanDistance()\n",
       "    (network): Model(\n",
       "      (fc1): Linear(in_features=45, out_features=100, bias=True)\n",
       "      (fc2): Linear(in_features=100, out_features=100, bias=True)\n",
       "      (fcout): Linear(in_features=100, out_features=1, bias=True)\n",
       "    )\n",
       "    (loss_fn): BCEWithLogitsLoss()\n",
       "  ),\n",
       ")"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "network_fair.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This time when we run the metrics, the spouse consistency is almost exactly 100% while accuracy and balanced accuracy are only slightly lower and nearest neighbor consistency is slightly higher. Great!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 83.70%\n",
      "Balanced accuracy: 73.62%\n",
      "Consistency: 95.79%\n",
      "Spouse consistency: 99.97%\n"
     ]
    }
   ],
   "source": [
    "y_pred_fair = network_fair.predict(X_test)\n",
    "accuracy = accuracy_score(y_test, y_pred_fair)\n",
    "balanced_acc = balanced_accuracy_score(y_test, y_pred_fair)\n",
    "consistency = consistency_score(minmax_scale(X_test), y_pred_fair.ravel() == '>50K')\n",
    "\n",
    "y_pred_fair_flipped = network_fair.predict(X_test_spouse_flipped)\n",
    "spouse_consistency = accuracy_score(y_pred_fair, y_pred_fair_flipped)\n",
    "\n",
    "print(f'Accuracy: {accuracy:.2%}')\n",
    "print(f'Balanced accuracy: {balanced_acc:.2%}')\n",
    "print(f'Consistency: {consistency:.2%}')\n",
    "print(f'Spouse consistency: {spouse_consistency:.2%}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Individual fairness auditing\n",
    "\n",
    "Let's now audit the two models and check for their individual fairness compliance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "invalid value encountered in true_divide\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss ratio (standard model) : 221.515. Is model fair: False\n",
      "Loss ratio (fair model) : 1.000. Is model fair: True\n"
     ]
    }
   ],
   "source": [
    "audit_nsteps = 500\n",
    "audit_lr = 0.001\n",
    "loss_fn = F.binary_cross_entropy_with_logits\n",
    "\n",
    "auditor = SenSeIAuditor(distance_x=distance_x, distance_y=distance_y,\n",
    "    num_steps=audit_nsteps, lr=audit_lr, max_noise=0.5, min_noise=-0.5)\n",
    "\n",
    "X_test_tensor = torch.as_tensor(X_test.to_numpy())\n",
    "y_test_tensor = torch.as_tensor(y_test_enc.to_numpy().reshape(-1, 1))\n",
    "audit_result_stdmodel = auditor.audit(network_standard.module_, X_test_tensor,\n",
    "                                      y_test_tensor, loss_fn,\n",
    "                                      audit_threshold=1.15, lambda_param=50.0)\n",
    "audit_result_fairmodel = auditor.audit(network_fair.module_.network,\n",
    "                                       X_test_tensor, y_test_tensor, loss_fn,\n",
    "                                       audit_threshold=1.15, lambda_param=50.0)\n",
    "\n",
    "print(f\"Loss ratio (standard model) : {audit_result_stdmodel.lower_bound:.3f}. \"\n",
    "      f\"Is model fair: {audit_result_stdmodel.is_model_fair}\")\n",
    "print(f\"Loss ratio (fair model) : {audit_result_fairmodel.lower_bound:.3f}. \"\n",
    "      f\"Is model fair: {audit_result_fairmodel.is_model_fair}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As signified by these numbers, the fair model is fairer than the standard model."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.7 ('aif360')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "vscode": {
   "interpreter": {
    "hash": "d0c5ced7753e77a483fec8ff7063075635521cce6e0bd54998c8f174742209dd"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
