{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### This notebook demonstrates the use of the Reject Option Classification (ROC) post-processing algorithm for bias mitigation.\n",
    "- The debiasing function used is implemented in the `RejectOptionClassification` class.\n",
    "- Divide the dataset into training, validation, and testing partitions.\n",
    "- Train classifier on original training data.\n",
    "- Estimate the optimal classification threshold, that maximizes balanced accuracy without fairness constraints.\n",
    "- Estimate the optimal classification threshold, and the critical region boundary (ROC margin) using a validation set for the desired constraint on fairness. The best parameters are those that maximize the classification threshold while satisfying the fairness constraints.\n",
    "- The constraints can be used on the following fairness measures:\n",
    "    * Statistical parity difference on the predictions of the classifier\n",
    "    * Average odds difference for the classifier\n",
    "    * Equal opportunity difference for the classifier\n",
    "- Determine the prediction scores for testing data. Using the estimated optimal classification threshold, compute accuracy and fairness metrics.\n",
    "- Using the determined optimal classification threshold and the ROC margin, adjust the predictions. Report accuracy and fairness metric on the new predictions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:No module named 'tempeh': LawSchoolGPADataset will be unavailable. To install, run:\n",
      "pip install 'aif360[LawSchoolGPA]'\n",
      "WARNING:root:No module named 'fairlearn': ExponentiatedGradientReduction will be unavailable. To install, run:\n",
      "pip install 'aif360[Reductions]'\n",
      "WARNING:root:No module named 'fairlearn': GridSearchReduction will be unavailable. To install, run:\n",
      "pip install 'aif360[Reductions]'\n",
      "WARNING:root:No module named 'fairlearn': GridSearchReduction will be unavailable. To install, run:\n",
      "pip install 'aif360[Reductions]'\n"
     ]
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "# Load all necessary packages\n",
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from warnings import warn\n",
    "\n",
    "from aif360.datasets import BinaryLabelDataset\n",
    "from aif360.datasets import AdultDataset, GermanDataset, CompasDataset\n",
    "from aif360.metrics import ClassificationMetric, BinaryLabelDatasetMetric\n",
    "from aif360.metrics.utils import compute_boolean_conditioning_vector\n",
    "from aif360.algorithms.preprocessing.optim_preproc_helpers.data_preproc_functions\\\n",
    "        import load_preproc_data_adult, load_preproc_data_german, load_preproc_data_compas\n",
    "from aif360.algorithms.postprocessing.reject_option_classification\\\n",
    "        import RejectOptionClassification\n",
    "from common_utils import compute_metrics\n",
    "\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.metrics import accuracy_score\n",
    "\n",
    "from IPython.display import Markdown, display\n",
    "import matplotlib.pyplot as plt\n",
    "from ipywidgets import interactive, FloatSlider"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gc"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load dataset and specify options"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "## import dataset\n",
    "dataset_used = \"adult\" # \"adult\", \"german\", \"compas\"\n",
    "protected_attribute_used = 1 # 1, 2\n",
    "\n",
    "if dataset_used == \"adult\":\n",
    "#     dataset_orig = AdultDataset()\n",
    "    if protected_attribute_used == 1:\n",
    "        privileged_groups = [{'sex': 1}]\n",
    "        unprivileged_groups = [{'sex': 0}]\n",
    "        dataset_orig = load_preproc_data_adult(['sex'])\n",
    "    else:\n",
    "        privileged_groups = [{'race': 1}]\n",
    "        unprivileged_groups = [{'race': 0}]\n",
    "        dataset_orig = load_preproc_data_adult(['race'])\n",
    "    \n",
    "elif dataset_used == \"german\":\n",
    "#     dataset_orig = GermanDataset()\n",
    "    if protected_attribute_used == 1:\n",
    "        privileged_groups = [{'sex': 1}]\n",
    "        unprivileged_groups = [{'sex': 0}]\n",
    "        dataset_orig = load_preproc_data_german(['sex'])\n",
    "    else:\n",
    "        privileged_groups = [{'age': 1}]\n",
    "        unprivileged_groups = [{'age': 0}]\n",
    "        dataset_orig = load_preproc_data_german(['age'])\n",
    "    \n",
    "elif dataset_used == \"compas\":\n",
    "#     dataset_orig = CompasDataset()\n",
    "    if protected_attribute_used == 1:\n",
    "        privileged_groups = [{'sex': 1}]\n",
    "        unprivileged_groups = [{'sex': 0}]\n",
    "        dataset_orig = load_preproc_data_compas(['sex'])\n",
    "    else:\n",
    "        privileged_groups = [{'race': 1}]\n",
    "        unprivileged_groups = [{'race': 0}]  \n",
    "        dataset_orig = load_preproc_data_compas(['race'])\n",
    "\n",
    "        \n",
    "# Metric used (should be one of allowed_metrics)\n",
    "metric_name = \"Statistical parity difference\"\n",
    "\n",
    "# Upper and lower bound on the fairness metric used\n",
    "metric_ub = 0.11\n",
    "metric_lb = -0.11\n",
    "        \n",
    "#random seed for calibrated equal odds prediction\n",
    "np.random.seed(1)\n",
    "\n",
    "# Verify metric name\n",
    "allowed_metrics = [\"Statistical parity difference\",\n",
    "                   \"Average odds difference\",\n",
    "                   \"Equal opportunity difference\"]\n",
    "if metric_name not in allowed_metrics:\n",
    "    raise ValueError(\"Metric name should be one of allowed metrics\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Split into train, test and validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 1/50 [00:10<08:20, 10.21s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.6834993858332196\n",
      "0.023940942424715228\n",
      "0.015054750562392472\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  4%|▍         | 2/50 [00:20<08:01, 10.03s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.4154497065647604\n",
      "0.11510011816741605\n",
      "0.03682241842422365\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|▌         | 3/50 [00:29<07:47,  9.95s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.1044083526682136\n",
      "0.13257489175050485\n",
      "0.057065539515779834\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  8%|▊         | 4/50 [00:39<07:35,  9.90s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.838269414494336\n",
      "0.1420004260714367\n",
      "0.068522257888251\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|█         | 5/50 [00:49<07:25,  9.89s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3.5244984304626725\n",
      "0.1516629119410604\n",
      "0.08791839967543691\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 12%|█▏        | 6/50 [00:59<07:16,  9.91s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4.199399481370275\n",
      "0.1935196059716675\n",
      "0.09604406174462382\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 14%|█▍        | 7/50 [01:09<07:05,  9.89s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4.874573495291388\n",
      "0.20514196709323795\n",
      "0.10991110433254486\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 16%|█▌        | 8/50 [01:19<06:58,  9.95s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5.552067694827351\n",
      "0.21430208393923655\n",
      "0.11575540424184516\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 18%|█▊        | 9/50 [01:29<06:49,  9.99s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6.222464856012011\n",
      "0.2372381855033805\n",
      "0.1341247880698248\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|██        | 10/50 [01:39<06:39,  9.99s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6.925344615804559\n",
      "0.31150807594466\n",
      "0.15345097007364328\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 22%|██▏       | 11/50 [01:49<06:31, 10.04s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7.65893271461717\n",
      "0.3536429367123418\n",
      "0.17022085426562378\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 24%|██▍       | 12/50 [02:00<06:25, 10.16s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8.337109321686912\n",
      "0.41500765402379003\n",
      "0.1875876145330711\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 26%|██▌       | 13/50 [02:10<06:14, 10.11s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9.011055002047224\n",
      "0.42457154854711654\n",
      "0.20810271524419843\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 28%|██▊       | 14/50 [02:20<06:01, 10.05s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9.687730312542651\n",
      "0.4601705712709919\n",
      "0.22270483868024415\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 30%|███       | 15/50 [02:30<05:50, 10.01s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10.427460079159275\n",
      "0.5269005561324793\n",
      "0.2341524811245444\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 32%|███▏      | 16/50 [02:40<05:40, 10.01s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "11.155588917701651\n",
      "0.5514133271317858\n",
      "0.24446175237402784\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 34%|███▍      | 17/50 [02:50<05:34, 10.13s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "11.827487375460626\n",
      "0.5966565013420831\n",
      "0.26885044285367976\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 36%|███▌      | 18/50 [03:00<05:27, 10.23s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "12.549338064692234\n",
      "0.6176133094921391\n",
      "0.2715968333867886\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 38%|███▊      | 19/50 [03:11<05:20, 10.33s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "13.227514671761977\n",
      "0.64564615756978\n",
      "0.2967711368731532\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|████      | 20/50 [03:22<05:13, 10.46s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "13.921113689095128\n",
      "0.6603199418535572\n",
      "0.29955213955844257\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 42%|████▏     | 21/50 [03:34<05:17, 10.94s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "14.654155861880716\n",
      "0.7132769926696216\n",
      "0.3122285723132164\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 44%|████▍     | 22/50 [03:52<06:11, 13.26s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "15.320185614849189\n",
      "0.7891752790156692\n",
      "0.3453258075040101\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 46%|████▌     | 23/50 [04:08<06:19, 14.04s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "16.018834447932306\n",
      "0.8458647821261503\n",
      "0.35724382800914667\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 48%|████▊     | 24/50 [04:19<05:36, 12.93s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "16.770847550156954\n",
      "0.8978978505337575\n",
      "0.3994553146386307\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|█████     | 25/50 [04:29<05:04, 12.17s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "17.442200081888906\n",
      "0.9256609146553361\n",
      "0.4406701508551352\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 52%|█████▏    | 26/50 [04:39<04:39, 11.63s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "18.1239252081343\n",
      "0.9334318758005303\n",
      "0.459225959417824\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 54%|█████▍    | 27/50 [04:50<04:18, 11.25s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "18.805240889859427\n",
      "0.9683114634660418\n",
      "0.47061140282793157\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 56%|█████▌    | 28/50 [05:00<04:03, 11.05s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "19.536918247577457\n",
      "1.0139500746417376\n",
      "0.48694427039507954\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 58%|█████▊    | 29/50 [05:11<03:49, 10.95s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "20.267913197761708\n",
      "1.0461055699278023\n",
      "0.510831104656192\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 60%|██████    | 30/50 [05:22<03:39, 10.98s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "20.950457213047635\n",
      "1.0464471259050319\n",
      "0.5206032996742886\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 62%|██████▏   | 31/50 [05:33<03:29, 11.01s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "21.62822437559711\n",
      "1.0469114961562185\n",
      "0.5331505502284857\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 64%|██████▍   | 32/50 [05:44<03:16, 10.93s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "22.363996178517816\n",
      "1.092366041610764\n",
      "0.5354538142219767\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 66%|██████▌   | 33/50 [05:55<03:06, 10.95s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "23.044765934215917\n",
      "1.151956288900989\n",
      "0.5556246649013019\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 68%|██████▊   | 34/50 [08:21<13:44, 51.53s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "23.77930940357582\n",
      "1.1704428247474987\n",
      "0.5607271336882763\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 70%|███████   | 35/50 [08:33<09:53, 39.56s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "24.450116009280748\n",
      "1.1886046679565279\n",
      "0.5788288236649664\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 72%|███████▏  | 36/50 [08:43<07:12, 30.88s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "25.135799099222062\n",
      "1.2351642247432313\n",
      "0.5825789357454341\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 74%|███████▍  | 37/50 [08:54<05:24, 24.93s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "25.813975706291803\n",
      "1.2534980036403907\n",
      "0.6035087145877529\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 76%|███████▌  | 38/50 [09:10<04:25, 22.09s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "26.555479732496252\n",
      "1.3127531890583668\n",
      "0.6214516164025222\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 78%|███████▊  | 39/50 [09:31<04:00, 21.90s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "27.235567080660577\n",
      "1.3201729415810826\n",
      "0.6348914607815932\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 80%|████████  | 40/50 [09:48<03:23, 20.37s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "27.917838132932992\n",
      "1.3344994384633604\n",
      "0.6506820366750994\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 82%|████████▏ | 41/50 [10:07<02:58, 19.86s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "28.589190664664944\n",
      "1.401696827310234\n",
      "0.6724617457001837\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 84%|████████▍ | 42/50 [10:25<02:34, 19.25s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "29.256994677221243\n",
      "1.4227109624354506\n",
      "0.6906108135506703\n",
      "29.938992766480148\n",
      "1.468473888034442\n",
      "0.6957332546460286\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 88%|████████▊ | 44/50 [10:57<01:44, 17.37s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "30.618943633137714\n",
      "1.5337270109798724\n",
      "0.7245915502007344\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 90%|█████████ | 45/50 [11:08<01:18, 15.62s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "31.34188617442337\n",
      "1.565348395060273\n",
      "0.7351626064040572\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 92%|█████████▏| 46/50 [11:24<01:02, 15.57s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "32.0092807424594\n",
      "1.6361025622126215\n",
      "0.7747022353914155\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 94%|█████████▍| 47/50 [11:45<00:52, 17.38s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "32.68854920158319\n",
      "1.6715816991315489\n",
      "0.7779850884736368\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 96%|█████████▌| 48/50 [12:07<00:37, 18.65s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "33.36495154906511\n",
      "1.6988701038386467\n",
      "0.7907756324766575\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 98%|█████████▊| 49/50 [12:29<00:19, 19.81s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "34.11136890951277\n",
      "1.7442411243504714\n",
      "0.7983415248837186\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [12:51<00:00, 15.42s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "34.77617032892044\n",
      "1.764951120060467\n",
      "0.8386069047808814\n",
      "0.6955234065784088\n",
      "0.03529902240120934\n",
      "0.016772138095617627\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "mis_mean = 0\n",
    "deoo_mean = 0\n",
    "dpe_mean = 0\n",
    "for step in tqdm(range(50)):\n",
    "\n",
    "    # Get the dataset and split into train and test\n",
    "    dataset_orig_train, dataset_orig_vt = dataset_orig.split([0.7], shuffle=True)\n",
    "    dataset_orig_valid, dataset_orig_test = dataset_orig_vt.split([0.5], shuffle=True)\n",
    "\n",
    "    # Logistic regression classifier and predictions\n",
    "    scale_orig = StandardScaler()\n",
    "    X_train = scale_orig.fit_transform(dataset_orig_train.features)\n",
    "    y_train = dataset_orig_train.labels.ravel()\n",
    "\n",
    "    lmod = LogisticRegression()\n",
    "    lmod.fit(X_train, y_train)\n",
    "    y_train_pred = lmod.predict(X_train)\n",
    "\n",
    "    # positive class index\n",
    "    pos_ind = np.where(lmod.classes_ == dataset_orig_train.favorable_label)[0][0]\n",
    "\n",
    "    dataset_orig_train_pred = dataset_orig_train.copy(deepcopy=True)\n",
    "    dataset_orig_train_pred.labels = y_train_pred\n",
    "\n",
    "    dataset_orig_valid_pred = dataset_orig_valid.copy(deepcopy=True)\n",
    "    X_valid = scale_orig.transform(dataset_orig_valid_pred.features)\n",
    "    y_valid = dataset_orig_valid_pred.labels\n",
    "    dataset_orig_valid_pred.scores = lmod.predict_proba(X_valid)[:,pos_ind].reshape(-1,1)\n",
    "\n",
    "    dataset_orig_test_pred = dataset_orig_test.copy(deepcopy=True)\n",
    "    X_test = scale_orig.transform(dataset_orig_test_pred.features)\n",
    "    y_test = dataset_orig_test_pred.labels.ravel()\n",
    "    Y_test = y_test.copy()\n",
    "    dataset_orig_test_pred.scores = lmod.predict_proba(X_test)[:,pos_ind].reshape(-1,1)\n",
    "\n",
    "    num_thresh = 100\n",
    "    ba_arr = np.zeros(num_thresh)\n",
    "    class_thresh_arr = np.linspace(0.01, 0.99, num_thresh)\n",
    "    for idx, class_thresh in enumerate(class_thresh_arr):\n",
    "        \n",
    "        fav_inds = dataset_orig_valid_pred.scores > class_thresh\n",
    "        dataset_orig_valid_pred.labels[fav_inds] = dataset_orig_valid_pred.favorable_label\n",
    "        dataset_orig_valid_pred.labels[~fav_inds] = dataset_orig_valid_pred.unfavorable_label\n",
    "        \n",
    "        classified_metric_orig_valid = ClassificationMetric(dataset_orig_valid,\n",
    "                                                dataset_orig_valid_pred, \n",
    "                                                unprivileged_groups=unprivileged_groups,\n",
    "                                                privileged_groups=privileged_groups)\n",
    "        \n",
    "        ba_arr[idx] = 0.5*(classified_metric_orig_valid.true_positive_rate()\\\n",
    "                        +classified_metric_orig_valid.true_negative_rate())\n",
    "\n",
    "    best_ind = np.where(ba_arr == np.max(ba_arr))[0][0]\n",
    "    best_class_thresh = class_thresh_arr[best_ind]\n",
    "\n",
    "    ROC = RejectOptionClassification(unprivileged_groups=unprivileged_groups, \n",
    "                                    privileged_groups=privileged_groups, \n",
    "                                    low_class_thresh=0.01, high_class_thresh=0.99,\n",
    "                                    num_class_thresh=100, num_ROC_margin=50,\n",
    "                                    metric_name=metric_name,\n",
    "                                    metric_ub=metric_ub, metric_lb=metric_lb)\n",
    "    ROC = ROC.fit(dataset_orig_valid, dataset_orig_valid_pred)\n",
    "\n",
    "    # Metrics for the test set\n",
    "    fav_inds = dataset_orig_valid_pred.scores > best_class_thresh\n",
    "    dataset_orig_valid_pred.labels[fav_inds] = dataset_orig_valid_pred.favorable_label\n",
    "    dataset_orig_valid_pred.labels[~fav_inds] = dataset_orig_valid_pred.unfavorable_label\n",
    "\n",
    "\n",
    "    # Transform the validation set\n",
    "    dataset_transf_valid_pred = ROC.predict(dataset_orig_valid_pred)\n",
    "\n",
    "\n",
    "\n",
    "    # Metrics for the test set\n",
    "    fav_inds = dataset_orig_test_pred.scores > best_class_thresh\n",
    "    dataset_orig_test_pred.labels[fav_inds] = dataset_orig_test_pred.favorable_label\n",
    "    dataset_orig_test_pred.labels[~fav_inds] = dataset_orig_test_pred.unfavorable_label\n",
    "\n",
    "\n",
    "    # Metrics for the transformed test set\n",
    "    dataset_transf_test_pred = ROC.predict(dataset_orig_test_pred)\n",
    "\n",
    "    eq = 0\n",
    "    for i in range(len(Y_test)):\n",
    "        if(dataset_transf_test_pred.labels.ravel()[i] == Y_test[i]):\n",
    "            eq += 1\n",
    "    mis_mean += eq / len(Y_test)\n",
    "    n_10 = 0\n",
    "    n_11 = 0\n",
    "    c_10 = 0\n",
    "    c_11 = 0\n",
    "    for i in range(len(Y_test)):\n",
    "        if(Y_test[i] == 1 and dataset_transf_test_pred.protected_attributes[:,0][i] == 0):\n",
    "            n_10 += 1\n",
    "            if(dataset_transf_test_pred.labels.ravel()[i] == 1):\n",
    "                c_10 += 1\n",
    "        elif(Y_test[i] == 1 and dataset_transf_test_pred.protected_attributes[:,0][i] == 1):\n",
    "            n_11 += 1\n",
    "            if(dataset_transf_test_pred.labels.ravel()[i] == 1):\n",
    "                c_11 += 1\n",
    "    deoo_mean += abs(c_10 / n_10 - c_11 / n_11)\n",
    "    n_00 = 0\n",
    "    n_01 = 0\n",
    "    c_00 = 0\n",
    "    c_01 = 0\n",
    "    for i in range(len(Y_test)):\n",
    "        if(Y_test[i] == 0 and dataset_transf_test_pred.protected_attributes[:,0][i] == 0):\n",
    "            n_00 += 1\n",
    "            if(dataset_transf_test_pred.labels.ravel()[i] == 1):\n",
    "                c_00 += 1\n",
    "        elif(Y_test[i] == 0 and dataset_transf_test_pred.protected_attributes[:,0][i] == 1):\n",
    "            n_01 += 1\n",
    "            if(dataset_transf_test_pred.labels.ravel()[i] == 1):\n",
    "                c_01 += 1\n",
    "    dpe_mean += abs(c_00 / n_00 - c_01 / n_01)\n",
    "    print(mis_mean)\n",
    "    print(deoo_mean)\n",
    "    print(dpe_mean)\n",
    "    del dataset_orig_train\n",
    "    del dataset_orig_vt\n",
    "    del dataset_orig_valid\n",
    "    del dataset_orig_test\n",
    "    del X_train\n",
    "    del y_train\n",
    "    del lmod\n",
    "    del y_train_pred\n",
    "    del pos_ind\n",
    "    del dataset_orig_train_pred\n",
    "    del dataset_orig_valid_pred\n",
    "    del X_valid\n",
    "    del y_valid\n",
    "    del dataset_orig_test_pred\n",
    "    del X_test\n",
    "    del y_test\n",
    "    del Y_test\n",
    "    del ROC\n",
    "    del dataset_transf_test_pred\n",
    "    del dataset_transf_valid_pred\n",
    "    gc.collect()\n",
    "print(mis_mean / 50)\n",
    "print(deoo_mean / 50)\n",
    "print(dpe_mean / 50)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.13 ('aif360')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  },
  "vscode": {
   "interpreter": {
    "hash": "d27cd899c2aad711dbfe864d87d8e4a3aa2aa47421b84ce5060d18007824eabc"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
