{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "53100692",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-05-18T04:57:34.914704Z",
     "iopub.status.busy": "2023-05-18T04:57:34.914050Z",
     "iopub.status.idle": "2023-05-18T04:57:38.469607Z",
     "shell.execute_reply": "2023-05-18T04:57:38.466907Z"
    },
    "papermill": {
     "duration": 3.574555,
     "end_time": "2023-05-18T04:57:38.474479",
     "exception": false,
     "start_time": "2023-05-18T04:57:34.899924",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "from json import dumps\n",
    "import os\n",
    "import pickle\n",
    "import urllib.request\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import datasets\n",
    "from datasets.utils.logging import disable_progress_bar\n",
    "import sklearn.calibration\n",
    "import transformers\n",
    "import torch\n",
    "import tqdm\n",
    "from tqdm.contrib.concurrent import process_map\n",
    "\n",
    "import postprocess\n",
    "import utils\n",
    "\n",
    "split_ratio_for_test = 0.3\n",
    "split_ratio_for_postprocessing = 0.5  # among all training data\n",
    "\n",
    "calibration_method = \"isotonic\"\n",
    "\n",
    "# The seeds control the randomness for the post-process/test split and in\n",
    "# postprocessing.  It does not affect pre-training data nor the randomness in\n",
    "# pre-training, i.e., we assume the pre-trained predictor to be fixed.  Results\n",
    "# will be aggregated over the seeds.\n",
    "seeds = range(33, 43)\n",
    "\n",
    "# This seed controls the randomness during pre-training (fixed).\n",
    "seed_for_pretraining = 33\n",
    "\n",
    "max_workers = 24\n",
    "device = torch.device(\"cuda:3\") if torch.cuda.is_available() else torch.device(\n",
    "    \"cpu\")\n",
    "\n",
    "data_dir = \"data/biasbios\"\n",
    "\n",
    "# BERT config\n",
    "model_name = \"bert-base-uncased\"\n",
    "batch_size = 128\n",
    "\n",
    "disable_progress_bar()\n",
    "transformers.set_seed(seed_for_pretraining)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "1788041e",
   "metadata": {
    "papermill": {
     "duration": 0.007934,
     "end_time": "2023-05-18T04:57:38.493749",
     "exception": false,
     "start_time": "2023-05-18T04:57:38.485815",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "## Download and load BiasBios dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "163c02cc",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-05-18T04:57:38.508969Z",
     "iopub.status.busy": "2023-05-18T04:57:38.507601Z",
     "iopub.status.idle": "2023-05-18T04:57:42.808896Z",
     "shell.execute_reply": "2023-05-18T04:57:42.806940Z"
    },
    "papermill": {
     "duration": 4.313052,
     "end_time": "2023-05-18T04:57:42.812318",
     "exception": false,
     "start_time": "2023-05-18T04:57:38.499266",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\n",
      "  \"bio\": \"male. He produced scores of films including such as al-Dhareeh (the shrine), 1976, winner of the Cinema Institute Films\\u2019 Award at the Documentary and Short Films Festival, Egypt, 1977; as well as the Kelibia Festival Award, Tunisia, 1978; al-Mahatta (The Station), winner of a major award at Oberhausen Short Film Festival, Germany, 1989; the EU Award at FESPACO Festival, Burkina Faso, 1990; The Silver Sword Award at Damascus festival, 1990; and The Silver Tanit Award, Carthage festival, Tunisia, 1991. Eltayeb has served as head of the Sudanese Film group for several terms and as secretary of the Sudanese Film club. He has written numerous articles on cinema, published in major Sudanese newspapers. He is currently working on a long fiction film, al-Siraj wal-attama (The Lantern and Darkness).\",\n",
      "  \"title\": 9,\n",
      "  \"gender\": 1\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "label_names = [\n",
    "    \"accountant\", \"architect\", \"attorney\", \"chiropractor\", \"comedian\",\n",
    "    \"composer\", \"dentist\", \"dietitian\", \"dj\", \"filmmaker\", \"interior_designer\",\n",
    "    \"journalist\", \"model\", \"nurse\", \"painter\", \"paralegal\", \"pastor\",\n",
    "    \"personal_trainer\", \"photographer\", \"physician\", \"poet\", \"professor\",\n",
    "    \"psychologist\", \"rapper\", \"software_engineer\", \"surgeon\", \"teacher\",\n",
    "    \"yoga_teacher\"\n",
    "]\n",
    "n_classes = len(label_names)\n",
    "\n",
    "group_names = [\"female\", \"male\"]\n",
    "n_groups = len(group_names)\n",
    "\n",
    "features = datasets.Features({\n",
    "    \"bio\": datasets.Value(\"string\"),\n",
    "    \"title\": datasets.ClassLabel(names=label_names),\n",
    "    \"gender\": datasets.ClassLabel(names=group_names),\n",
    "})\n",
    "\n",
    "train_path = f\"{data_dir}/train.pickle\"\n",
    "test_path = f\"{data_dir}/test.pickle\"\n",
    "dev_path = f\"{data_dir}/dev.pickle\"\n",
    "if any([not os.path.exists(p) for p in [train_path, test_path, dev_path]]):\n",
    "  os.makedirs(data_dir, exist_ok=True)\n",
    "  urllib.request.urlretrieve(\n",
    "      \"https://storage.googleapis.com/ai2i/nullspace/biasbios/train.pickle\",\n",
    "      train_path)\n",
    "  urllib.request.urlretrieve(\n",
    "      \"https://storage.googleapis.com/ai2i/nullspace/biasbios/test.pickle\",\n",
    "      test_path)\n",
    "  urllib.request.urlretrieve(\n",
    "      \"https://storage.googleapis.com/ai2i/nullspace/biasbios/dev.pickle\",\n",
    "      dev_path)\n",
    "\n",
    "rows = {k: [] for k in features}\n",
    "for split, path in zip([\"train\", \"test\", \"dev\"],\n",
    "                       [train_path, test_path, dev_path]):\n",
    "  with open(path, \"rb\") as pickle_file:\n",
    "    for row in pickle.load(pickle_file):\n",
    "      rows[\"gender\"].append(\"female\" if row[\"g\"] == \"f\" else \"male\")\n",
    "      rows[\"title\"].append(row[\"p\"])\n",
    "      rows[\"bio\"].append(rows[\"gender\"][-1] + \". \" +\n",
    "                         row[\"hard_text_untokenized\"])\n",
    "\n",
    "raw_dataset = datasets.Dataset.from_dict(rows, features=features)\n",
    "labels = np.array(raw_dataset[\"title\"])\n",
    "groups = np.array(raw_dataset[\"gender\"])\n",
    "\n",
    "print(dumps(raw_dataset[seed_for_pretraining], indent=2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a9da82bc",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-05-18T04:57:42.834092Z",
     "iopub.status.busy": "2023-05-18T04:57:42.833812Z",
     "iopub.status.idle": "2023-05-18T04:57:46.642696Z",
     "shell.execute_reply": "2023-05-18T04:57:46.640775Z"
    },
    "papermill": {
     "duration": 3.821628,
     "end_time": "2023-05-18T04:57:46.646169",
     "exception": false,
     "start_time": "2023-05-18T04:57:42.824541",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset statistics:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th>Group</th>\n",
       "      <th>female</th>\n",
       "      <th>male</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Target</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>accountant</th>\n",
       "      <td>0.011428</td>\n",
       "      <td>0.016898</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>architect</th>\n",
       "      <td>0.013168</td>\n",
       "      <td>0.036508</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>attorney</th>\n",
       "      <td>0.068610</td>\n",
       "      <td>0.095177</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>chiropractor</th>\n",
       "      <td>0.003789</td>\n",
       "      <td>0.009029</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>comedian</th>\n",
       "      <td>0.003251</td>\n",
       "      <td>0.010444</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>composer</th>\n",
       "      <td>0.005041</td>\n",
       "      <td>0.022156</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>dentist</th>\n",
       "      <td>0.028297</td>\n",
       "      <td>0.044132</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>dietitian</th>\n",
       "      <td>0.020258</td>\n",
       "      <td>0.001368</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>dj</th>\n",
       "      <td>0.001159</td>\n",
       "      <td>0.006029</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>filmmaker</th>\n",
       "      <td>0.012685</td>\n",
       "      <td>0.022236</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>interior_designer</th>\n",
       "      <td>0.006496</td>\n",
       "      <td>0.001325</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>journalist</th>\n",
       "      <td>0.054217</td>\n",
       "      <td>0.047686</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>model</th>\n",
       "      <td>0.034124</td>\n",
       "      <td>0.006095</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>nurse</th>\n",
       "      <td>0.094650</td>\n",
       "      <td>0.008210</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>painter</th>\n",
       "      <td>0.019456</td>\n",
       "      <td>0.019842</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>paralegal</th>\n",
       "      <td>0.008232</td>\n",
       "      <td>0.001268</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>pastor</th>\n",
       "      <td>0.003344</td>\n",
       "      <td>0.009100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>personal_trainer</th>\n",
       "      <td>0.003591</td>\n",
       "      <td>0.003682</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>photographer</th>\n",
       "      <td>0.047715</td>\n",
       "      <td>0.073987</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>physician</th>\n",
       "      <td>0.107517</td>\n",
       "      <td>0.089844</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>poet</th>\n",
       "      <td>0.018896</td>\n",
       "      <td>0.016894</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>professor</th>\n",
       "      <td>0.292638</td>\n",
       "      <td>0.306737</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>psychologist</th>\n",
       "      <td>0.062520</td>\n",
       "      <td>0.032699</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>rapper</th>\n",
       "      <td>0.000747</td>\n",
       "      <td>0.006015</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>software_engineer</th>\n",
       "      <td>0.005980</td>\n",
       "      <td>0.027527</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>surgeon</th>\n",
       "      <td>0.010829</td>\n",
       "      <td>0.053478</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>teacher</th>\n",
       "      <td>0.053640</td>\n",
       "      <td>0.030418</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>yoga_teacher</th>\n",
       "      <td>0.007721</td>\n",
       "      <td>0.001216</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "Group                female      male\n",
       "Target                               \n",
       "accountant         0.011428  0.016898\n",
       "architect          0.013168  0.036508\n",
       "attorney           0.068610  0.095177\n",
       "chiropractor       0.003789  0.009029\n",
       "comedian           0.003251  0.010444\n",
       "composer           0.005041  0.022156\n",
       "dentist            0.028297  0.044132\n",
       "dietitian          0.020258  0.001368\n",
       "dj                 0.001159  0.006029\n",
       "filmmaker          0.012685  0.022236\n",
       "interior_designer  0.006496  0.001325\n",
       "journalist         0.054217  0.047686\n",
       "model              0.034124  0.006095\n",
       "nurse              0.094650  0.008210\n",
       "painter            0.019456  0.019842\n",
       "paralegal          0.008232  0.001268\n",
       "pastor             0.003344  0.009100\n",
       "personal_trainer   0.003591  0.003682\n",
       "photographer       0.047715  0.073987\n",
       "physician          0.107517  0.089844\n",
       "poet               0.018896  0.016894\n",
       "professor          0.292638  0.306737\n",
       "psychologist       0.062520  0.032699\n",
       "rapper             0.000747  0.006015\n",
       "software_engineer  0.005980  0.027527\n",
       "surgeon            0.010829  0.053478\n",
       "teacher            0.053640  0.030418\n",
       "yoga_teacher       0.007721  0.001216"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th>Group</th>\n",
       "      <th>female</th>\n",
       "      <th>male</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Count</th>\n",
       "      <td>182102</td>\n",
       "      <td>211321</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "Group  female    male\n",
       "Count  182102  211321"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Perfect results if dataset equals population:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>error_rate</th>\n",
       "      <th>delta_dp</th>\n",
       "      <th>delta_dp_rms</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>alpha</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>inf</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.08644</td>\n",
       "      <td>0.023733</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.0</th>\n",
       "      <td>0.115624</td>\n",
       "      <td>0.00000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       error_rate  delta_dp  delta_dp_rms\n",
       "alpha                                    \n",
       "inf      0.000000   0.08644      0.023733\n",
       "0.0      0.115624   0.00000      0.000000"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Compute and print dataset statistics\n",
    "\n",
    "df = pd.DataFrame(np.stack([\n",
    "    np.array(group_names)[raw_dataset[\"gender\"]],\n",
    "    np.array(label_names)[raw_dataset[\"title\"]]\n",
    "],\n",
    "                           axis=1),\n",
    "                  columns=[\"Group\", \"Target\"])\n",
    "df_grouped = df.groupby([\"Target\", \"Group\"]).size().unstack()\n",
    "counts = df_grouped.sum(axis=0)\n",
    "\n",
    "print(\"Dataset statistics:\")\n",
    "display(df_grouped / counts)\n",
    "display(pd.DataFrame(counts, columns=[\"Count\"]).T)\n",
    "\n",
    "print(\"Perfect results if dataset equals population:\")\n",
    "display(\n",
    "    pd.DataFrame([{\n",
    "        \"alpha\":\n",
    "            np.inf,\n",
    "        \"error_rate\":\n",
    "            0.0,\n",
    "        \"delta_dp\":\n",
    "            utils.delta_dp(raw_dataset[\"title\"], raw_dataset[\"gender\"]),\n",
    "        \"delta_dp_rms\":\n",
    "            utils.delta_dp(\n",
    "                raw_dataset[\"title\"],\n",
    "                raw_dataset[\"gender\"],\n",
    "                ord=2,\n",
    "            ) / np.sqrt(n_classes)\n",
    "    }, {\n",
    "        \"alpha\":\n",
    "            0.0,\n",
    "        \"error_rate\":\n",
    "            postprocess.PostProcessorDP().fit(\n",
    "                scores=np.concatenate(\n",
    "                    [np.eye(n_classes) for _ in range(n_groups)], axis=0),\n",
    "                groups=np.repeat(np.arange(n_groups), n_classes),\n",
    "                r=np.nan_to_num(\n",
    "                    (df_grouped.to_numpy() / counts.to_numpy())).T.flatten(),\n",
    "            ).score_,\n",
    "        \"delta_dp\":\n",
    "            0.0,\n",
    "        \"delta_dp_rms\":\n",
    "            0.0\n",
    "    }]).groupby('alpha').agg('mean').sort_index(ascending=False))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "f87b0449",
   "metadata": {
    "papermill": {
     "duration": 0.008494,
     "end_time": "2023-05-18T04:57:46.666911",
     "exception": false,
     "start_time": "2023-05-18T04:57:46.658417",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "## Get BERT embeddings and train prediction head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a9a4ae19",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-05-18T04:57:46.680460Z",
     "iopub.status.busy": "2023-05-18T04:57:46.680115Z",
     "iopub.status.idle": "2023-05-18T04:57:47.153086Z",
     "shell.execute_reply": "2023-05-18T04:57:47.151863Z"
    },
    "papermill": {
     "duration": 0.483847,
     "end_time": "2023-05-18T04:57:47.156709",
     "exception": false,
     "start_time": "2023-05-18T04:57:46.672862",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Tokenize the dataset and compute pre-trained BERT embeddings\n",
    "\n",
    "if os.path.exists(f\"{data_dir}/bert_embeds.pkl\"):\n",
    "  embeds = pickle.load(open(f\"{data_dir}/bert_embeds.pkl\", \"rb\"))\n",
    "\n",
    "else:\n",
    "  # BERT will be used as an embedding model/feature extractor\n",
    "  tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)\n",
    "  model = transformers.AutoModel.from_pretrained(\n",
    "      model_name, add_pooling_layer=True).to(device)\n",
    "  model_input_args = list(model.forward.__code__.co_varnames)\n",
    "\n",
    "  def tokenize_function(examples):\n",
    "    tokenized_examples = tokenizer(\n",
    "        examples[\"bio\"],\n",
    "        padding=False,\n",
    "        max_length=tokenizer.model_max_length,\n",
    "        truncation=True,\n",
    "    )\n",
    "    tokenized_examples[\"labels\"] = examples[\"title\"]\n",
    "    tokenized_examples[\"group_labels\"] = examples[\"gender\"]\n",
    "    return tokenized_examples\n",
    "\n",
    "  def embedding_fn(dataloader):\n",
    "    model.eval()\n",
    "    embeds = []\n",
    "    with torch.no_grad():\n",
    "      for batch in tqdm.tqdm(dataloader, desc=\"Inference\"):\n",
    "        batch = {\n",
    "            k: v.to(device) for k, v in batch.items() if k in model_input_args\n",
    "        }\n",
    "        outputs = model(**batch, output_hidden_states=True)\n",
    "        embeds.append(outputs.pooler_output.cpu().numpy())\n",
    "    return np.concatenate(embeds, axis=0)\n",
    "\n",
    "  tokenized_dataset = raw_dataset.map(\n",
    "      tokenize_function,\n",
    "      batched=True,\n",
    "      remove_columns=raw_dataset.column_names,\n",
    "      desc=\"Running tokenizer on dataset\",\n",
    "  )\n",
    "\n",
    "  data_collator = transformers.DataCollatorWithPadding(tokenizer)\n",
    "  dataloader = torch.utils.data.DataLoader(\n",
    "      tokenized_dataset,\n",
    "      collate_fn=data_collator,\n",
    "      batch_size=batch_size,\n",
    "  )\n",
    "  embeds = embedding_fn(dataloader)\n",
    "\n",
    "  with open(f\"{data_dir}/bert_embeds.pkl\", \"wb\") as f:\n",
    "    pickle.dump(embeds, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "25a1c647",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-05-18T04:57:47.178883Z",
     "iopub.status.busy": "2023-05-18T04:57:47.178619Z",
     "iopub.status.idle": "2023-05-18T04:57:47.560104Z",
     "shell.execute_reply": "2023-05-18T04:57:47.558605Z"
    },
    "papermill": {
     "duration": 0.395091,
     "end_time": "2023-05-18T04:57:47.563769",
     "exception": false,
     "start_time": "2023-05-18T04:57:47.168678",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "(embeds_, train_data_pre, labels_, train_labels_pre, groups_,\n",
    " train_groups_pre) = sklearn.model_selection.train_test_split(\n",
    "     embeds,\n",
    "     labels,\n",
    "     groups,\n",
    "     test_size=(1 - split_ratio_for_test) *\n",
    "     (1 - split_ratio_for_postprocessing),\n",
    "     random_state=seed_for_pretraining,\n",
    " )\n",
    "\n",
    "n_samples = len(embeds)\n",
    "n_test = int(n_samples * split_ratio_for_test)\n",
    "n_post = n_samples - len(train_data_pre) - n_test\n",
    "\n",
    "# # Train logistic regression models on the pre-training data on each group\n",
    "# # separately\n",
    "# predictors = []\n",
    "# for a in range(n_groups):\n",
    "#   predictor = sklearn.linear_model.LogisticRegression(\n",
    "#       random_state=seed_for_pretraining, max_iter=10000)\n",
    "#   predictor.fit(train_data_pre[train_groups_pre == a],\n",
    "#                 train_labels_pre[train_groups_pre == a])\n",
    "#   predictors.append(predictor)\n",
    "\n",
    "# # Get predicted probabilities (uncalibrated)\n",
    "# probas_ = np.empty((len(embeds_), n_classes))\n",
    "# for a, predictor in enumerate(predictors):\n",
    "#   probas_[groups_ == a] = predictor.predict_proba(embeds_[groups_ == a])\n",
    "\n",
    "# # Train calibrated logistic regression models on the pre-training data on each\n",
    "# # group separately\n",
    "# predictors_cal = []\n",
    "# for a in range(n_groups):\n",
    "#   predictor_cal = sklearn.calibration.CalibratedClassifierCV(\n",
    "#       sklearn.linear_model.LogisticRegression(random_state=seed_for_pretraining,\n",
    "#                                               max_iter=10000),\n",
    "#       cv=5,\n",
    "#       method='isotonic',\n",
    "#   )\n",
    "#   predictor_cal.fit(train_data_pre[train_groups_pre == a],\n",
    "#                     train_labels_pre[train_groups_pre == a])\n",
    "#   predictors_cal.append(predictor_cal)\n",
    "\n",
    "# # Get predicted probabilities (calibrated)\n",
    "# probas_cal_ = np.empty((len(embeds_), n_classes))\n",
    "# for a, predictor_cal in enumerate(predictors_cal):\n",
    "#   probas_cal_[groups_ == a] = predictor_cal.predict_proba(embeds_[groups_ == a])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "fa61f120",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-05-18T04:57:47.586511Z",
     "iopub.status.busy": "2023-05-18T04:57:47.586250Z",
     "iopub.status.idle": "2023-05-18T04:57:47.623109Z",
     "shell.execute_reply": "2023-05-18T04:57:47.621474Z"
    },
    "papermill": {
     "duration": 0.050676,
     "end_time": "2023-05-18T04:57:47.626479",
     "exception": false,
     "start_time": "2023-05-18T04:57:47.575803",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "with open(f\"{data_dir}/probas_cal.pkl\", \"rb\") as f:\n",
    "  probas_cal_ = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "67408d45",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-05-18T04:57:47.649818Z",
     "iopub.status.busy": "2023-05-18T04:57:47.649567Z",
     "iopub.status.idle": "2023-05-18T04:57:47.654148Z",
     "shell.execute_reply": "2023-05-18T04:57:47.652846Z"
    },
    "papermill": {
     "duration": 0.016957,
     "end_time": "2023-05-18T04:57:47.656219",
     "exception": false,
     "start_time": "2023-05-18T04:57:47.639262",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n",
    "# with open(f\"{data_dir}/probas_cal.pkl\", \"wb\") as f:\n",
    "#   pickle.dump(probas_cal_, f)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "9d7a9981",
   "metadata": {
    "papermill": {
     "duration": 0.005659,
     "end_time": "2023-05-18T04:57:47.668562",
     "exception": false,
     "start_time": "2023-05-18T04:57:47.662903",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "## TPR parity post-processing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "bac696ec",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-05-18T04:57:47.680766Z",
     "iopub.status.busy": "2023-05-18T04:57:47.680263Z",
     "iopub.status.idle": "2023-05-18T18:33:00.516764Z",
     "shell.execute_reply": "2023-05-18T18:33:00.514948Z"
    },
    "papermill": {
     "duration": 48912.847034,
     "end_time": "2023-05-18T18:33:00.520712",
     "exception": false,
     "start_time": "2023-05-18T04:57:47.673678",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "01605ad0405042eb8ece0864bb4b077f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/220 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "With calibration:\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAVYAAAE9CAYAAABQnEoaAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAw4ElEQVR4nO3de1xUdf4/8NdwmwEUVJCLioCmIgKikIb8vAdeysq2r6631LRSVkvRTLMNbxtq6rZt4maZLmnFt7T9mpo6WRZmbl4zRfGaCHIJFAdFbjOf3x80k8MMMANnGId5PR8PHo+dz5zL+3zWXvOZzzlzjkwIIUBERJJxsHYBRETNDYOViEhiDFYiIokxWImIJMZgJSKSGIOViEhiDFYiIokxWImIJMZgJSKSGIPVzmzZsgUymUzvr23bthg0aBB27dplsLxMJsOSJUssUsvBgwcNamndujX69u2Lf//73wbLBwUFYcqUKRappS6XL1+GXC7Hjz/+2OT7NseSJUsgk8kk3+6AAQMwZ84cybfbnDlZuwCyjs2bNyMkJARCCOTl5eHdd9/FqFGjsHPnTowaNUq33I8//ogOHTpYtJY333wTgwcPBgAUFhYiNTUVU6ZMgUqlwuzZs3XLffHFF/Dw8LBoLcbMnz8fcXFxiImJafJ9PwiWL1+OuLg4zJw5E926dbN2ObZBkF3ZvHmzACCOHj2q115aWirkcrkYN25ck9Xy7bffCgDis88+02tXq9UiKChIxMTENFkttcnIyBAAxN69e61dSr2SkpKElP9J3717V/e/w8LCxPPPPy/Ztps7TgUQAEChUMDFxQXOzs567TWnAn777TckJCQgNDQULVq0gI+PD4YMGYL09HSDbW7YsAE9e/ZEixYt0LJlS4SEhOC1116rtxYHBwe0aNHCoBZjUwFZWVmYOHEifHx8IJfL0b17d6xduxYajUaSWjZs2AA/Pz/ExcXptQsh8OabbyIwMBAKhQLR0dFQKpUYNGgQBg0apLesSqXC/PnzERwcDBcXF7Rv3x5z5szB3bt39ZaTyWSYNWsWPvroI3Tv3h1ubm7o2bOn0Sma3bt3IzIyEnK5HMHBwVizZo3R+oUQSElJQWRkJFxdXdG6dWs888wzuHLlit5ygwYNQlhYGL7//nv069cPbm5ueO6553TvT5o0CR9//DFKSkrq7TPiVIDdUqvVqKqqghAC+fn5eOutt3D37l2MHz++zvVu3rwJAEhKSoKfnx/u3LmDL774AoMGDcKBAwd0ofLpp58iISEBs2fPxpo1a+Dg4IBLly4hIyPDYJsajQZVVVUAgKKiImzevBlnzpzBxo0b66zlt99+Q79+/VBRUYHly5cjKCgIu3btwvz583H58mWkpKSYXUtNu3fvxoABA+DgoD8GWbx4MZKTk/HCCy/g6aefxvXr1zF9+nRUVlaia9euuuVKS0sxcOBAZGdn47XXXkNERATOnj2LN954A7/88gu+/vprvXnR3bt34+jRo1i2bBlatGiB1atXY/To0cjMzESnTp0AAAcOHMCTTz6JmJgYfPrpp1Cr1Vi9ejXy8/MN6n/xxRexZcsWvPTSS1i1ahVu3ryJZcuWoV+/fvj555/h6+urWzY3NxcTJ07EggUL8Oabb+od86BBg/Dqq6/i4MGDelNFVAsrj5ipiWmnAmr+yeVykZKSYrA8AJGUlFTr9qqqqkRlZaUYOnSoGD16tK591qxZolWrVnXWop0KqPnn4OAgFi9ebLB8YGCgmDx5su71woULBQDx3//+V2+5mTNnCplMJjIzM02uxZj8/HwBQKxcuVKv/ebNm0Iul4uxY8fqtf/4448CgBg4cKCuLTk5WTg4OBhMvXz++ecCgNizZ4+uDYDw9fUVKpVK15aXlyccHBxEcnKyrq1v376iXbt24t69e7o2lUol2rRpozcVoK1n7dq1evu+fv26cHV1FQsWLNC1DRw4UAAQBw4cMNoXFRUVQiaTiVdffdXo+6SPUwF2KjU1FUePHsXRo0fx1VdfYfLkyfjLX/6Cd999t951//Wvf6F3795QKBRwcnKCs7MzDhw4gHPnzumW6dOnD4qLizFu3Dj83//9HwoLC2vd3qpVq3S1KJVKLFiwACtXrsQrr7xSZx3ffPMNQkND0adPH732KVOmQAiBb775xuxa7nfjxg0AgI+Pj177kSNHUF5ejjFjxui1P/LIIwgKCtJr27VrF8LCwhAZGYmqqird37BhwyCTyXDw4EG95QcPHoyWLVvqXvv6+sLHxwfXrl0DANy9exdHjx7F008/DYVCoVuuZcuWBiPJXbt2QSaTYeLEiXr79vPzQ8+ePQ323bp1awwZMsRoXzg7O6NVq1bIyckx3lmkh1MBdqp79+6Ijo7WvR4+fDiuXbuGBQsWYOLEiWjVqpXR9datW4d58+ZhxowZWL58Oby9veHo6Ii//vWvesE6adIkVFVV4f3338ef/vQnaDQaPPzww1ixYoXBfGWnTp30ann00Udx69YtrF27FtOmTUNISIjRWoqKigyCDADatWune9/cWu537949ANALsPu3e//XaK2abfn5+bh06ZLBfLFWzZD38vIyWEYul+tquXXrFjQaDfz8/AyWq9mWn58PIYTROgHopha0/P39jS6npVAodHVQ3RispBMREYF9+/bhwoULBqNAra1bt2LQoEHYsGGDXruxkxpTp07F1KlTcffuXXz//fdISkrC448/jgsXLiAwMLDeWoQQOH36dK3B6uXlhdzcXIN27UjT29u7UbVo19fOK9+/XwBG5zTz8vL0wt7b2xuurq748MMP69yHqVq3bg2ZTIa8vDyj+665bZlMhvT0dMjlcoPla7bVdw3srVu3zK7XXnEqgHROnToFAGjbtm2ty8hkMoP/IE+fPl3nxfPu7u4YMWIEFi9ejIqKCpw9e9bkWmp+Db/f0KFDkZGRgRMnTui1p6amQiaT6a6NbWgtgYGBcHV1xeXLl/Xa+/btC7lcjrS0NL32I0eO6L6yaz3++OO4fPkyvLy8EB0dbfBnbMRdF3d3d/Tp0wc7duxAWVmZrr2kpARffvmlwb6FEMjJyTG67/DwcJP3e+PGDZSVlSE0NNSseu0VR6x26syZM3pn4nfs2AGlUonRo0cjODi41vUef/xxLF++HElJSRg4cCAyMzOxbNkyBAcH67YHAM8//zxcXV0RGxsLf39/5OXlITk5GZ6ennj44Yf1tnnx4kUcOXIEAHD79m18/fXX2LRpE6Kjo9G/f/9aa5k7dy5SU1Px2GOPYdmyZQgMDMTu3buRkpKCmTNn6s7Om1PL/VxcXBATE6OrTatNmzZITExEcnIyWrdujdGjRyM7OxtLly6Fv7+/3tn0OXPmYPv27RgwYADmzp2LiIgIaDQaZGVlYf/+/Zg3bx769u1baw3GLF++HMOHD0dcXBzmzZsHtVqNVatWwd3dXW90HRsbixdeeAFTp07FsWPHMGDAALi7uyM3NxeHDh1CeHg4Zs6cadI+tX1g7MOKjLDuuTNqasauCvD09BSRkZFi3bp1oqysTG951LgqoLy8XMyfP1+0b99eKBQK0bt3b/Gf//xHTJ48WQQGBuqW+/e//y0GDx4sfH19hYuLi2jXrp0YM2aMOH36tG4ZY1cFuLu7i9DQUJGUlCRu376tV0vNqwKEEOLatWti/PjxwsvLSzg7O4tu3bqJt956S6jVarNqqc2mTZuEo6OjuHHjhl67RqMRK1asEB06dBAuLi4iIiJC7Nq1S/Ts2VPv6gghhLhz5454/fXXRbdu3YSLi4vw9PQU4eHhYu7cuSIvL0+vr//yl78Y1GDsuHfu3CkiIiKEi4uL6Nixo1i5cmWtPxD48MMPRd++fYW7u7twdXUVnTt3Fs8++6w4duyYbpmBAweKHj161NoPkyZNEuHh4XX2Ff1BJgSf0kpUm7KyMnTs2BHz5s3Dq6++WueyV69eRUhICJKSkkz68YGtUKlUaNeuHf7+97/j+eeft3Y5NoHBSlSPDRs2YMmSJbhy5Qrc3d0BAD///DM++eQT9OvXDx4eHsjMzMTq1auhUqlw5syZWs/E26KlS5ciLS0Np0+fhpMTZw9NwV4iqscLL7yA4uJiXLlyRXfCx93dHceOHcOmTZtQXFwMT09PDBo0CH/729+aVagCgIeHB7Zs2cJQNQNHrEREEuPlVkREEmOwEhFJjMFKRCQxu5uN1mg0uHHjBlq2bGmRx1gQUfMlhEBJSQnatWtncCvJ+9ldsN64cQMBAQHWLoOIbNj169frfGSR3QWr9pZs169ft8rzk7QqKyuxf/9+xMfH13rno+bO3vvA3o8fsL0+UKlUCAgI0Lu1ozF2F6zar/8eHh5WD1Y3Nzd4eHjYxD8oS7D3PrD34wdstw/qm0a0+smrlJQUBAcHQ6FQICoqyuizk+5XXl6OxYsXIzAwEHK5HJ07d671lmxERNZg1RFrWloa5syZg5SUFMTGxuK9997DiBEjkJGRgY4dOxpdZ8yYMcjPz8emTZvw0EMPoaCgQO+uSkRE1mbVYF23bh2mTZuG6dOnAwDefvtt7Nu3Dxs2bEBycrLB8nv37sV3332HK1euoE2bNgBg9v0siYgszWpTARUVFTh+/Dji4+P12uPj43H48GGj6+zcuRPR0dFYvXo12rdvj65du2L+/Pl8XAQRmUytEfjxchH+71QOfrxcBLVG+l/1W23EWlhYCLVabXDDCl9fX6OPnQCAK1eu4NChQ1AoFPjiiy9QWFiIhIQE3Lx5s9Z51vLycpSXl+teq1QqANWT5pWVlRIdjfm0+7ZmDdZm731g78cPNH0f7Dubj+W7zyG/pELX5tvSBX99rDuG9aj/5jmm1mn1qwJqnl0TQtR6xk2j0UAmk2Hbtm3w9PQEUD2d8Mwzz2D9+vVwdXU1WCc5ORlLly41aN+/fz/c3NwkOILGUSqV1i7B6uy9D+z9+IGm6YOfi2T48IL2S/ofGZNfUo5Zn57Cc1016OlV9+i1tLTUpH1ZLVi1T/esOTotKCio9bZr/v7+aN++vS5UgeqnjQohkJ2djS5duhiss2jRIiQmJupea69Di4+Pt/rlVkqlEnFxcTZ1mYmU7L0P7P34gabrA7VGIGnltwCMneiuDtkd2XIsmDAYjg61X0ql/cZbH6sFq4uLC6KionTPWdJSKpV48sknja4TGxuLzz77DHfu3EGLFi0AABcuXICDg0Otv4KQy+VGn1Dp7Oz8QPxjflDqsCZ77wN7P37A8n3w08VCFN+r++qh4tIqHM9SIbZL7U+iNbVGq17HmpiYiA8++AAffvghzp07h7lz5yIrKwszZswAUD3afPbZZ3XLjx8/Hl5eXpg6dSoyMjLw/fff45VXXsFzzz1ndBqAiAgAfrxSKOly9bHqHOvYsWNRVFSEZcuWITc3F2FhYdizZ4/uOe+5ubnIysrSLd+iRQsolUrMnj0b0dHR8PLywpgxY7BixQprHQIR2QRTb7gkzY2ZrH7yKiEhAQkJCUbf27Jli0FbSEgIJ/uJyCwxnb3w7reXTFpOClb/SSsRkaU90skLrdzqnh9t7eaMRzoxWImITOLoIMPKp8PrXCb56fA6rwgwB4OViOzC8DB//Gtib/h5KPTa/T0V+NfE3hge5i/Zvqw+x0pE1FSGh/kjLtQPP129iYKSMvi0VKBPcBvJRqpaDFYisiuODjLJTlLVhlMBREQSY7ASEUmMwUpEJDEGKxGRxBisREQSY7ASEUmMwUpEJDEGKxGRxBisREQSY7ASEUmMwUpEJDEGKxGRxBisREQSY7ASEUmMwUpEJDEGKxGRxBisREQSY7ASEUmMwUpEJDEGKxGRxBisREQSY7ASEUmMwUpEJDEGKxGRxBisREQSY7ASEUmMwUpEJDEGKxGRxBisREQSY7ASEUmMwUpEJDEGKxGRxBisREQSY7ASEUmMwUpEJDEGKxGRxBisREQSY7ASEUmMwUpEJDEGKxGRxBisREQSY7ASEUnM6sGakpKC4OBgKBQKREVFIT09vdZlDx48CJlMZvB3/vz5JqyYiKhuVg3WtLQ0zJkzB4sXL8bJkyfRv39/jBgxAllZWXWul5mZidzcXN1fly5dmqhiIqL6WTVY161bh2nTpmH69Ono3r073n77bQQEBGDDhg11rufj4wM/Pz/dn6OjYxNVTERUPydr7biiogLHjx/HwoUL9drj4+Nx+PDhOtft1asXysrKEBoaitdffx2DBw+uddny8nKUl5frXqtUKgBAZWUlKisrG3EEjaPdtzVrsDZ77wN7P37A9vrA1DqtFqyFhYVQq9Xw9fXVa/f19UVeXp7Rdfz9/bFx40ZERUWhvLwcH330EYYOHYqDBw9iwIABRtdJTk7G0qVLDdr3798PNze3xh9IIymVSmuXYHX23gf2fvyA7fRBaWmpSctZLVi1ZDKZ3mshhEGbVrdu3dCtWzfd65iYGFy/fh1r1qypNVgXLVqExMRE3WuVSoWAgADEx8fDw8NDgiNomMrKSiiVSsTFxcHZ2dlqdViTvfeBvR8/YHt9oP3GWx+rBau3tzccHR0NRqcFBQUGo9i6PPLII9i6dWut78vlcsjlcoN2Z2fnB+L/yAelDmuy9z6w9+MHbKcPTK3RaievXFxcEBUVZfAVQKlUol+/fiZv5+TJk/D395e6PCKiBrPqVEBiYiImTZqE6OhoxMTEYOPGjcjKysKMGTMAVH+Nz8nJQWpqKgDg7bffRlBQEHr06IGKigps3boV27dvx/bt2615GEREeqwarGPHjkVRURGWLVuG3NxchIWFYc+ePQgMDAQA5Obm6l3TWlFRgfnz5yMnJweurq7o0aMHdu/ejZEjR1rrEIiIDFj95FVCQgISEhKMvrdlyxa91wsWLMCCBQuaoCoiooaz+k9aiYiaGwYrEZHEGKxERBJjsBIRSYzBSkQkMQYrEZHEGKxERBJjsBIRSYzBSkQkMQYrEZHEGKxERBJrVLCWlZVJVQcRUbNhdrBqNBosX74c7du3R4sWLXDlyhUAwF//+lds2rRJ8gKJiGyN2cG6YsUKbNmyBatXr4aLi4uuPTw8HB988IGkxRER2SKzgzU1NRUbN27EhAkT9B47HRERgfPnz0taHBGRLTI7WHNycvDQQw8ZtGs0Gpt5hC0RkSWZHaw9evRAenq6Qftnn32GXr16SVIUEZEtM/sJAklJSZg0aRJycnKg0WiwY8cOZGZmIjU1Fbt27bJEjURENsXsEeuoUaOQlpaGPXv2QCaT4Y033sC5c+fw5ZdfIi4uzhI1EhHZlAY982rYsGEYNmyY1LUQETULZo9YO3XqhKKiIoP24uJidOrUSZKiiIhsmdnB+uuvv0KtVhu0l5eXIycnR5KiiIhsmclTATt37tT973379sHT01P3Wq1W48CBAwgKCpK0OCIiW2RysD711FMAAJlMhsmTJ+u95+zsjKCgIKxdu1bS4oiIbJHJwarRaAAAwcHBOHr0KLy9vS1WFBGRLTP7qoCrV69aog4iomajQZdb3b17F9999x2ysrJQUVGh995LL70kSWFERLbK7GA9efIkRo4cidLSUty9exdt2rRBYWEh3Nzc4OPjw2AlIrtn9uVWc+fOxahRo3Dz5k24urriyJEjuHbtGqKiorBmzRpL1EhEZFPMDtZTp05h3rx5cHR0hKOjI8rLyxEQEIDVq1fjtddes0SNREQ2xexgdXZ2hkwmAwD4+voiKysLAODp6an730RE9szsOdZevXrh2LFj6Nq1KwYPHow33ngDhYWF+OijjxAeHm6JGomIbIrZI9Y333wT/v7+AIDly5fDy8sLM2fOREFBATZu3Ch5gUREtsasEasQAm3btkWPHj0AAG3btsWePXssUhgRka0ya8QqhECXLl2QnZ1tqXqIiGyeWcHq4OCALl26GL1tIBERVTN7jnX16tV45ZVXcObMGUvUQ0Rk88y+KmDixIkoLS1Fz5494eLiAldXV733b968KVlxRES2yOxgffvtty1QBhFR82F2sNa8FysREekze46ViIjqxmAlIpIYg5WISGIMViIiiZkVrFVVVXBycuI1rEREdTArWJ2cnBAYGAi1Wm2peoiIbJ7ZUwGvv/46Fi1aJNkPAVJSUhAcHAyFQoGoqCikp6ebtN4PP/wAJycnREZGSlIHEZFUzL6O9Z133sGlS5fQrl07BAYGwt3dXe/9EydOmLyttLQ0zJkzBykpKYiNjcV7772HESNGICMjAx07dqx1vdu3b+PZZ5/F0KFDkZ+fb+4hEBFZlNnB+tRTT0m283Xr1mHatGmYPn06gOpfde3btw8bNmxAcnJyreu9+OKLGD9+PBwdHfGf//xHsnqIiKRgdrAmJSVJsuOKigocP34cCxcu1GuPj4/H4cOHa11v8+bNuHz5MrZu3YoVK1bUu5/y8nKUl5frXqtUKgBAZWUlKisrG1h942n3bc0arM3e+8Dejx+wvT4wtU6zg1Xr+PHjOHfuHGQyGUJDQ9GrVy+z1i8sLIRarYavr69eu6+vL/Ly8oyuc/HiRSxcuBDp6elwcjKt9OTkZCxdutSgff/+/XBzczOrZktQKpXWLsHq7L0P7P34Advpg9LSUpOWMztYCwoK8Oc//xkHDx5Eq1atIITA7du3MXjwYHz66ado27atWdvTPphQSwhh0AYAarUa48ePx9KlS9G1a1eTt79o0SIkJibqXqtUKgQEBCA+Ph4eHh5m1SqlyspKKJVKxMXFwdnZ2Wp1WJO994G9Hz9ge32g/cZbH7ODdfbs2VCpVDh79iy6d+8OAMjIyMDkyZPx0ksv4ZNPPjFpO97e3nB0dDQYnRYUFBiMYgGgpKQEx44dw8mTJzFr1iwAgEajgRACTk5O2L9/P4YMGWKwnlwuh1wuN2h3dnZ+IP6PfFDqsCZ77wN7P37AdvrA1BrNDta9e/fi66+/1oUqAISGhmL9+vWIj483eTsuLi6IioqCUqnE6NGjde1KpRJPPvmkwfIeHh745Zdf9NpSUlLwzTff4PPPP0dwcLC5h0JEZBFmB6tGozGa2s7OztBoNGZtKzExEZMmTUJ0dDRiYmKwceNGZGVlYcaMGQCqv8bn5OQgNTUVDg4OCAsL01vfx8cHCoXCoJ2IyJrMDtYhQ4bg5ZdfxieffIJ27doBAHJycjB37lwMHTrUrG2NHTsWRUVFWLZsGXJzcxEWFoY9e/YgMDAQAJCbm4usrCxzSyQisiqzf3n17rvvoqSkBEFBQejcuTMeeughBAcHo6SkBP/85z/NLiAhIQG//vorysvLcfz4cQwYMED33pYtW3Dw4MFa112yZAlOnTpl9j6JiCzJ7BFrQEAATpw4AaVSifPnz0MIgdDQUDz66KOWqI+IyOaYFaxVVVVQKBQ4deoU4uLiEBcXZ6m6iIhsFu9uRUQkMavf3YqIqLmx6t2tiIiaI6ve3YqIqDky++QVADz33HMICAiwSEFERLbO7JNXa9as4ckrIqI6mH3yaujQoXVetE9EZO/MnmMdMWIEFi1ahDNnziAqKsrg5NUTTzwhWXFERLbI7GCdOXMmgOrHqtQkk8k4TUBEdq9Bd7ciIqLamT3HSkREdTM5WEeOHInbt2/rXv/tb39DcXGx7nVRURFCQ0MlLY6IyBaZHKz79u3Te9rpqlWr9H7WWlVVhczMTGmrIyKyQSYHqxCiztdERFSNc6xERBIzOVhlMpnBY6mNPaaaiMjemXy5lRACU6ZM0T1KuqysDDNmzND9QOD++VciIntmcrBOnjxZ7/XEiRMNlnn22WcbXxERkY0zOVg3b95syTqIiJoNnrwiIpIYg5WISGIMViIiiTFYiYgkxmAlIpIYg5WISGIMViIiiTFYiYgkxmAlIpIYg5WISGIMViIiiTFYiYgkxmAlIpIYg5WISGIMViIiiTFYiYgkxmAlIpIYg5WISGIMViIiiTFYiYgkxmAlIpIYg5WISGIMViIiiTFYiYgkxmAlIpKY1YM1JSUFwcHBUCgUiIqKQnp6eq3LHjp0CLGxsfDy8oKrqytCQkLw97//vQmrJSKqn5M1d56WloY5c+YgJSUFsbGxeO+99zBixAhkZGSgY8eOBsu7u7tj1qxZiIiIgLu7Ow4dOoQXX3wR7u7ueOGFF6xwBEREhqw6Yl23bh2mTZuG6dOno3v37nj77bcREBCADRs2GF2+V69eGDduHHr06IGgoCBMnDgRw4YNq3OUS0TU1KwWrBUVFTh+/Dji4+P12uPj43H48GGTtnHy5EkcPnwYAwcOtESJREQNYrWpgMLCQqjVavj6+uq1+/r6Ii8vr851O3TogN9++w1VVVVYsmQJpk+fXuuy5eXlKC8v171WqVQAgMrKSlRWVjbiCBpHu29r1mBt9t4H9n78gO31gal1WnWOFQBkMpneayGEQVtN6enpuHPnDo4cOYKFCxfioYcewrhx44wum5ycjKVLlxq079+/H25ubg0vXCJKpdLaJVidvfeBvR8/YDt9UFpaatJyVgtWb29vODo6GoxOCwoKDEaxNQUHBwMAwsPDkZ+fjyVLltQarIsWLUJiYqLutUqlQkBAAOLj4+Hh4dHIo2i4yspKKJVKxMXFwdnZ2Wp1WJO994G9Hz9ge32g/cZbH6sFq4uLC6KioqBUKjF69Ghdu1KpxJNPPmnydoQQel/1a5LL5ZDL5Qbtzs7OD8T/kQ9KHdZk731g78cP2E4fmFqjVacCEhMTMWnSJERHRyMmJgYbN25EVlYWZsyYAaB6tJmTk4PU1FQAwPr169GxY0eEhIQAqL6udc2aNZg9e7bVjoGIqCarBuvYsWNRVFSEZcuWITc3F2FhYdizZw8CAwMBALm5ucjKytItr9FosGjRIly9ehVOTk7o3LkzVq5ciRdffNFah0BEZMDqJ68SEhKQkJBg9L0tW7bovZ49ezZHp0T0wLP6T1qJiJobBisRkcQYrEREErP6HCtVU2sEfrp6EwUlZfBpqUCf4DZwdKj7hxJE9GBisFqZWiOQ8vVFbP7hKorv/fFzOX9PBZJGhWJ4mL8VqyOihuBUgBX9XCTDIyu/xd+/vqAXqgCQd7sMM7eewN4zuVaqjogaiiNWE1WpNZJub88vefjwggOAKqPvCwAyAEu/zMDgbj4WmRZwcuTnKpElMFhN9OlPWfUvZCKNEFiz/0K9ywkAubfLsGrveXTydm/UPh0cDEN0fF/Dm4kTUeMxWE109Nebkm2r8E45VGVVqB6T1u/n68UoulP7/RBM8XBQGwDGA5aIpMVgNZGADEJIMx1wr0Jt1vIuTjJohGjUPtWa6vX/3KdDo7ZDRPXj8MVEDwe1gUzmIMmfwsX0zzMnBxm8Wygavc+T12/j5PXbcHJ00P0RkWVwxGqicX0CJNuWRlRfs1qlqX8U2tW3BfoEe0m2byKyPAariZwcHSQL131n80wKVQBYODwE/R7ylmS/RNQ0GKxmuP/rc0Mvv1JrBJbvOmfSsq3cnNG3k3mjVX7FJ7I+BmsD/e+x7Aatd+W3O8hTlZm07Mgwf2w/kWPW9nkJFZH1cXjTxErKjP8goKbYzl4Ia+9p4WqIyBI4Ym2gMdENu2ypYxtXpB27Xu9yMwd1xiNmTgMQ0YOBwdpADZ3LjOnsDX9PBfJul8HY6SsZAD9PBWI6e/PuVkQ2ilMBTczRQYakUaG/vzJ+ZUDSqFCGKpENY7BawfAwf/zzzz3hZuT7gqfbg/8IYCKqG4PVikqNnMe6XVrJ2wUS2TjOsTZSQ65nVWsEVuw5b/Q9U28XqNYIHP31Jn4rKUfblnI8HFT9xAFex0pkfQzWRmrI9azV17KWo7a7W2lvF/jWvkx0atvC4P0zObex63QuVGV/3BzbQ+GMxyP88ebT4XzMC5GVMVitwNRrWY0tdybnNj42cm9YVVklPv4pCy0Vjtj5cy5yb//xIwQ+5oWoaTFYG8nc61mr1Bq0b6Uw6VrWkeF+6BvcRvdarRF450DdN8h+7/urBm3ax7y8O74XhvXwq3N9TiUQNR6DtZHMDaL/PZYNjRDwUDj9/lXe+Fd0T1cnXCu6i+s3S3VtVwrv/j6FYB7tRV2vffELiu6Uw0FmfJ8aIdCpbQu9eVu5s6PZ+yOydwxWK3CQyTAyzA+f1jFq7ebrgRNZt/Xasm+V1rK0aW7fq8K+swVo21Ju8N6N4ns4nVOMsso/TsZ5KJyx+plwk6cQOLdLVI3B2sS0Uweje/pCobqGvXlueqNQf08FXn+sO27erTRYt427M45du9Wo/bdrpUBEh1Z6bWdv3MZPRh49oyqrvvRrw8Te9Ybr3jO5WPplBud2icBgbXLaqQOhcUAvL4GFEwbgZHaJwSjP2GVcao3Anl/ykK8y/nNYUzwe4W8wb7v+24t1rlPfpV/7zuZh1scnDWrSzu2aEsymaqpRMUff1BgMVitzdJAhprPhzVaMzd06OQJLngjFzK0nIIP+D2JrvjbG09UZ12/dQ07xDV3bH5d+GVffk2K1T5w1tu/6rsk19xrgphoVN2Y/FVUa/Pvwrzj66024uTjiT706oF8Xy933obEfANr1c26VYt/ZPNyrVCPY2x2vjQyFqwvn1xuKwWpjhof5Y8PE3gb/4ft5KvBET3+jVwVoPRbub3DiytRLv05n3zY6PfFbifaJs8bVdU2uWq2Gh0l7rw67mVtPWHxU3Jj9JO/JwMb0q7j/uY//OXUDbi6OWDemp+RTIo39oNl7Jhcrdp9D9q17ujZHBxkOXSrCR0eyEBfqg/effVjSmu0Fg9UGDQ/zR1yon9GRSkSHVli+65zezbS187bGLrUy9TaGER08Eext+GOF09nFJtVsaoADhiNZtUZgyc4Mk0fFVWpNgy4bM3c/91u19zzeTzf+oVZaocaMrSewfnwvPBbRzuy6jGnsB83eM7mYue0EhnTzwTvjeqGbb0tk5pfg3W8u4ZvzBQAAZUYBnk89ynBtAAarjaptCuGxiHYYHuZv8tdDU29j+OrwEKPbCPZ2w2fH6//12aie/gb3l62srMT+facNlq35a7b6nrpQc1T809Wb6HPfPLKpTN3Pqr3nEejlBs3vzy2r0mjwQS2her9XP/8Z/6+zl64fHWXmzZRrP3DM/QCo+UghtUZgxe5zGBLig/cnRcPh93p6d2yND56NxvTUY/juwm9QawSUGQW4V6HmtICZGKzNUG2hW9uySaNqn7cFqm9jWNv1rI25v6zQmDaqbMwv1cxhzrTIiWvFKLxTPTetKqs06WTinQoNnt7wI1oqqu9g9qfe7UyeCgH++MAx94Pm/sf1/O+xbFz57Q6yb93DO+N66UJVy8FBhr8Mfkg3agWAN/dkYPlT4WZUSgxWqnPetr75OlOD2ZwTKjV/zWbqdIV2VPx0r3YNmgowZ1qk8E4FxO+TqfcqzZjmMPHpvHVp7AeNtr2bb0uj73fz02//tahx10/bIwYrAah73taUdRsazMbUDEVzR8WWfrrDq8NDIITQfTVPPXINq/fV/VNjraRR3dEnqHqawlEmjE6F1Eb7gWPuB03NbWjXz8wvQe+OrQ3Wy8wr0Xsd5OVmco1UjcFKOuZMIdTUmGA2pS6pR8WN2U/NaZHp/Ttjzf4LqG8w6tvSBQO7+erqrKw0vMqiLtoPjMZMvzg5OiCmszc6tHbF+m8v6c2xAoBGI7D+20twdJBB/fsBvTYyFGQe3nGDJKMN5icj2yPmvpM0UtCOiv08FXrtfp4KSX+A0JD9uDg54Pn+wfVue+mTYZJ+0ACGd5ow5YPG0UGG1x/rjm/OF2B66jEcv3YLd8qrcPzaLUxPPYZvzhfoQjUu1IcnrhqAI1ayGZYcFTd2P4t+H9XVvI4VgEWuY23s9MvwMH9smNAbK3afw582HNa133+MvI614RisZFMaM11h6f0sGhmKefEhTfbLq8Z+0Ny/Pn95JS0GK5GEXJwc8PyATnh+QKcm2V9jP2j+WN8Lz0QHSFeYneMcKxGRxBisREQSY7ASEUnM7uZYtb+WUalUVq2jsrISpaWlUKlUcHZ2tmot1mLvfWDvxw/YXh9oc0PUvPSjBrsL1pKS6l+VBARwop6IGqakpASenp61vi8T9UVvM6PRaHDjxg20bNkSsloeqtcUVCoVAgICcP36dXh4mHMrjubD3vvA3o8fsL0+EEKgpKQE7dq1g4ND7TOpdjdidXBwQIcO5j2y2pI8PDxs4h+UJdl7H9j78QO21Qd1jVS1ePKKiEhiDFYiIokxWK1ELpcjKSkJcrnc2qVYjb33gb0fP9B8+8DuTl4REVkaR6xERBJjsBIRSYzBSkQkMQYrEZHEGKwWlJKSguDgYCgUCkRFRSE9Pb3WZXfs2IG4uDi0bdsWHh4eiImJwb59+5qwWsswpw/u98MPP8DJyQmRkZGWLdDCzD3+8vJyLF68GIGBgZDL5ejcuTM+/PDDJqrWMsztg23btqFnz55wc3ODv78/pk6diqKioiaqViKCLOLTTz8Vzs7O4v333xcZGRni5ZdfFu7u7uLatWtGl3/55ZfFqlWrxE8//SQuXLggFi1aJJydncWJEyeauHLpmNsHWsXFxaJTp04iPj5e9OzZs2mKtYCGHP8TTzwh+vbtK5RKpbh69ar473//K3744YcmrFpa5vZBenq6cHBwEP/4xz/ElStXRHp6uujRo4d46qmnmrjyxmGwWkifPn3EjBkz9NpCQkLEwoULTd5GaGioWLp0qdSlNZmG9sHYsWPF66+/LpKSkmw6WM09/q+++kp4enqKoqKipiivSZjbB2+99Zbo1KmTXts777wjOnToYLEaLYFTARZQUVGB48ePIz4+Xq89Pj4ehw8frmUtfRqNBiUlJWjTpo0lSrS4hvbB5s2bcfnyZSQlJVm6RItqyPHv3LkT0dHRWL16Ndq3b4+uXbti/vz5uHfvXlOULLmG9EG/fv2QnZ2NPXv2QAiB/Px8fP7553jssceaomTJ2N1NWJpCYWEh1Go1fH199dp9fX2Rl5dn0jbWrl2Lu3fvYsyYMZYo0eIa0gcXL17EwoULkZ6eDicn2/6n2ZDjv3LlCg4dOgSFQoEvvvgChYWFSEhIwM2bN21ynrUhfdCvXz9s27YNY8eORVlZGaqqqvDEE0/gn//8Z1OULBmOWC2o5m0JhRAm3arwk08+wZIlS5CWlgYfHx9LldckTO0DtVqN8ePHY+nSpejatWtTlWdx5vwb0Gg0kMlk2LZtG/r06YORI0di3bp12LJli82OWgHz+iAjIwMvvfQS3njjDRw/fhx79+7F1atXMWPGjKYoVTK2PSx4QHl7e8PR0dHgU7mgoMDg07umtLQ0TJs2DZ999hkeffRRS5ZpUeb2QUlJCY4dO4aTJ09i1qxZAKqDRggBJycn7N+/H0OGDGmS2qXQkH8D/v7+aN++vd5t6bp37w4hBLKzs9GlSxeL1iy1hvRBcnIyYmNj8corrwAAIiIi4O7ujv79+2PFihXw9/e3eN1S4IjVAlxcXBAVFQWlUqnXrlQq0a9fv1rX++STTzBlyhR8/PHHNjenVJO5feDh4YFffvkFp06d0v3NmDED3bp1w6lTp9C3b9+mKl0SDfk3EBsbixs3buDOnTu6tgsXLjxw9xA2VUP6oLS01OAG0o6OjgDqfxzKA8V6582aN+1lJps2bRIZGRlizpw5wt3dXfz6669CCCEWLlwoJk2apFv+448/Fk5OTmL9+vUiNzdX91dcXGytQ2g0c/ugJlu/KsDc4y8pKREdOnQQzzzzjDh79qz47rvvRJcuXcT06dOtdQiNZm4fbN68WTg5OYmUlBRx+fJlcejQIREdHS369OljrUNoEAarBa1fv14EBgYKFxcX0bt3b/Hdd9/p3ps8ebIYOHCg7vXAgQMFAIO/yZMnN33hEjKnD2qy9WAVwvzjP3funHj00UeFq6ur6NChg0hMTBSlpaVNXLW0zO2Dd955R4SGhgpXV1fh7+8vJkyYILKzs5u46sbhbQOJiCTGOVYiIokxWImIJMZgJSKSGIOViEhiDFYiIokxWImIJMZgJSKSGIOViEhiDFYiIokxWMkuHD58GDKZDMOHD7d2KWQH+JNWsgvTp09HaWkptm/fjosXL6Jjx47WLomaMY5Yqdm7e/cu0tLSMGfOHAwZMgRbtmyxdknUzDFYqdlLS0uDn58f+vTpgwkTJmDz5s22dW9PsjkMVmr2Nm3ahAkTJgAAnnrqKRQUFODAgQNQq9WIjIxEZGQk/Pz80KFDB0RGRqJ///4AACcnJ0RGRiIsLAz/8z//g9LSUr32Hj16YNSoUSguLrbWodEDinOs1KxlZmYiJCQEmZmZumdpjR8/HgDw8ccf65ZbsmQJvL29dY+FAaofLVJYWAgAmDBhAqKiopCYmKjXPmnSJISEhGDx4sVNdUhkAzhipWZt06ZNePjhh/UeUDhhwgTs2LEDt27dMnk7/fv3x6VLlwzaY2NjkZ2dLUmt1HwwWKnZqqqqQmpqqm6EqjVs2DC0bNkS27ZtM3k7X331FcLDw/Xa1Wo1lEolHn/8cclqpuaBT2mlZmvXrl3Iz89HWFgYzpw5o/de//79sWnTJr2v/jUVFxcjMjJSt/y0adP02rOzs9GjRw8MGzbMYsdAtonBSs3Wpk2bAABxcXG1LnPixAn07t3b6HutWrXCqVOnam0vLS1FXFwcUlJS8NJLL0lSMzUPnAqgZuvLL7+EqH5gZq1/tYWqKdzc3PCPf/wDa9euRVVVlYSVk61jsBI1QnR0NMLDw7F9+3Zrl0IPEF5uRUQkMY5YiYgkxmAlIpIYg5WISGIMViIiiTFYiYgkxmAlIpIYg5WISGIMViIiiTFYiYgkxmAlIpIYg5WISGIMViIiif1/eMphQrO/RhYAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 350x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr:last-of-type th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th colspan=\"2\" halign=\"left\">error_rate</th>\n",
       "      <th colspan=\"2\" halign=\"left\">delta_tpr</th>\n",
       "      <th colspan=\"2\" halign=\"left\">delta_tpr_rms</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>mean</th>\n",
       "      <th>std</th>\n",
       "      <th>mean</th>\n",
       "      <th>std</th>\n",
       "      <th>mean</th>\n",
       "      <th>std</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>alpha</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>inf</th>\n",
       "      <td>0.229155</td>\n",
       "      <td>0.000972</td>\n",
       "      <td>0.705258</td>\n",
       "      <td>0.019569</td>\n",
       "      <td>0.269118</td>\n",
       "      <td>0.006950</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.6500</th>\n",
       "      <td>0.229150</td>\n",
       "      <td>0.000967</td>\n",
       "      <td>0.709943</td>\n",
       "      <td>0.015646</td>\n",
       "      <td>0.271260</td>\n",
       "      <td>0.005905</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.6000</th>\n",
       "      <td>0.229150</td>\n",
       "      <td>0.000967</td>\n",
       "      <td>0.709943</td>\n",
       "      <td>0.015646</td>\n",
       "      <td>0.271260</td>\n",
       "      <td>0.005905</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.5500</th>\n",
       "      <td>0.229141</td>\n",
       "      <td>0.000951</td>\n",
       "      <td>0.709851</td>\n",
       "      <td>0.015610</td>\n",
       "      <td>0.267288</td>\n",
       "      <td>0.006723</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.5000</th>\n",
       "      <td>0.229178</td>\n",
       "      <td>0.001002</td>\n",
       "      <td>0.655175</td>\n",
       "      <td>0.022320</td>\n",
       "      <td>0.256251</td>\n",
       "      <td>0.005978</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.4500</th>\n",
       "      <td>0.229368</td>\n",
       "      <td>0.000982</td>\n",
       "      <td>0.572453</td>\n",
       "      <td>0.025638</td>\n",
       "      <td>0.244670</td>\n",
       "      <td>0.005649</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.4000</th>\n",
       "      <td>0.229743</td>\n",
       "      <td>0.000976</td>\n",
       "      <td>0.567343</td>\n",
       "      <td>0.028135</td>\n",
       "      <td>0.232569</td>\n",
       "      <td>0.006461</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.3500</th>\n",
       "      <td>0.230251</td>\n",
       "      <td>0.000952</td>\n",
       "      <td>0.568884</td>\n",
       "      <td>0.027024</td>\n",
       "      <td>0.221895</td>\n",
       "      <td>0.005658</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.3000</th>\n",
       "      <td>0.231176</td>\n",
       "      <td>0.000889</td>\n",
       "      <td>0.525235</td>\n",
       "      <td>0.060095</td>\n",
       "      <td>0.208697</td>\n",
       "      <td>0.007122</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.2500</th>\n",
       "      <td>0.232023</td>\n",
       "      <td>0.000889</td>\n",
       "      <td>0.443597</td>\n",
       "      <td>0.026007</td>\n",
       "      <td>0.186501</td>\n",
       "      <td>0.006854</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.2000</th>\n",
       "      <td>0.234182</td>\n",
       "      <td>0.000927</td>\n",
       "      <td>0.366152</td>\n",
       "      <td>0.035674</td>\n",
       "      <td>0.158805</td>\n",
       "      <td>0.006714</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.1500</th>\n",
       "      <td>0.237239</td>\n",
       "      <td>0.000906</td>\n",
       "      <td>0.316218</td>\n",
       "      <td>0.051414</td>\n",
       "      <td>0.131189</td>\n",
       "      <td>0.007345</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.1000</th>\n",
       "      <td>0.241528</td>\n",
       "      <td>0.001002</td>\n",
       "      <td>0.273978</td>\n",
       "      <td>0.051904</td>\n",
       "      <td>0.103434</td>\n",
       "      <td>0.008445</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.0800</th>\n",
       "      <td>0.244230</td>\n",
       "      <td>0.001097</td>\n",
       "      <td>0.252096</td>\n",
       "      <td>0.062920</td>\n",
       "      <td>0.092031</td>\n",
       "      <td>0.011132</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.0500</th>\n",
       "      <td>0.249226</td>\n",
       "      <td>0.001072</td>\n",
       "      <td>0.208501</td>\n",
       "      <td>0.054568</td>\n",
       "      <td>0.067516</td>\n",
       "      <td>0.008209</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.0200</th>\n",
       "      <td>0.263328</td>\n",
       "      <td>0.000962</td>\n",
       "      <td>0.205381</td>\n",
       "      <td>0.058814</td>\n",
       "      <td>0.068957</td>\n",
       "      <td>0.006951</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.0100</th>\n",
       "      <td>0.269798</td>\n",
       "      <td>0.000876</td>\n",
       "      <td>0.213110</td>\n",
       "      <td>0.076709</td>\n",
       "      <td>0.069983</td>\n",
       "      <td>0.008663</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.0080</th>\n",
       "      <td>0.271123</td>\n",
       "      <td>0.000839</td>\n",
       "      <td>0.212592</td>\n",
       "      <td>0.077248</td>\n",
       "      <td>0.070108</td>\n",
       "      <td>0.008751</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.0050</th>\n",
       "      <td>0.273009</td>\n",
       "      <td>0.000791</td>\n",
       "      <td>0.208199</td>\n",
       "      <td>0.074632</td>\n",
       "      <td>0.069947</td>\n",
       "      <td>0.008732</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.0020</th>\n",
       "      <td>0.274830</td>\n",
       "      <td>0.000815</td>\n",
       "      <td>0.216458</td>\n",
       "      <td>0.081039</td>\n",
       "      <td>0.071424</td>\n",
       "      <td>0.010026</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.0010</th>\n",
       "      <td>0.593474</td>\n",
       "      <td>0.001179</td>\n",
       "      <td>0.893687</td>\n",
       "      <td>0.002946</td>\n",
       "      <td>0.633710</td>\n",
       "      <td>0.002873</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.0001</th>\n",
       "      <td>0.593369</td>\n",
       "      <td>0.000648</td>\n",
       "      <td>0.894531</td>\n",
       "      <td>0.002875</td>\n",
       "      <td>0.632990</td>\n",
       "      <td>0.004363</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       error_rate           delta_tpr           delta_tpr_rms          \n",
       "             mean       std      mean       std          mean       std\n",
       "alpha                                                                  \n",
       "inf      0.229155  0.000972  0.705258  0.019569      0.269118  0.006950\n",
       "0.6500   0.229150  0.000967  0.709943  0.015646      0.271260  0.005905\n",
       "0.6000   0.229150  0.000967  0.709943  0.015646      0.271260  0.005905\n",
       "0.5500   0.229141  0.000951  0.709851  0.015610      0.267288  0.006723\n",
       "0.5000   0.229178  0.001002  0.655175  0.022320      0.256251  0.005978\n",
       "0.4500   0.229368  0.000982  0.572453  0.025638      0.244670  0.005649\n",
       "0.4000   0.229743  0.000976  0.567343  0.028135      0.232569  0.006461\n",
       "0.3500   0.230251  0.000952  0.568884  0.027024      0.221895  0.005658\n",
       "0.3000   0.231176  0.000889  0.525235  0.060095      0.208697  0.007122\n",
       "0.2500   0.232023  0.000889  0.443597  0.026007      0.186501  0.006854\n",
       "0.2000   0.234182  0.000927  0.366152  0.035674      0.158805  0.006714\n",
       "0.1500   0.237239  0.000906  0.316218  0.051414      0.131189  0.007345\n",
       "0.1000   0.241528  0.001002  0.273978  0.051904      0.103434  0.008445\n",
       "0.0800   0.244230  0.001097  0.252096  0.062920      0.092031  0.011132\n",
       "0.0500   0.249226  0.001072  0.208501  0.054568      0.067516  0.008209\n",
       "0.0200   0.263328  0.000962  0.205381  0.058814      0.068957  0.006951\n",
       "0.0100   0.269798  0.000876  0.213110  0.076709      0.069983  0.008663\n",
       "0.0080   0.271123  0.000839  0.212592  0.077248      0.070108  0.008751\n",
       "0.0050   0.273009  0.000791  0.208199  0.074632      0.069947  0.008732\n",
       "0.0020   0.274830  0.000815  0.216458  0.081039      0.071424  0.010026\n",
       "0.0010   0.593474  0.001179  0.893687  0.002946      0.633710  0.002873\n",
       "0.0001   0.593369  0.000648  0.894531  0.002875      0.632990  0.004363"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "alphas = [\n",
    "    np.inf, 0.65, 0.6, 0.55, 0.5, 0.45, 0.4, 0.35, 0.3, 0.25, 0.2, 0.15, 0.1,\n",
    "    0.08, 0.05, 0.02, 0.01, 0.008, 0.005, 0.002, 0.001, 0.0001\n",
    "]\n",
    "\n",
    "return_vals = process_map(\n",
    "    partial(\n",
    "        utils.postprocess,\n",
    "        postprocessor_factory=postprocess.PostProcessorTPR,\n",
    "        evaluate_fn=partial(utils.evaluate,\n",
    "                            n_groups=n_groups,\n",
    "                            n_classes=n_classes,\n",
    "                            metrics=['delta_tpr', 'delta_tpr_rms']),\n",
    "        probas=probas_cal_,\n",
    "        labels=labels_,\n",
    "        groups=groups_,\n",
    "        n_post=n_post,\n",
    "        n_test=n_test,\n",
    "    ),\n",
    "    [(alpha, seed) for alpha in alphas for seed in seeds],\n",
    "    max_workers=max_workers,\n",
    ")\n",
    "results = [{\n",
    "    'alpha': alpha,\n",
    "    **result\n",
    "} for alpha, _, result, _ in return_vals if result is not None]\n",
    "\n",
    "(fig, ax), df = utils.plot_results(results, 'delta_tpr')\n",
    "ax.set_xlabel(\"$\\\\Delta_{\\\\mathrm{TPR}}$\")\n",
    "ax.set_title(\"BiasBios (gender)\")\n",
    "print(\"With calibration:\")\n",
    "plt.show()\n",
    "display(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5eed3def",
   "metadata": {
    "papermill": {
     "duration": 0.008847,
     "end_time": "2023-05-18T18:33:00.543108",
     "exception": false,
     "start_time": "2023-05-18T18:33:00.534261",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "fair",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  },
  "papermill": {
   "default_parameters": {},
   "duration": 48929.440775,
   "end_time": "2023-05-18T18:33:03.062374",
   "environment_variables": {},
   "exception": null,
   "input_path": "tpr_biasbios_2_.ipynb",
   "output_path": "tpr_biasbios_2.ipynb",
   "parameters": {},
   "start_time": "2023-05-18T04:57:33.621599",
   "version": "2.3.4"
  },
  "vscode": {
   "interpreter": {
    "hash": "aac456e002ecb64114f81d56e2b750c6ea20a76d8646af90a7cd3cc27dac8d4c"
   }
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {
     "01605ad0405042eb8ece0864bb4b077f": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HBoxModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HBoxModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HBoxView",
       "box_style": "",
       "children": [
        "IPY_MODEL_e1c43a20352747bc8634191054d423e9",
        "IPY_MODEL_01ae05788b95426d98d0d19ca978511e",
        "IPY_MODEL_aba4515aad164445bff82e2713d22a6a"
       ],
       "layout": "IPY_MODEL_f499fcdc1c4d495f8e9d7d32809e03fb",
       "tabbable": null,
       "tooltip": null
      }
     },
     "01ae05788b95426d98d0d19ca978511e": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "FloatProgressModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "FloatProgressModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "ProgressView",
       "bar_style": "success",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_6c122677dd2a4f7b8c8616a379c22c6e",
       "max": 220.0,
       "min": 0.0,
       "orientation": "horizontal",
       "style": "IPY_MODEL_42c1b54101a642df9ddc3042c6bc0d21",
       "tabbable": null,
       "tooltip": null,
       "value": 220.0
      }
     },
     "13af2b8d6fde4042bbb9211c4c421993": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "3a163de7adda4f31b929a9390f779d97": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "42c1b54101a642df9ddc3042c6bc0d21": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "ProgressStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "bar_color": null,
       "description_width": ""
      }
     },
     "532998ea27604656b1d2406703e5a144": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "6c122677dd2a4f7b8c8616a379c22c6e": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "7a17a08771b64d73abcb45fea439f0f0": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "aba4515aad164445bff82e2713d22a6a": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_13af2b8d6fde4042bbb9211c4c421993",
       "placeholder": "​",
       "style": "IPY_MODEL_3a163de7adda4f31b929a9390f779d97",
       "tabbable": null,
       "tooltip": null,
       "value": " 220/220 [13:35:11&lt;00:00, 217.32s/it]"
      }
     },
     "e1c43a20352747bc8634191054d423e9": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_7a17a08771b64d73abcb45fea439f0f0",
       "placeholder": "​",
       "style": "IPY_MODEL_532998ea27604656b1d2406703e5a144",
       "tabbable": null,
       "tooltip": null,
       "value": "100%"
      }
     },
     "f499fcdc1c4d495f8e9d7d32809e03fb": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     }
    },
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}