{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "feeb6b7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from matplotlib import pyplot as plt\n",
    "import fairlearn.metrics as flm\n",
    "import sklearn.metrics as skm\n",
    "import functools\n",
    "from fairlearn.metrics import MetricFrame"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "24127f54",
   "metadata": {},
   "outputs": [],
   "source": [
    "# reading in data for dataframe index \n",
    "seeds = list(range(0, 50))\n",
    "seeds = np.asarray([[val]*3 for val in seeds] * 1000).flatten()\n",
    "threshes = []\n",
    "columns = []\n",
    "\n",
    "for j in range(2):\n",
    "    path = '../CrimeCommunity/{}.txt'.format(j)\n",
    "    data = []\n",
    "    with open(path) as file:\n",
    "        for line in file:\n",
    "            l = line.rstrip()\n",
    "            if len(l) > 100:\n",
    "                l = l.replace(\"[\", \"\").replace(\"]\", \"\")\n",
    "                l = l.split(\", \")\n",
    "            data.append(l)\n",
    "    data = data[:-1]\n",
    "    threshe = data[1::5]\n",
    "    tru = data[2::5]\n",
    "    pre = data[3::5]\n",
    "    fai = data[4::5]\n",
    "    for i in range(len(tru)):\n",
    "        columns.append(tru[i])\n",
    "        columns.append(pre[i])\n",
    "        columns.append(fai[i])\n",
    "        threshes.append([round(float(threshe[i]), 4)]*3)\n",
    "threshes = np.asarray(threshes).flatten()\n",
    "labels = np.asarray([[\"trues\", \"preds\", \"fairs\"]] * 5000).flatten()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "4548b77d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# use index to prep frame\n",
    "names = list(zip(threshes, seeds, labels))\n",
    "index = pd.MultiIndex.from_tuples(names, names=[\"threshold\", \"seed\", \"vals\"])\n",
    "columns = np.asarray(columns).astype(int)\n",
    "frame = pd.DataFrame(columns.T, columns=index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "d0dc7af8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th>threshold</th>\n",
       "      <th colspan=\"10\" halign=\"left\">0.0410</th>\n",
       "      <th>...</th>\n",
       "      <th colspan=\"10\" halign=\"left\">0.0001</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>seed</th>\n",
       "      <th colspan=\"3\" halign=\"left\">0</th>\n",
       "      <th colspan=\"3\" halign=\"left\">1</th>\n",
       "      <th colspan=\"3\" halign=\"left\">2</th>\n",
       "      <th>3</th>\n",
       "      <th>...</th>\n",
       "      <th>1</th>\n",
       "      <th colspan=\"3\" halign=\"left\">2</th>\n",
       "      <th colspan=\"3\" halign=\"left\">3</th>\n",
       "      <th colspan=\"3\" halign=\"left\">4</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>vals</th>\n",
       "      <th>trues</th>\n",
       "      <th>preds</th>\n",
       "      <th>fairs</th>\n",
       "      <th>trues</th>\n",
       "      <th>preds</th>\n",
       "      <th>fairs</th>\n",
       "      <th>trues</th>\n",
       "      <th>preds</th>\n",
       "      <th>fairs</th>\n",
       "      <th>trues</th>\n",
       "      <th>...</th>\n",
       "      <th>fairs</th>\n",
       "      <th>trues</th>\n",
       "      <th>preds</th>\n",
       "      <th>fairs</th>\n",
       "      <th>trues</th>\n",
       "      <th>preds</th>\n",
       "      <th>fairs</th>\n",
       "      <th>trues</th>\n",
       "      <th>preds</th>\n",
       "      <th>fairs</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>493</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>494</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>495</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>496</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>497</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>498 rows × 165 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "threshold 0.0410                                                        ...  \\\n",
       "seed           0                 1                 2                 3  ...   \n",
       "vals       trues preds fairs trues preds fairs trues preds fairs trues  ...   \n",
       "0              0     0     0     1     0     1     0     0     0     0  ...   \n",
       "1              0     0     0     0     0     0     0     0     0     0  ...   \n",
       "2              1     0     0     0     0     0     0     0     0     0  ...   \n",
       "3              1     0     0     0     0     0     0     0     0     1  ...   \n",
       "4              0     0     0     0     0     0     0     0     0     0  ...   \n",
       "..           ...   ...   ...   ...   ...   ...   ...   ...   ...   ...  ...   \n",
       "493            0     0     0     1     0     0     0     0     0     1  ...   \n",
       "494            1     1     1     0     0     0     0     1     0     0  ...   \n",
       "495            0     0     0     0     0     0     0     0     0     0  ...   \n",
       "496            0     0     0     0     0     0     1     0     1     1  ...   \n",
       "497            0     0     0     0     0     0     0     1     1     0  ...   \n",
       "\n",
       "threshold 0.0001                                                        \n",
       "seed           1     2                 3                 4              \n",
       "vals       fairs trues preds fairs trues preds fairs trues preds fairs  \n",
       "0              1     0     1     0     0     0     0     0     0     0  \n",
       "1              0     0     0     0     0     0     1     0     0     0  \n",
       "2              0     0     0     0     0     0     0     0     0     0  \n",
       "3              0     0     0     0     1     1     0     0     0     0  \n",
       "4              0     0     0     0     0     0     0     0     0     0  \n",
       "..           ...   ...   ...   ...   ...   ...   ...   ...   ...   ...  \n",
       "493            0     0     0     0     1     0     0     0     0     0  \n",
       "494            0     0     1     0     0     0     0     1     0     0  \n",
       "495            0     0     0     0     0     0     0     0     0     0  \n",
       "496            0     1     0     1     1     0     1     0     0     0  \n",
       "497            0     0     0     1     0     0     0     0     0     1  \n",
       "\n",
       "[498 rows x 165 columns]"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "frame"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "dc567197",
   "metadata": {},
   "outputs": [],
   "source": [
    "sensitive = frame[0.041,0,\"fairs\"].to_list()\n",
    "for i in range(len(sensitive)):\n",
    "    sensitive[i] = \"a\" if sensitive[i]==0 else \"b\"\n",
    "    \n",
    "labels = [1, 0]\n",
    "\n",
    "precision_alt = functools.partial(skm.precision_score, labels=labels, zero_division=0, average='macro')\n",
    "recall_alt = functools.partial(skm.recall_score, labels=labels, average='macro')\n",
    "log_loss_alt = functools.partial(skm.log_loss, labels=labels)\n",
    "cm_alt = functools.partial(skm.confusion_matrix, labels=labels,\n",
    "                               normalize=None)  # normalize='true')  # sets nomlaization for KL divs\n",
    "dp_alt = functools.partial(flm.demographic_parity_difference, sensitive_features=sensitive)\n",
    "\n",
    "metrics = {\n",
    "    'support': flm.count,\n",
    "    'accuracy': skm.accuracy_score,\n",
    "    'precision': precision_alt,\n",
    "    'recall': recall_alt,\n",
    "    #'conf_matrix': cm_alt,\n",
    "    'log_loss': log_loss_alt,\n",
    "    'selection_rate' : flm.selection_rate,\n",
    "    'mean_prediction' : flm.mean_prediction,\n",
    "    'tnr' : flm.true_negative_rate,\n",
    "    'tpr' : flm.true_positive_rate,\n",
    "    'fpr' : flm.false_positive_rate,\n",
    "    'fnr' : flm.false_negative_rate#,\n",
    "    #'dp' : flm.demographic_parity_ratio\n",
    "}    \n",
    "\n",
    "# for some reason Metric Frame doesn't like these, so call directly (only for overall)\n",
    "# as it doesn't make sense to take the difference after already differed em.\n",
    "extra_metrics = {\n",
    "    'dp_ratio' : flm.demographic_parity_ratio,\n",
    "    'dp_diff' : flm.demographic_parity_difference,\n",
    "    'eodds_ratio' : flm.equalized_odds_ratio,\n",
    "    'eodds_diff' : flm.equalized_odds_difference\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 568,
   "id": "04448ca5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_fair_metrics(extra_metrics, trues, preds, sensitive):\n",
    "    to_ret = {}\n",
    "    for key in list(extra_metrics.keys()):\n",
    "        val = extra_metrics[key](y_true=trues, \n",
    "                                 y_pred=preds, \n",
    "                                 sensitive_features=sensitive)\n",
    "        to_ret[key] = val\n",
    "    return to_ret"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "644c3403",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'get_fair_metrics' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Input \u001b[0;32mIn [45]\u001b[0m, in \u001b[0;36m<cell line: 11>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     23\u001b[0m met_frame \u001b[38;5;241m=\u001b[39m MetricFrame(metrics\u001b[38;5;241m=\u001b[39mmetrics, y_true\u001b[38;5;241m=\u001b[39mframe[thresh, seed, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtrues\u001b[39m\u001b[38;5;124m\"\u001b[39m], \n\u001b[1;32m     24\u001b[0m                         y_pred\u001b[38;5;241m=\u001b[39mframe[thresh, seed, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpreds\u001b[39m\u001b[38;5;124m\"\u001b[39m], \n\u001b[1;32m     25\u001b[0m                         sensitive_features\u001b[38;5;241m=\u001b[39msensitive)\n\u001b[1;32m     26\u001b[0m \u001b[38;5;66;03m# from lamentable Metric Frame behaviours here\u001b[39;00m\n\u001b[0;32m---> 27\u001b[0m fair_metrics \u001b[38;5;241m=\u001b[39m \u001b[43mget_fair_metrics\u001b[49m(extra_metrics, frame[thresh, seed, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtrues\u001b[39m\u001b[38;5;124m\"\u001b[39m], \n\u001b[1;32m     28\u001b[0m                                 frame[thresh, seed, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpreds\u001b[39m\u001b[38;5;124m\"\u001b[39m], sensitive)\n\u001b[1;32m     29\u001b[0m overall_seed_frame \u001b[38;5;241m=\u001b[39m met_frame\u001b[38;5;241m.\u001b[39moverall\n\u001b[1;32m     30\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m key \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mlist\u001b[39m(fair_metrics\u001b[38;5;241m.\u001b[39mkeys()):\n",
      "\u001b[0;31mNameError\u001b[0m: name 'get_fair_metrics' is not defined"
     ]
    }
   ],
   "source": [
    "# TODO scriptify this for vaughan\n",
    "# this takes a hot minute (litterally)\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "overalls = []\n",
    "grp_diffs = []\n",
    "grp_meaned_frames = []\n",
    "alls = []\n",
    "all_grps = []\n",
    "all_diffs = []\n",
    "for thresh in np.unique(threshes): \n",
    "    # print(thresh)\n",
    "    per_seed_overall = []\n",
    "    per_seed_grp = []\n",
    "    per_seed_grp_diffs = []\n",
    "    \n",
    "    for seed in list(range(5)): # TODO update to number of seeds\n",
    "        # sensitives must be strings for fairlearn's humanitarian values and whatnot\n",
    "        sensitive = frame[thresh, seed,\"fairs\"].to_list()\n",
    "        for i in range(len(sensitive)):\n",
    "            sensitive[i] = \"a\" if sensitive[i]==0 else \"b\"\n",
    "        # metric frame handles all the skearns\n",
    "        met_frame = MetricFrame(metrics=metrics, y_true=frame[thresh, seed, \"trues\"], \n",
    "                                y_pred=frame[thresh, seed, \"preds\"], \n",
    "                                sensitive_features=sensitive)\n",
    "        # from lamentable Metric Frame behaviours here\n",
    "        fair_metrics = get_fair_metrics(extra_metrics, frame[thresh, seed, \"trues\"], \n",
    "                                        frame[thresh, seed, \"preds\"], sensitive)\n",
    "        overall_seed_frame = met_frame.overall\n",
    "        for key in list(fair_metrics.keys()):\n",
    "            overall_seed_frame[key] = fair_metrics[key]\n",
    "        \n",
    "        # record seed data\n",
    "        per_seed_overall.append(overall_seed_frame)\n",
    "        per_seed_grp.append(met_frame.by_group)\n",
    "        per_seed_grp_diffs.append(met_frame.difference())\n",
    "    \n",
    "    # record seed data\n",
    "    alls.append(pd.concat(per_seed_overall, axis=1).T)\n",
    "    all_grps.append(pd.concat(per_seed_grp, axis=0))\n",
    "    all_diffs.append(pd.concat(per_seed_grp_diffs, axis=1).T)\n",
    "    \n",
    "#     seed_data_grp = pd.concat(per_seed_grp, axis=0)\n",
    "#     seed_data_overall = pd.concat(per_seed_overall, axis=1).T\n",
    "#     overall_meaned = seed_data_overall.mean()\n",
    "#     grp_diff_meaned = grp_diffed(seed_data_grp)\n",
    "#     grp_means = grp_meaned(seed_data_grp).T\n",
    "#     overalls.append(overall_meaned)\n",
    "#     grp_diffs.append(grp_diff_meaned)\n",
    "#     grp_meaned_frames.append(grp_means)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 570,
   "id": "2355189e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# package data, adding thresholds and seeds for seaborn use later in plotting. Also write to csv so we dont\n",
    "# have to do this too often (cause lazy). These are the most important dfs, others can be ignored, frankly.\n",
    "num_grps = 2\n",
    "num_threshes = len(np.unique(threshes))\n",
    "\n",
    "all_data = pd.concat(alls)\n",
    "all_data[\"seed\"] = all_data.index.to_list()\n",
    "num_seeds = len(all_data[\"seed\"].value_counts())\n",
    "mean_index = np.asarray([[thresh]*num_seeds for thresh in np.unique(threshes)]).flatten()\n",
    "all_data[\"threshold\"] = mean_index\n",
    "\n",
    "all_diff_data = pd.concat(all_diffs)\n",
    "all_diff_data[\"seed\"] = all_data.index.to_list()\n",
    "all_diff_data[\"threshold\"] = mean_index\n",
    "\n",
    "all_grp_data = pd.concat(all_grps)\n",
    "all_grp_data[\"group\"] = all_grp_data.index.to_list()\n",
    "basic_seeds = np.asarray([[elem]*num_grps for elem in list(range(num_seeds))]).flatten()\n",
    "\n",
    "all_grp_data[\"seed\"] = np.asarray([[basic_seeds] * num_threshes]).flatten()\n",
    "mean_index = np.asarray([[thresh]*num_seeds*num_grps for thresh in np.unique(threshes)]).flatten()\n",
    "all_grp_data[\"threshold\"] = mean_index\n",
    "\n",
    "all_data.to_csv('./results/overall_performance_seedwise.csv')\n",
    "all_diff_data.to_csv('./results/group_differences_M-m_seedwise.csv')\n",
    "all_grp_data.to_csv('./results/group_performances_seedwise.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 469,
   "id": "eeb80919",
   "metadata": {},
   "outputs": [],
   "source": [
    "def grp_meaned(df):\n",
    "    grp_1 = seed_data.iloc[::2]\n",
    "    grp_1_mean = grp_1.mean()\n",
    "    grp_2 = seed_data.iloc[1::2]\n",
    "    grp_2_mean = grp_2.mean()\n",
    "    return pd.concat([grp_1_mean, grp_2_mean], axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 490,
   "id": "caeccb4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "overs = pd.concat(overalls, axis=1).T\n",
    "diffs = pd.concat(grp_diffs, axis=1).T\n",
    "means = pd.concat(grp_meaned_frames, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 491,
   "id": "83015783",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>support</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>precision</th>\n",
       "      <th>recall</th>\n",
       "      <th>log_loss</th>\n",
       "      <th>selection_rate</th>\n",
       "      <th>mean_prediction</th>\n",
       "      <th>tnr</th>\n",
       "      <th>tpr</th>\n",
       "      <th>fpr</th>\n",
       "      <th>fnr</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0.0000</th>\n",
       "      <td>441.8</td>\n",
       "      <td>0.916664</td>\n",
       "      <td>0.916664</td>\n",
       "      <td>0.916664</td>\n",
       "      <td>2.878340</td>\n",
       "      <td>0.004083</td>\n",
       "      <td>0.004083</td>\n",
       "      <td>0.996559</td>\n",
       "      <td>0.011429</td>\n",
       "      <td>0.003441</td>\n",
       "      <td>0.988571</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.0000</th>\n",
       "      <td>56.2</td>\n",
       "      <td>0.448726</td>\n",
       "      <td>0.448726</td>\n",
       "      <td>0.448726</td>\n",
       "      <td>19.040348</td>\n",
       "      <td>0.066397</td>\n",
       "      <td>0.066397</td>\n",
       "      <td>0.988889</td>\n",
       "      <td>0.100981</td>\n",
       "      <td>0.011111</td>\n",
       "      <td>0.899019</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.0001</th>\n",
       "      <td>441.8</td>\n",
       "      <td>0.916664</td>\n",
       "      <td>0.916664</td>\n",
       "      <td>0.916664</td>\n",
       "      <td>2.878340</td>\n",
       "      <td>0.004083</td>\n",
       "      <td>0.004083</td>\n",
       "      <td>0.996559</td>\n",
       "      <td>0.011429</td>\n",
       "      <td>0.003441</td>\n",
       "      <td>0.988571</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.0001</th>\n",
       "      <td>56.2</td>\n",
       "      <td>0.448726</td>\n",
       "      <td>0.448726</td>\n",
       "      <td>0.448726</td>\n",
       "      <td>19.040348</td>\n",
       "      <td>0.066397</td>\n",
       "      <td>0.066397</td>\n",
       "      <td>0.988889</td>\n",
       "      <td>0.100981</td>\n",
       "      <td>0.011111</td>\n",
       "      <td>0.899019</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.0002</th>\n",
       "      <td>441.8</td>\n",
       "      <td>0.916664</td>\n",
       "      <td>0.916664</td>\n",
       "      <td>0.916664</td>\n",
       "      <td>2.878340</td>\n",
       "      <td>0.004083</td>\n",
       "      <td>0.004083</td>\n",
       "      <td>0.996559</td>\n",
       "      <td>0.011429</td>\n",
       "      <td>0.003441</td>\n",
       "      <td>0.988571</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.0997</th>\n",
       "      <td>56.2</td>\n",
       "      <td>0.448726</td>\n",
       "      <td>0.448726</td>\n",
       "      <td>0.448726</td>\n",
       "      <td>19.040348</td>\n",
       "      <td>0.066397</td>\n",
       "      <td>0.066397</td>\n",
       "      <td>0.988889</td>\n",
       "      <td>0.100981</td>\n",
       "      <td>0.011111</td>\n",
       "      <td>0.899019</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.0998</th>\n",
       "      <td>441.8</td>\n",
       "      <td>0.916664</td>\n",
       "      <td>0.916664</td>\n",
       "      <td>0.916664</td>\n",
       "      <td>2.878340</td>\n",
       "      <td>0.004083</td>\n",
       "      <td>0.004083</td>\n",
       "      <td>0.996559</td>\n",
       "      <td>0.011429</td>\n",
       "      <td>0.003441</td>\n",
       "      <td>0.988571</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.0998</th>\n",
       "      <td>56.2</td>\n",
       "      <td>0.448726</td>\n",
       "      <td>0.448726</td>\n",
       "      <td>0.448726</td>\n",
       "      <td>19.040348</td>\n",
       "      <td>0.066397</td>\n",
       "      <td>0.066397</td>\n",
       "      <td>0.988889</td>\n",
       "      <td>0.100981</td>\n",
       "      <td>0.011111</td>\n",
       "      <td>0.899019</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.0999</th>\n",
       "      <td>441.8</td>\n",
       "      <td>0.916664</td>\n",
       "      <td>0.916664</td>\n",
       "      <td>0.916664</td>\n",
       "      <td>2.878340</td>\n",
       "      <td>0.004083</td>\n",
       "      <td>0.004083</td>\n",
       "      <td>0.996559</td>\n",
       "      <td>0.011429</td>\n",
       "      <td>0.003441</td>\n",
       "      <td>0.988571</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.0999</th>\n",
       "      <td>56.2</td>\n",
       "      <td>0.448726</td>\n",
       "      <td>0.448726</td>\n",
       "      <td>0.448726</td>\n",
       "      <td>19.040348</td>\n",
       "      <td>0.066397</td>\n",
       "      <td>0.066397</td>\n",
       "      <td>0.988889</td>\n",
       "      <td>0.100981</td>\n",
       "      <td>0.011111</td>\n",
       "      <td>0.899019</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>2000 rows × 11 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "        support  accuracy  precision    recall   log_loss  selection_rate  \\\n",
       "0.0000    441.8  0.916664   0.916664  0.916664   2.878340        0.004083   \n",
       "0.0000     56.2  0.448726   0.448726  0.448726  19.040348        0.066397   \n",
       "0.0001    441.8  0.916664   0.916664  0.916664   2.878340        0.004083   \n",
       "0.0001     56.2  0.448726   0.448726  0.448726  19.040348        0.066397   \n",
       "0.0002    441.8  0.916664   0.916664  0.916664   2.878340        0.004083   \n",
       "...         ...       ...        ...       ...        ...             ...   \n",
       "0.0997     56.2  0.448726   0.448726  0.448726  19.040348        0.066397   \n",
       "0.0998    441.8  0.916664   0.916664  0.916664   2.878340        0.004083   \n",
       "0.0998     56.2  0.448726   0.448726  0.448726  19.040348        0.066397   \n",
       "0.0999    441.8  0.916664   0.916664  0.916664   2.878340        0.004083   \n",
       "0.0999     56.2  0.448726   0.448726  0.448726  19.040348        0.066397   \n",
       "\n",
       "        mean_prediction       tnr       tpr       fpr       fnr  \n",
       "0.0000         0.004083  0.996559  0.011429  0.003441  0.988571  \n",
       "0.0000         0.066397  0.988889  0.100981  0.011111  0.899019  \n",
       "0.0001         0.004083  0.996559  0.011429  0.003441  0.988571  \n",
       "0.0001         0.066397  0.988889  0.100981  0.011111  0.899019  \n",
       "0.0002         0.004083  0.996559  0.011429  0.003441  0.988571  \n",
       "...                 ...       ...       ...       ...       ...  \n",
       "0.0997         0.066397  0.988889  0.100981  0.011111  0.899019  \n",
       "0.0998         0.004083  0.996559  0.011429  0.003441  0.988571  \n",
       "0.0998         0.066397  0.988889  0.100981  0.011111  0.899019  \n",
       "0.0999         0.004083  0.996559  0.011429  0.003441  0.988571  \n",
       "0.0999         0.066397  0.988889  0.100981  0.011111  0.899019  \n",
       "\n",
       "[2000 rows x 11 columns]"
      ]
     },
     "execution_count": 491,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "index = np.unique(threshes)\n",
    "overs['threshold'] = index\n",
    "diffs['threshold'] = index\n",
    "overs.set_index('threshold')\n",
    "diffs.set_index('threshold')\n",
    "mean_index = np.asarray([[thresh]*2 for thresh in np.unique(threshes)]).flatten()\n",
    "means.set_index(mean_index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 494,
   "id": "c326cac2",
   "metadata": {},
   "outputs": [],
   "source": [
    "overs.to_csv('./results/overall_performance.csv')\n",
    "diffs.to_csv('./results/group_differences_M-m.csv')\n",
    "means.to_csv('./results/group_performances.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 389,
   "id": "e4086d46",
   "metadata": {},
   "outputs": [],
   "source": [
    "def grp_diffed(seed_data):\n",
    "    grp_1 = seed_data.iloc[::2]\n",
    "    grp_1_mean = grp_1.mean()\n",
    "    grp_2 = seed_data.iloc[1::2]\n",
    "    grp_2_mean = grp_2.mean()\n",
    "    grp_1_mean['support']\n",
    "    if grp_1_mean['support'] > grp_2_mean['support']:\n",
    "        diffed = grp_1_mean - grp_2_mean\n",
    "    else:\n",
    "        diffed = grp_2_mean - grp_1_mean\n",
    "    return diffed\n",
    "#grp_1_mean - grp_2_mean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "de1ea926",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 1, 2, 3, 4])"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(np.unique(frame.columns.get_level_values('seed').to_numpy()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35a50622",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "CPoF",
   "language": "python",
   "name": "cpof"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
