{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook compares the overfitting of Fairlearn Vs AnonFair on a resampled version of the [myocardial infarction dataset](https://archive.ics.uci.edu/dataset/579/myocardial+infarction+complications).\n",
    "\n",
    "We use sex as the protected attribute.\n",
    "\n",
    "The initial dataset is balanced, and to induce unfairness in the downstream classifier, we drop half the datapoints that satisfy sex=1  and target_label=0.\n",
    "\n",
    "Because the dataset is relatively high-dimensional (dims ~= 100) with around 1,000 training points, xgboost overfits perfectly obtaining zero error on the train set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dataset_loader\n",
    "from anonfair import FairPredictor, performance\n",
    "from anonfair import group_metrics as gm\n",
    "import xgboost\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "pd.options.mode.chained_assignment = None  # default='warn'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "sampler=dataset_loader.resample(1,0,0.5)\n",
    "train,val,test = dataset_loader.myocardial_infarction(resample=sampler,seed=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now train XGBoost, and specify a fair predictor over the validation set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "classifier = xgboost.XGBClassifier().fit(X=train['data'], y=train['target'])\n",
    "fpred=FairPredictor(classifier,val)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We call fit to enforce equal opportunity."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "fpred.fit(gm.accuracy,gm.equal_opportunity,0.02)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And evaluate fairness on validation data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>original</th>\n",
       "      <th>updated</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Statistical Parity</th>\n",
       "      <td>0.053529</td>\n",
       "      <td>0.020196</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Predictive Parity</th>\n",
       "      <td>0.035982</td>\n",
       "      <td>0.160000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Equal Opportunity</th>\n",
       "      <td>0.214719</td>\n",
       "      <td>0.006061</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Average Group Difference in False Negative Rate</th>\n",
       "      <td>0.214719</td>\n",
       "      <td>0.006061</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Equalized Odds</th>\n",
       "      <td>0.108903</td>\n",
       "      <td>0.019030</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Conditional Use Accuracy</th>\n",
       "      <td>0.042487</td>\n",
       "      <td>0.080671</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Average Group Difference in Accuracy</th>\n",
       "      <td>0.043367</td>\n",
       "      <td>0.024065</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Treatment Equality</th>\n",
       "      <td>0.250000</td>\n",
       "      <td>0.285714</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                 original   updated\n",
       "Statistical Parity                               0.053529  0.020196\n",
       "Predictive Parity                                0.035982  0.160000\n",
       "Equal Opportunity                                0.214719  0.006061\n",
       "Average Group Difference in False Negative Rate  0.214719  0.006061\n",
       "Equalized Odds                                   0.108903  0.019030\n",
       "Conditional Use Accuracy                         0.042487  0.080671\n",
       "Average Group Difference in Accuracy             0.043367  0.024065\n",
       "Treatment Equality                               0.250000  0.285714"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fpred.evaluate_fairness()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And on the test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>original</th>\n",
       "      <th>updated</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Statistical Parity</th>\n",
       "      <td>0.002239</td>\n",
       "      <td>0.069792</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Predictive Parity</th>\n",
       "      <td>0.208333</td>\n",
       "      <td>0.240000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Equal Opportunity</th>\n",
       "      <td>0.088235</td>\n",
       "      <td>0.176471</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Average Group Difference in False Negative Rate</th>\n",
       "      <td>0.088235</td>\n",
       "      <td>0.176471</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Equalized Odds</th>\n",
       "      <td>0.064279</td>\n",
       "      <td>0.112429</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Conditional Use Accuracy</th>\n",
       "      <td>0.112137</td>\n",
       "      <td>0.141967</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Average Group Difference in Accuracy</th>\n",
       "      <td>0.044950</td>\n",
       "      <td>0.009946</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Treatment Equality</th>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.400000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                 original   updated\n",
       "Statistical Parity                               0.002239  0.069792\n",
       "Predictive Parity                                0.208333  0.240000\n",
       "Equal Opportunity                                0.088235  0.176471\n",
       "Average Group Difference in False Negative Rate  0.088235  0.176471\n",
       "Equalized Odds                                   0.064279  0.112429\n",
       "Conditional Use Accuracy                         0.112137  0.141967\n",
       "Average Group Difference in Accuracy             0.044950  0.009946\n",
       "Treatment Equality                               0.333333  0.400000"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fpred.evaluate_fairness(test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now check validation performance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>original</th>\n",
       "      <th>updated</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Accuracy</th>\n",
       "      <td>0.895765</td>\n",
       "      <td>0.899023</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Balanced Accuracy</th>\n",
       "      <td>0.806793</td>\n",
       "      <td>0.793102</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>F1 score</th>\n",
       "      <td>0.733333</td>\n",
       "      <td>0.725664</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>MCC</th>\n",
       "      <td>0.679293</td>\n",
       "      <td>0.688249</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Precision</th>\n",
       "      <td>0.846154</td>\n",
       "      <td>0.911111</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Recall</th>\n",
       "      <td>0.647059</td>\n",
       "      <td>0.602941</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ROC AUC</th>\n",
       "      <td>0.904504</td>\n",
       "      <td>0.878107</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                   original   updated\n",
       "Accuracy           0.895765  0.899023\n",
       "Balanced Accuracy  0.806793  0.793102\n",
       "F1 score           0.733333  0.725664\n",
       "MCC                0.679293  0.688249\n",
       "Precision          0.846154  0.911111\n",
       "Recall             0.647059  0.602941\n",
       "ROC AUC            0.904504  0.878107"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fpred.evaluate()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And on the test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>original</th>\n",
       "      <th>updated</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Accuracy</th>\n",
       "      <td>0.895082</td>\n",
       "      <td>0.862295</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Balanced Accuracy</th>\n",
       "      <td>0.790922</td>\n",
       "      <td>0.722636</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>F1 score</th>\n",
       "      <td>0.719298</td>\n",
       "      <td>0.603774</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>MCC</th>\n",
       "      <td>0.676716</td>\n",
       "      <td>0.561185</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Precision</th>\n",
       "      <td>0.891304</td>\n",
       "      <td>0.842105</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Recall</th>\n",
       "      <td>0.602941</td>\n",
       "      <td>0.470588</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ROC AUC</th>\n",
       "      <td>0.908724</td>\n",
       "      <td>0.850583</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                   original   updated\n",
       "Accuracy           0.895082  0.862295\n",
       "Balanced Accuracy  0.790922  0.722636\n",
       "F1 score           0.719298  0.603774\n",
       "MCC                0.676716  0.561185\n",
       "Precision          0.891304  0.842105\n",
       "Recall             0.602941  0.470588\n",
       "ROC AUC            0.908724  0.850583"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fpred.evaluate(test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now run fairlearn on the same data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>#sk-container-id-1 {color: black;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>ExponentiatedGradient(constraints=&lt;fairlearn.reductions._moments.utility_parity.TruePositiveRateParity object at 0x164d29eb0&gt;,\n",
       "                      estimator=XGBClassifier(base_score=None, booster=None,\n",
       "                                              callbacks=None,\n",
       "                                              colsample_bylevel=None,\n",
       "                                              colsample_bynode=None,\n",
       "                                              colsample_bytree=None,\n",
       "                                              device=None,\n",
       "                                              early_stopping_rounds=None,\n",
       "                                              enable_categorical=False,\n",
       "                                              eval_metric=None,\n",
       "                                              feature_typ...\n",
       "                                              grow_policy=None,\n",
       "                                              importance_type=None,\n",
       "                                              interaction_constraints=None,\n",
       "                                              learning_rate=None, max_bin=None,\n",
       "                                              max_cat_threshold=None,\n",
       "                                              max_cat_to_onehot=None,\n",
       "                                              max_delta_step=None,\n",
       "                                              max_depth=None, max_leaves=None,\n",
       "                                              min_child_weight=None,\n",
       "                                              missing=nan,\n",
       "                                              monotone_constraints=None,\n",
       "                                              multi_strategy=None,\n",
       "                                              n_estimators=None, n_jobs=None,\n",
       "                                              num_parallel_tree=None,\n",
       "                                              random_state=None, ...),\n",
       "                      nu=0.0)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" ><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">ExponentiatedGradient</label><div class=\"sk-toggleable__content\"><pre>ExponentiatedGradient(constraints=&lt;fairlearn.reductions._moments.utility_parity.TruePositiveRateParity object at 0x164d29eb0&gt;,\n",
       "                      estimator=XGBClassifier(base_score=None, booster=None,\n",
       "                                              callbacks=None,\n",
       "                                              colsample_bylevel=None,\n",
       "                                              colsample_bynode=None,\n",
       "                                              colsample_bytree=None,\n",
       "                                              device=None,\n",
       "                                              early_stopping_rounds=None,\n",
       "                                              enable_categorical=False,\n",
       "                                              eval_metric=None,\n",
       "                                              feature_typ...\n",
       "                                              grow_policy=None,\n",
       "                                              importance_type=None,\n",
       "                                              interaction_constraints=None,\n",
       "                                              learning_rate=None, max_bin=None,\n",
       "                                              max_cat_threshold=None,\n",
       "                                              max_cat_to_onehot=None,\n",
       "                                              max_delta_step=None,\n",
       "                                              max_depth=None, max_leaves=None,\n",
       "                                              min_child_weight=None,\n",
       "                                              missing=nan,\n",
       "                                              monotone_constraints=None,\n",
       "                                              multi_strategy=None,\n",
       "                                              n_estimators=None, n_jobs=None,\n",
       "                                              num_parallel_tree=None,\n",
       "                                              random_state=None, ...),\n",
       "                      nu=0.0)</pre></div></div></div><div class=\"sk-parallel\"><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-2\" type=\"checkbox\" ><label for=\"sk-estimator-id-2\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">estimator: XGBClassifier</label><div class=\"sk-toggleable__content\"><pre>XGBClassifier(base_score=None, booster=None, callbacks=None,\n",
       "              colsample_bylevel=None, colsample_bynode=None,\n",
       "              colsample_bytree=None, device=None, early_stopping_rounds=None,\n",
       "              enable_categorical=False, eval_metric=None, feature_types=None,\n",
       "              gamma=None, grow_policy=None, importance_type=None,\n",
       "              interaction_constraints=None, learning_rate=None, max_bin=None,\n",
       "              max_cat_threshold=None, max_cat_to_onehot=None,\n",
       "              max_delta_step=None, max_depth=None, max_leaves=None,\n",
       "              min_child_weight=None, missing=nan, monotone_constraints=None,\n",
       "              multi_strategy=None, n_estimators=None, n_jobs=None,\n",
       "              num_parallel_tree=None, random_state=None, ...)</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-3\" type=\"checkbox\" ><label for=\"sk-estimator-id-3\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">XGBClassifier</label><div class=\"sk-toggleable__content\"><pre>XGBClassifier(base_score=None, booster=None, callbacks=None,\n",
       "              colsample_bylevel=None, colsample_bynode=None,\n",
       "              colsample_bytree=None, device=None, early_stopping_rounds=None,\n",
       "              enable_categorical=False, eval_metric=None, feature_types=None,\n",
       "              gamma=None, grow_policy=None, importance_type=None,\n",
       "              interaction_constraints=None, learning_rate=None, max_bin=None,\n",
       "              max_cat_threshold=None, max_cat_to_onehot=None,\n",
       "              max_delta_step=None, max_depth=None, max_leaves=None,\n",
       "              min_child_weight=None, missing=nan, monotone_constraints=None,\n",
       "              multi_strategy=None, n_estimators=None, n_jobs=None,\n",
       "              num_parallel_tree=None, random_state=None, ...)</pre></div></div></div></div></div></div></div></div></div></div>"
      ],
      "text/plain": [
       "ExponentiatedGradient(constraints=<fairlearn.reductions._moments.utility_parity.TruePositiveRateParity object at 0x164d29eb0>,\n",
       "                      estimator=XGBClassifier(base_score=None, booster=None,\n",
       "                                              callbacks=None,\n",
       "                                              colsample_bylevel=None,\n",
       "                                              colsample_bynode=None,\n",
       "                                              colsample_bytree=None,\n",
       "                                              device=None,\n",
       "                                              early_stopping_rounds=None,\n",
       "                                              enable_categorical=False,\n",
       "                                              eval_metric=None,\n",
       "                                              feature_typ...\n",
       "                                              grow_policy=None,\n",
       "                                              importance_type=None,\n",
       "                                              interaction_constraints=None,\n",
       "                                              learning_rate=None, max_bin=None,\n",
       "                                              max_cat_threshold=None,\n",
       "                                              max_cat_to_onehot=None,\n",
       "                                              max_delta_step=None,\n",
       "                                              max_depth=None, max_leaves=None,\n",
       "                                              min_child_weight=None,\n",
       "                                              missing=nan,\n",
       "                                              monotone_constraints=None,\n",
       "                                              multi_strategy=None,\n",
       "                                              n_estimators=None, n_jobs=None,\n",
       "                                              num_parallel_tree=None,\n",
       "                                              random_state=None, ...),\n",
       "                      nu=0.0)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from fairlearn.reductions import TruePositiveRateParity, ExponentiatedGradient\n",
    "mitagator = ExponentiatedGradient(xgboost.XGBClassifier(),TruePositiveRateParity())\n",
    "mitagator.fit(X=train['data'],y=train['target'],sensitive_features=train['data']['SEX'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To evaluate fairlearn, we write a helper function to evaluate performance and fairness on train or test, and concat the outputs together.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>train</th>\n",
       "      <th>test</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Accuracy</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.895082</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Balanced Accuracy</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.790922</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>F1 score</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.719298</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>MCC</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.676716</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Precision</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.891304</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Recall</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.602941</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ROC AUC</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.790922</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Statistical Parity</th>\n",
       "      <td>0.014158</td>\n",
       "      <td>0.002239</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Predictive Parity</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.208333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Equal Opportunity</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.088235</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Average Group Difference in False Negative Rate</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.088235</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Equalized Odds</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.064279</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Conditional Use Accuracy</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.112137</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Average Group Difference in Accuracy</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.044950</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Treatment Equality</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.333333</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                    train      test\n",
       "Accuracy                                         1.000000  0.895082\n",
       "Balanced Accuracy                                1.000000  0.790922\n",
       "F1 score                                         1.000000  0.719298\n",
       "MCC                                              1.000000  0.676716\n",
       "Precision                                        1.000000  0.891304\n",
       "Recall                                           1.000000  0.602941\n",
       "ROC AUC                                          1.000000  0.790922\n",
       "Statistical Parity                               0.014158  0.002239\n",
       "Predictive Parity                                0.000000  0.208333\n",
       "Equal Opportunity                                0.000000  0.088235\n",
       "Average Group Difference in False Negative Rate  0.000000  0.088235\n",
       "Equalized Odds                                   0.000000  0.064279\n",
       "Conditional Use Accuracy                         0.000000  0.112137\n",
       "Average Group Difference in Accuracy             0.000000  0.044950\n",
       "Treatment Equality                               0.000000  0.333333"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def eval(train, classifier=mitagator):\n",
    "    return pd.concat((performance.evaluate(train['target'], classifier.predict(train['data'])),\n",
    "                      performance.evaluate_fairness(train['target'], classifier.predict(train['data']), train['groups'])),axis=0)\n",
    "\n",
    "out = pd.concat((eval(train), eval(test)), axis=1)\n",
    "out.columns = ['train', 'test']\n",
    "out"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Evaluating the initially trained baseline classifier we find that, as expected, fairlearn did not alter the performance or unfairness of the classifier."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>train</th>\n",
       "      <th>test</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Accuracy</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.895082</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Balanced Accuracy</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.790922</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>F1 score</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.719298</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>MCC</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.676716</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Precision</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.891304</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Recall</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.602941</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ROC AUC</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.790922</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Statistical Parity</th>\n",
       "      <td>0.014158</td>\n",
       "      <td>0.002239</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Predictive Parity</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.208333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Equal Opportunity</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.088235</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Average Group Difference in False Negative Rate</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.088235</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Equalized Odds</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.064279</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Conditional Use Accuracy</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.112137</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Average Group Difference in Accuracy</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.044950</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Treatment Equality</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.333333</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                    train      test\n",
       "Accuracy                                         1.000000  0.895082\n",
       "Balanced Accuracy                                1.000000  0.790922\n",
       "F1 score                                         1.000000  0.719298\n",
       "MCC                                              1.000000  0.676716\n",
       "Precision                                        1.000000  0.891304\n",
       "Recall                                           1.000000  0.602941\n",
       "ROC AUC                                          1.000000  0.790922\n",
       "Statistical Parity                               0.014158  0.002239\n",
       "Predictive Parity                                0.000000  0.208333\n",
       "Equal Opportunity                                0.000000  0.088235\n",
       "Average Group Difference in False Negative Rate  0.000000  0.088235\n",
       "Equalized Odds                                   0.000000  0.064279\n",
       "Conditional Use Accuracy                         0.000000  0.112137\n",
       "Average Group Difference in Accuracy             0.000000  0.044950\n",
       "Treatment Equality                               0.000000  0.333333"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out = pd.concat((eval(train, classifier), eval(test, classifier)), axis=1)\n",
    "out.columns = ['train', 'test']\n",
    "out"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
