{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook compares the overfitting of Fairlearn Vs AnonFair using random forests and decision trees on the adult dataset.\n",
    "\n",
    "We use sex as the protected attribute.\n",
    "\n",
    "Even on this low-dimensional data, the default parameters of scikit-learn cause both decision trees and random forests to overfit. \n",
    "\n",
    "This can be adjusted by specifying a low maximimal tree depth. The examples in the Fairlearn documentation typically use a depth of 4 on adult. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dataset_loader\n",
    "from anonfair import FairPredictor, performance\n",
    "from anonfair import group_metrics as gm\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.tree import DecisionTreeClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "train,val,test = dataset_loader.adult()\n",
    "basetree = DecisionTreeClassifier().fit(X=train['data'], y=train['target'])\n",
    "baseforest = RandomForestClassifier().fit(X=train['data'], y=train['target'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now specify a fair predictors over the validation set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The outputs of a decision tree are all 0 or 1, so we add Gaussian noise to allow thresholding to work\n",
    "ftree=FairPredictor(basetree,val,add_noise=0.001)\n",
    "fforest=FairPredictor(baseforest,val)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We call fit to enforce equal opportunity."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "ftree.fit(gm.accuracy,gm.equal_opportunity,0.02)\n",
    "fforest.fit(gm.accuracy,gm.equal_opportunity,0.02)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now focus on trees only.\n",
    "And evaluate fairness on validation data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>original</th>\n",
       "      <th>updated</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Statistical Parity</th>\n",
       "      <td>0.180488</td>\n",
       "      <td>0.152779</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Predictive Parity</th>\n",
       "      <td>0.115453</td>\n",
       "      <td>0.122845</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Equal Opportunity</th>\n",
       "      <td>0.062360</td>\n",
       "      <td>0.007663</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Average Group Difference in False Negative Rate</th>\n",
       "      <td>0.062360</td>\n",
       "      <td>0.007663</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Equalized Odds</th>\n",
       "      <td>0.079554</td>\n",
       "      <td>0.044270</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Conditional Use Accuracy</th>\n",
       "      <td>0.113697</td>\n",
       "      <td>0.125512</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Average Group Difference in Accuracy</th>\n",
       "      <td>0.122871</td>\n",
       "      <td>0.127511</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Treatment Equality</th>\n",
       "      <td>0.231433</td>\n",
       "      <td>0.467860</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                 original   updated\n",
       "Statistical Parity                               0.180488  0.152779\n",
       "Predictive Parity                                0.115453  0.122845\n",
       "Equal Opportunity                                0.062360  0.007663\n",
       "Average Group Difference in False Negative Rate  0.062360  0.007663\n",
       "Equalized Odds                                   0.079554  0.044270\n",
       "Conditional Use Accuracy                         0.113697  0.125512\n",
       "Average Group Difference in Accuracy             0.122871  0.127511\n",
       "Treatment Equality                               0.231433  0.467860"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ftree.evaluate_fairness()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And on the test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>original</th>\n",
       "      <th>updated</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Statistical Parity</th>\n",
       "      <td>0.182991</td>\n",
       "      <td>0.157615</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Predictive Parity</th>\n",
       "      <td>0.114898</td>\n",
       "      <td>0.126459</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Equal Opportunity</th>\n",
       "      <td>0.056253</td>\n",
       "      <td>0.011652</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Average Group Difference in False Negative Rate</th>\n",
       "      <td>0.056253</td>\n",
       "      <td>0.011652</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Equalized Odds</th>\n",
       "      <td>0.078540</td>\n",
       "      <td>0.047740</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Conditional Use Accuracy</th>\n",
       "      <td>0.113885</td>\n",
       "      <td>0.126001</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Average Group Difference in Accuracy</th>\n",
       "      <td>0.125034</td>\n",
       "      <td>0.125877</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Treatment Equality</th>\n",
       "      <td>0.269692</td>\n",
       "      <td>0.499027</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                 original   updated\n",
       "Statistical Parity                               0.182991  0.157615\n",
       "Predictive Parity                                0.114898  0.126459\n",
       "Equal Opportunity                                0.056253  0.011652\n",
       "Average Group Difference in False Negative Rate  0.056253  0.011652\n",
       "Equalized Odds                                   0.078540  0.047740\n",
       "Conditional Use Accuracy                         0.113885  0.126001\n",
       "Average Group Difference in Accuracy             0.125034  0.125877\n",
       "Treatment Equality                               0.269692  0.499027"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ftree.evaluate_fairness(test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now check validation performance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>original</th>\n",
       "      <th>updated</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Accuracy</th>\n",
       "      <td>0.809910</td>\n",
       "      <td>0.806552</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Balanced Accuracy</th>\n",
       "      <td>0.739240</td>\n",
       "      <td>0.723428</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>F1 score</th>\n",
       "      <td>0.603180</td>\n",
       "      <td>0.582538</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>MCC</th>\n",
       "      <td>0.478201</td>\n",
       "      <td>0.457241</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Precision</th>\n",
       "      <td>0.602665</td>\n",
       "      <td>0.602339</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Recall</th>\n",
       "      <td>0.603696</td>\n",
       "      <td>0.563997</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ROC AUC</th>\n",
       "      <td>0.739219</td>\n",
       "      <td>0.697202</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                   original   updated\n",
       "Accuracy           0.809910  0.806552\n",
       "Balanced Accuracy  0.739240  0.723428\n",
       "F1 score           0.603180  0.582538\n",
       "MCC                0.478201  0.457241\n",
       "Precision          0.602665  0.602339\n",
       "Recall             0.603696  0.563997\n",
       "ROC AUC            0.739219  0.697202"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ftree.evaluate()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And on the test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>original</th>\n",
       "      <th>updated</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Accuracy</th>\n",
       "      <td>0.806732</td>\n",
       "      <td>0.804029</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Balanced Accuracy</th>\n",
       "      <td>0.738907</td>\n",
       "      <td>0.723995</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>F1 score</th>\n",
       "      <td>0.601217</td>\n",
       "      <td>0.582155</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>MCC</th>\n",
       "      <td>0.473767</td>\n",
       "      <td>0.454384</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Precision</th>\n",
       "      <td>0.593792</td>\n",
       "      <td>0.594296</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Recall</th>\n",
       "      <td>0.608830</td>\n",
       "      <td>0.570500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ROC AUC</th>\n",
       "      <td>0.738907</td>\n",
       "      <td>0.695082</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                   original   updated\n",
       "Accuracy           0.806732  0.804029\n",
       "Balanced Accuracy  0.738907  0.723995\n",
       "F1 score           0.601217  0.582155\n",
       "MCC                0.473767  0.454384\n",
       "Precision          0.593792  0.594296\n",
       "Recall             0.608830  0.570500\n",
       "ROC AUC            0.738907  0.695082"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ftree.evaluate(test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now run fairlearn on the same data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>#sk-container-id-1 {color: black;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>ExponentiatedGradient(constraints=&lt;fairlearn.reductions._moments.utility_parity.TruePositiveRateParity object at 0x164572ab0&gt;,\n",
       "                      estimator=DecisionTreeClassifier(),\n",
       "                      nu=2.0474182056426738e-05)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" ><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">ExponentiatedGradient</label><div class=\"sk-toggleable__content\"><pre>ExponentiatedGradient(constraints=&lt;fairlearn.reductions._moments.utility_parity.TruePositiveRateParity object at 0x164572ab0&gt;,\n",
       "                      estimator=DecisionTreeClassifier(),\n",
       "                      nu=2.0474182056426738e-05)</pre></div></div></div><div class=\"sk-parallel\"><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-2\" type=\"checkbox\" ><label for=\"sk-estimator-id-2\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">estimator: DecisionTreeClassifier</label><div class=\"sk-toggleable__content\"><pre>DecisionTreeClassifier()</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-3\" type=\"checkbox\" ><label for=\"sk-estimator-id-3\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">DecisionTreeClassifier</label><div class=\"sk-toggleable__content\"><pre>DecisionTreeClassifier()</pre></div></div></div></div></div></div></div></div></div></div>"
      ],
      "text/plain": [
       "ExponentiatedGradient(constraints=<fairlearn.reductions._moments.utility_parity.TruePositiveRateParity object at 0x164572ab0>,\n",
       "                      estimator=DecisionTreeClassifier(),\n",
       "                      nu=2.0474182056426738e-05)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from fairlearn.reductions import TruePositiveRateParity, ExponentiatedGradient\n",
    "mitagator = ExponentiatedGradient(DecisionTreeClassifier(),TruePositiveRateParity())\n",
    "mitagator.fit(X=train['data'],y=train['target'],sensitive_features=train['data']['sex'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To evaluate fairlearn, we write a helper function to evaluate performance and fairness on train or test, and concat the outputs together.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>train</th>\n",
       "      <th>test</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Accuracy</th>\n",
       "      <td>0.999959</td>\n",
       "      <td>0.807796</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Balanced Accuracy</th>\n",
       "      <td>0.999973</td>\n",
       "      <td>0.738434</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>F1 score</th>\n",
       "      <td>0.999914</td>\n",
       "      <td>0.601189</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>MCC</th>\n",
       "      <td>0.999888</td>\n",
       "      <td>0.474606</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Precision</th>\n",
       "      <td>0.999829</td>\n",
       "      <td>0.597030</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Recall</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.605407</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ROC AUC</th>\n",
       "      <td>0.999973</td>\n",
       "      <td>0.738434</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Statistical Parity</th>\n",
       "      <td>0.194639</td>\n",
       "      <td>0.184861</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Predictive Parity</th>\n",
       "      <td>0.000202</td>\n",
       "      <td>0.088625</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Equal Opportunity</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.041558</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Average Group Difference in False Negative Rate</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.041558</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Equalized Odds</th>\n",
       "      <td>0.000044</td>\n",
       "      <td>0.073702</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Conditional Use Accuracy</th>\n",
       "      <td>0.000101</td>\n",
       "      <td>0.102414</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Average Group Difference in Accuracy</th>\n",
       "      <td>0.000061</td>\n",
       "      <td>0.132310</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Treatment Equality</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.209488</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                    train      test\n",
       "Accuracy                                         0.999959  0.807796\n",
       "Balanced Accuracy                                0.999973  0.738434\n",
       "F1 score                                         0.999914  0.601189\n",
       "MCC                                              0.999888  0.474606\n",
       "Precision                                        0.999829  0.597030\n",
       "Recall                                           1.000000  0.605407\n",
       "ROC AUC                                          0.999973  0.738434\n",
       "Statistical Parity                               0.194639  0.184861\n",
       "Predictive Parity                                0.000202  0.088625\n",
       "Equal Opportunity                                0.000000  0.041558\n",
       "Average Group Difference in False Negative Rate  0.000000  0.041558\n",
       "Equalized Odds                                   0.000044  0.073702\n",
       "Conditional Use Accuracy                         0.000101  0.102414\n",
       "Average Group Difference in Accuracy             0.000061  0.132310\n",
       "Treatment Equality                               1.000000  0.209488"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def eval(train, classifier=mitagator):\n",
    "    return pd.concat((performance.evaluate(train['target'], classifier.predict(train['data'])),\n",
    "                      performance.evaluate_fairness(train['target'], classifier.predict(train['data']), train['groups'])),axis=0)\n",
    "\n",
    "out = pd.concat((eval(train), eval(test)), axis=1)\n",
    "out.columns = ['train', 'test']\n",
    "out"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Evaluating the initially trained baseline classifier we find that, as expected, fairlearn did not substantially alter the performance or unfairness of the classifier (beyond altering the random seed of the tree)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>train</th>\n",
       "      <th>test</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Accuracy</th>\n",
       "      <td>0.999959</td>\n",
       "      <td>0.806732</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Balanced Accuracy</th>\n",
       "      <td>0.999914</td>\n",
       "      <td>0.738907</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>F1 score</th>\n",
       "      <td>0.999914</td>\n",
       "      <td>0.601217</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>MCC</th>\n",
       "      <td>0.999888</td>\n",
       "      <td>0.473767</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Precision</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.593792</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Recall</th>\n",
       "      <td>0.999829</td>\n",
       "      <td>0.608830</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ROC AUC</th>\n",
       "      <td>0.999914</td>\n",
       "      <td>0.738907</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Statistical Parity</th>\n",
       "      <td>0.194516</td>\n",
       "      <td>0.182991</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Predictive Parity</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.114898</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Equal Opportunity</th>\n",
       "      <td>0.000202</td>\n",
       "      <td>0.056253</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Average Group Difference in False Negative Rate</th>\n",
       "      <td>0.000202</td>\n",
       "      <td>0.056253</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Equalized Odds</th>\n",
       "      <td>0.000101</td>\n",
       "      <td>0.078540</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Conditional Use Accuracy</th>\n",
       "      <td>0.000044</td>\n",
       "      <td>0.113885</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Average Group Difference in Accuracy</th>\n",
       "      <td>0.000061</td>\n",
       "      <td>0.125034</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Treatment Equality</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.269692</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                    train      test\n",
       "Accuracy                                         0.999959  0.806732\n",
       "Balanced Accuracy                                0.999914  0.738907\n",
       "F1 score                                         0.999914  0.601217\n",
       "MCC                                              0.999888  0.473767\n",
       "Precision                                        1.000000  0.593792\n",
       "Recall                                           0.999829  0.608830\n",
       "ROC AUC                                          0.999914  0.738907\n",
       "Statistical Parity                               0.194516  0.182991\n",
       "Predictive Parity                                0.000000  0.114898\n",
       "Equal Opportunity                                0.000202  0.056253\n",
       "Average Group Difference in False Negative Rate  0.000202  0.056253\n",
       "Equalized Odds                                   0.000101  0.078540\n",
       "Conditional Use Accuracy                         0.000044  0.113885\n",
       "Average Group Difference in Accuracy             0.000061  0.125034\n",
       "Treatment Equality                               0.000000  0.269692"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out = pd.concat((eval(train, basetree), eval(test, basetree)), axis=1)\n",
    "out.columns = ['train', 'test']\n",
    "out"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now do the same with the random forest classifier."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>original</th>\n",
       "      <th>updated</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Statistical Parity</th>\n",
       "      <td>0.171365</td>\n",
       "      <td>0.147358</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Predictive Parity</th>\n",
       "      <td>0.018927</td>\n",
       "      <td>0.054233</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Equal Opportunity</th>\n",
       "      <td>0.083080</td>\n",
       "      <td>0.005127</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Average Group Difference in False Negative Rate</th>\n",
       "      <td>0.083080</td>\n",
       "      <td>0.005127</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Equalized Odds</th>\n",
       "      <td>0.075989</td>\n",
       "      <td>0.029586</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Conditional Use Accuracy</th>\n",
       "      <td>0.059969</td>\n",
       "      <td>0.083430</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Average Group Difference in Accuracy</th>\n",
       "      <td>0.109057</td>\n",
       "      <td>0.109686</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Treatment Equality</th>\n",
       "      <td>0.120955</td>\n",
       "      <td>0.168906</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                 original   updated\n",
       "Statistical Parity                               0.171365  0.147358\n",
       "Predictive Parity                                0.018927  0.054233\n",
       "Equal Opportunity                                0.083080  0.005127\n",
       "Average Group Difference in False Negative Rate  0.083080  0.005127\n",
       "Equalized Odds                                   0.075989  0.029586\n",
       "Conditional Use Accuracy                         0.059969  0.083430\n",
       "Average Group Difference in Accuracy             0.109057  0.109686\n",
       "Treatment Equality                               0.120955  0.168906"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fforest.evaluate_fairness()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>original</th>\n",
       "      <th>updated</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Statistical Parity</th>\n",
       "      <td>0.176631</td>\n",
       "      <td>0.141287</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Predictive Parity</th>\n",
       "      <td>0.028393</td>\n",
       "      <td>0.074746</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Equal Opportunity</th>\n",
       "      <td>0.074845</td>\n",
       "      <td>0.020783</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Average Group Difference in False Negative Rate</th>\n",
       "      <td>0.074845</td>\n",
       "      <td>0.020783</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Equalized Odds</th>\n",
       "      <td>0.075552</td>\n",
       "      <td>0.035816</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Conditional Use Accuracy</th>\n",
       "      <td>0.065604</td>\n",
       "      <td>0.095488</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Average Group Difference in Accuracy</th>\n",
       "      <td>0.116183</td>\n",
       "      <td>0.109884</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Treatment Equality</th>\n",
       "      <td>0.210769</td>\n",
       "      <td>0.280337</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                 original   updated\n",
       "Statistical Parity                               0.176631  0.141287\n",
       "Predictive Parity                                0.028393  0.074746\n",
       "Equal Opportunity                                0.074845  0.020783\n",
       "Average Group Difference in False Negative Rate  0.074845  0.020783\n",
       "Equalized Odds                                   0.075552  0.035816\n",
       "Conditional Use Accuracy                         0.065604  0.095488\n",
       "Average Group Difference in Accuracy             0.116183  0.109884\n",
       "Treatment Equality                               0.210769  0.280337"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fforest.evaluate_fairness(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>original</th>\n",
       "      <th>updated</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Accuracy</th>\n",
       "      <td>0.853235</td>\n",
       "      <td>0.854791</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Balanced Accuracy</th>\n",
       "      <td>0.766780</td>\n",
       "      <td>0.767685</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>F1 score</th>\n",
       "      <td>0.662142</td>\n",
       "      <td>0.664395</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>MCC</th>\n",
       "      <td>0.574487</td>\n",
       "      <td>0.578378</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Precision</th>\n",
       "      <td>0.737196</td>\n",
       "      <td>0.743329</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Recall</th>\n",
       "      <td>0.600958</td>\n",
       "      <td>0.600616</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ROC AUC</th>\n",
       "      <td>0.904615</td>\n",
       "      <td>0.891409</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                   original   updated\n",
       "Accuracy           0.853235  0.854791\n",
       "Balanced Accuracy  0.766780  0.767685\n",
       "F1 score           0.662142  0.664395\n",
       "MCC                0.574487  0.578378\n",
       "Precision          0.737196  0.743329\n",
       "Recall             0.600958  0.600616\n",
       "ROC AUC            0.904615  0.891409"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fforest.evaluate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>original</th>\n",
       "      <th>updated</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Accuracy</th>\n",
       "      <td>0.853657</td>\n",
       "      <td>0.851691</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Balanced Accuracy</th>\n",
       "      <td>0.767404</td>\n",
       "      <td>0.763884</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>F1 score</th>\n",
       "      <td>0.663148</td>\n",
       "      <td>0.657721</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>MCC</th>\n",
       "      <td>0.575743</td>\n",
       "      <td>0.569434</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Precision</th>\n",
       "      <td>0.738145</td>\n",
       "      <td>0.734487</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Recall</th>\n",
       "      <td>0.601985</td>\n",
       "      <td>0.595483</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ROC AUC</th>\n",
       "      <td>0.903594</td>\n",
       "      <td>0.890906</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                   original   updated\n",
       "Accuracy           0.853657  0.851691\n",
       "Balanced Accuracy  0.767404  0.763884\n",
       "F1 score           0.663148  0.657721\n",
       "MCC                0.575743  0.569434\n",
       "Precision          0.738145  0.734487\n",
       "Recall             0.601985  0.595483\n",
       "ROC AUC            0.903594  0.890906"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fforest.evaluate(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>#sk-container-id-2 {color: black;}#sk-container-id-2 pre{padding: 0;}#sk-container-id-2 div.sk-toggleable {background-color: white;}#sk-container-id-2 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-2 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-2 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-2 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-2 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-2 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-2 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-2 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-2 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-2 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-2 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-2 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-2 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-2 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-2 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-2 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-2 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-2 div.sk-item {position: relative;z-index: 1;}#sk-container-id-2 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-2 div.sk-item::before, #sk-container-id-2 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-2 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-2 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-2 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-2 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-2 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-2 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-2 div.sk-label-container {text-align: center;}#sk-container-id-2 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-2 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-2\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>ExponentiatedGradient(constraints=&lt;fairlearn.reductions._moments.utility_parity.TruePositiveRateParity object at 0x164572c90&gt;,\n",
       "                      estimator=RandomForestClassifier(),\n",
       "                      nu=2.895427308515526e-05)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-4\" type=\"checkbox\" ><label for=\"sk-estimator-id-4\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">ExponentiatedGradient</label><div class=\"sk-toggleable__content\"><pre>ExponentiatedGradient(constraints=&lt;fairlearn.reductions._moments.utility_parity.TruePositiveRateParity object at 0x164572c90&gt;,\n",
       "                      estimator=RandomForestClassifier(),\n",
       "                      nu=2.895427308515526e-05)</pre></div></div></div><div class=\"sk-parallel\"><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-5\" type=\"checkbox\" ><label for=\"sk-estimator-id-5\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">estimator: RandomForestClassifier</label><div class=\"sk-toggleable__content\"><pre>RandomForestClassifier()</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-6\" type=\"checkbox\" ><label for=\"sk-estimator-id-6\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">RandomForestClassifier</label><div class=\"sk-toggleable__content\"><pre>RandomForestClassifier()</pre></div></div></div></div></div></div></div></div></div></div>"
      ],
      "text/plain": [
       "ExponentiatedGradient(constraints=<fairlearn.reductions._moments.utility_parity.TruePositiveRateParity object at 0x164572c90>,\n",
       "                      estimator=RandomForestClassifier(),\n",
       "                      nu=2.895427308515526e-05)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mitagator = ExponentiatedGradient(RandomForestClassifier(),TruePositiveRateParity())\n",
    "mitagator.fit(X=train['data'],y=train['target'],sensitive_features=train['data']['sex'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>train</th>\n",
       "      <th>test</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Accuracy</th>\n",
       "      <td>0.999959</td>\n",
       "      <td>0.854967</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Balanced Accuracy</th>\n",
       "      <td>0.999914</td>\n",
       "      <td>0.769790</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>F1 score</th>\n",
       "      <td>0.999914</td>\n",
       "      <td>0.666792</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>MCC</th>\n",
       "      <td>0.999888</td>\n",
       "      <td>0.579960</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Precision</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.740493</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Recall</th>\n",
       "      <td>0.999829</td>\n",
       "      <td>0.606434</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ROC AUC</th>\n",
       "      <td>0.999914</td>\n",
       "      <td>0.769790</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Statistical Parity</th>\n",
       "      <td>0.194516</td>\n",
       "      <td>0.178964</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Predictive Parity</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.008129</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Equal Opportunity</th>\n",
       "      <td>0.000202</td>\n",
       "      <td>0.098747</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Average Group Difference in False Negative Rate</th>\n",
       "      <td>0.000202</td>\n",
       "      <td>0.098747</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Equalized Odds</th>\n",
       "      <td>0.000101</td>\n",
       "      <td>0.086332</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Conditional Use Accuracy</th>\n",
       "      <td>0.000044</td>\n",
       "      <td>0.053082</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Average Group Difference in Accuracy</th>\n",
       "      <td>0.000061</td>\n",
       "      <td>0.110158</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Treatment Equality</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.208607</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                    train      test\n",
       "Accuracy                                         0.999959  0.854967\n",
       "Balanced Accuracy                                0.999914  0.769790\n",
       "F1 score                                         0.999914  0.666792\n",
       "MCC                                              0.999888  0.579960\n",
       "Precision                                        1.000000  0.740493\n",
       "Recall                                           0.999829  0.606434\n",
       "ROC AUC                                          0.999914  0.769790\n",
       "Statistical Parity                               0.194516  0.178964\n",
       "Predictive Parity                                0.000000  0.008129\n",
       "Equal Opportunity                                0.000202  0.098747\n",
       "Average Group Difference in False Negative Rate  0.000202  0.098747\n",
       "Equalized Odds                                   0.000101  0.086332\n",
       "Conditional Use Accuracy                         0.000044  0.053082\n",
       "Average Group Difference in Accuracy             0.000061  0.110158\n",
       "Treatment Equality                               0.000000  0.208607"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out = pd.concat((eval(train,mitagator), eval(test,mitagator)), axis=1)\n",
    "out.columns = ['train', 'test']\n",
    "out"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
