{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c0221296",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os \n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.stats import wilcoxon\n",
    "results_path = \"/.../MultipleFairness/Results/\"\n",
    "os.chdir(results_path)\n",
    "\n",
    "data_name = \"acs_west_poverty\"\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "95a4e4dd",
   "metadata": {},
   "source": [
    "# Table generation for Ipopt vs MILP comparison"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9daafefc",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Expect results data to have columns\n",
    "| scoreObj | auc | dpViol | eoddsViol | prpViol | gap | objVal | objBound | \n",
    "Require 3 .csv files (saved from comparison script)\n",
    "1. {task}_eye.csv: Contains the solution using the identity (no bin transitions)\n",
    "2. {task}_base.csv: Contains the solution obtained from solving the QCQP (without reformulation) using Ipopt\n",
    "3. {task}_bin.csv: Contains the solution obtained from solving the reformulation using MILP\n",
    "\"\"\"\n",
    "\n",
    "f = lambda x: str(round(x[0],3)) + \" ± \" + str(round(x[1],3))\n",
    "\n",
    "def get_stats(data_name):\n",
    "    eye = pd.read_csv(data_name+\"_eye.csv\")\n",
    "    base = pd.read_csv(data_name+\"_base.csv\")\n",
    "    milp = pd.read_csv(data_name+\"_bin.csv\")\n",
    "    obj_base = np.mean(base[\"scoreObj\"])\n",
    "    obj_base_std = np.std(base[\"scoreObj\"])\n",
    "    obj_bin = np.mean(milp[\"scoreObj\"])\n",
    "    obj_bin_std = np.std(milp[\"scoreObj\"])\n",
    "    gap_base = 100*(base[\"scoreObj\"]-milp[\"objBound\"])/base[\"scoreObj\"]\n",
    "    gap_base_mu = np.mean(gap_base)\n",
    "    gap_base_sigma = np.std(gap_base)\n",
    "    gap_bin = 100*(milp[\"scoreObj\"]-milp[\"objBound\"])/milp[\"scoreObj\"]\n",
    "    gap_bin_mu = np.mean(gap_bin)\n",
    "    gap_bin_sigma = np.std(gap_bin)\n",
    "    auc = np.mean(eye[\"auc\"])\n",
    "    auc_base = np.mean(base[\"auc\"])\n",
    "    auc_bin = np.mean(milp[\"auc\"])\n",
    "    p = wilcoxon(gap_base, gap_bin, alternative = \"greater\").pvalue\n",
    "    return obj_base, obj_bin, f([gap_base_mu, gap_base_sigma]), f([gap_bin_mu, gap_bin_sigma]), p, auc_base, auc_bin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "2cc6cc6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets_tested = [\"acs_west_income\", \"acs_west_insurance\", \"acs_west_mobility\", \n",
    "                   \"acs_west_poverty\", \"acs_west_public\",  \"acs_west_travel\",\n",
    "                  \"heartDisease\",\"compas\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5a14eb59",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = pd.DataFrame()\n",
    "for data_name in datasets_tested:\n",
    "    results = results.append(pd.Series(get_stats(data_name)), ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "e0091a8f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>ObjINT</th>\n",
       "      <th>ObjIP</th>\n",
       "      <th>DeltaINT</th>\n",
       "      <th>DeltaIP</th>\n",
       "      <th>p</th>\n",
       "      <th>AUCINT</th>\n",
       "      <th>AUCIP</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>ACS</th>\n",
       "      <td>2.0809</td>\n",
       "      <td>1.9682</td>\n",
       "      <td>15.076 ± 6.461</td>\n",
       "      <td>10.621 ± 3.402</td>\n",
       "      <td>0.0029</td>\n",
       "      <td>0.9041</td>\n",
       "      <td>0.9044</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>acs_west_insurance</th>\n",
       "      <td>0.9769</td>\n",
       "      <td>0.9599</td>\n",
       "      <td>3.432 ± 0.225</td>\n",
       "      <td>1.715 ± 0.169</td>\n",
       "      <td>0.0010</td>\n",
       "      <td>0.7411</td>\n",
       "      <td>0.7413</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>acs_west_mobility</th>\n",
       "      <td>2.4580</td>\n",
       "      <td>2.3781</td>\n",
       "      <td>5.37 ± 0.803</td>\n",
       "      <td>2.193 ± 0.138</td>\n",
       "      <td>0.0010</td>\n",
       "      <td>0.7971</td>\n",
       "      <td>0.7973</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>acs_west_poverty</th>\n",
       "      <td>2.0693</td>\n",
       "      <td>2.0526</td>\n",
       "      <td>3.756 ± 0.435</td>\n",
       "      <td>2.972 ± 0.324</td>\n",
       "      <td>0.0010</td>\n",
       "      <td>0.8440</td>\n",
       "      <td>0.8440</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>acs_west_public</th>\n",
       "      <td>8.9361</td>\n",
       "      <td>1.9665</td>\n",
       "      <td>79.711 ± 0.782</td>\n",
       "      <td>7.878 ± 2.207</td>\n",
       "      <td>0.0010</td>\n",
       "      <td>0.5420</td>\n",
       "      <td>0.8149</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>acs_west_travel</th>\n",
       "      <td>2.3935</td>\n",
       "      <td>2.3859</td>\n",
       "      <td>2.554 ± 0.254</td>\n",
       "      <td>2.242 ± 0.28</td>\n",
       "      <td>0.0010</td>\n",
       "      <td>0.7725</td>\n",
       "      <td>0.7725</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>heartDisease</th>\n",
       "      <td>1.8871</td>\n",
       "      <td>1.3035</td>\n",
       "      <td>26.385 ± 17.401</td>\n",
       "      <td>3.81 ± 0.864</td>\n",
       "      <td>0.0010</td>\n",
       "      <td>0.8302</td>\n",
       "      <td>0.8629</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>compas</th>\n",
       "      <td>7.4551</td>\n",
       "      <td>3.1300</td>\n",
       "      <td>62.88 ± 13.407</td>\n",
       "      <td>17.055 ± 7.482</td>\n",
       "      <td>0.0010</td>\n",
       "      <td>0.5143</td>\n",
       "      <td>0.7378</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                    ObjINT   ObjIP         DeltaINT         DeltaIP       p  \\\n",
       "ACS                 2.0809  1.9682   15.076 ± 6.461  10.621 ± 3.402  0.0029   \n",
       "acs_west_insurance  0.9769  0.9599    3.432 ± 0.225   1.715 ± 0.169  0.0010   \n",
       "acs_west_mobility   2.4580  2.3781     5.37 ± 0.803   2.193 ± 0.138  0.0010   \n",
       "acs_west_poverty    2.0693  2.0526    3.756 ± 0.435   2.972 ± 0.324  0.0010   \n",
       "acs_west_public     8.9361  1.9665   79.711 ± 0.782   7.878 ± 2.207  0.0010   \n",
       "acs_west_travel     2.3935  2.3859    2.554 ± 0.254    2.242 ± 0.28  0.0010   \n",
       "heartDisease        1.8871  1.3035  26.385 ± 17.401    3.81 ± 0.864  0.0010   \n",
       "compas              7.4551  3.1300   62.88 ± 13.407  17.055 ± 7.482  0.0010   \n",
       "\n",
       "                    AUCINT   AUCIP  \n",
       "ACS                 0.9041  0.9044  \n",
       "acs_west_insurance  0.7411  0.7413  \n",
       "acs_west_mobility   0.7971  0.7973  \n",
       "acs_west_poverty    0.8440  0.8440  \n",
       "acs_west_public     0.5420  0.8149  \n",
       "acs_west_travel     0.7725  0.7725  \n",
       "heartDisease        0.8302  0.8629  \n",
       "compas              0.5143  0.7378  "
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results.columns = [\"ObjINT\", \"ObjIP\", \"DeltaINT\", \"DeltaIP\", \"p\",\"AUCINT\", \"AUCIP\"]\n",
    "results.index = [\"ACS \", \"acs_west_insurance\", \"acs_west_mobility\", \n",
    "                   \"acs_west_poverty\", \"acs_west_public\",  \"acs_west_travel\",\n",
    "                  \"heartDisease\",\"compas\"]\n",
    "results.round(4)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7e6a3aa6",
   "metadata": {},
   "source": [
    "# Method vs. MFOpt comparison"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "857981bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_path = \"/.../MultipleFairness/Data/Pleiss/\"\n",
    "\n",
    "\"\"\"\n",
    "Expect results data to have columns\n",
    "| scoreObj | auc | dpViol | eoddsViol | prpViol | gap | objVal | objBound | \n",
    "Require 3 .csv files (saved from comparison script)\n",
    "1. {train/test}_base.csv: Contains the solution using the identity (no bin transitions)\n",
    "2. {train/test}_method.csv: Contains the solution using the identity but based on the scores generated from a method we are comparing against (e.g. FairLogLoss Rezaei)\n",
    "3. {train/test}_opt.csv: Contains the solution obtained from solving the reformulation using MILP, attempting to reduce fairness violations by half \n",
    "\"\"\"\n",
    "test_base = pd.read_csv(results_path+\"test_base.csv\")\n",
    "test_method = pd.read_csv(results_path+\"test_method.csv\")\n",
    "test_opt = pd.read_csv(results_path+\"test_opt.csv\")\n",
    "\n",
    "train_base = pd.read_csv(results_path+\"train_base.csv\")\n",
    "train_method = pd.read_csv(results_path+\"train_method.csv\")\n",
    "train_opt = pd.read_csv(results_path+\"train_opt.csv\")\n",
    "\n",
    "results = {\"train base\": train_base, \"train method\": train_method, \"train_opt\": train_opt,\n",
    "          \"test_base\": test_base, \"test_method\": test_method, \"test_opt\": test_opt}\n",
    "\n",
    "metrics = [\"auc\", \"dpViol\", \"eoddsViol\", \"prpViol\"]\n",
    "\n",
    "# To get error margins \n",
    "f = lambda x: str(x[0]) + \" ± \" + str(x[1])\n",
    "results_full = pd.DataFrame()\n",
    "for k, frame in results.items():\n",
    "    result_frame = frame\n",
    "    result_list = [f(np.array([np.mean(result_frame[metric]), np.std(result_frame[metric])]).round(4).tolist()) for metric in metrics]\n",
    "    results_full = pd.concat([results_full, pd.DataFrame(result_list, columns = [k])], axis=1)\n",
    "results_full.index = metrics\n",
    "\n",
    "# To get 1-SD intervals\n",
    "metrics = [\"auc\", \"dpViol\", \"eoddsViol\", \"prpViol\"]\n",
    "z = 1\n",
    "f = lambda x: str(round(x[0]-z*x[1],4))+\", \"+str(round(x[0]+z*x[1],4))\n",
    "results_full = pd.DataFrame()\n",
    "for k, frame in results.items():\n",
    "    result_frame = frame\n",
    "    result_list = [f(np.array([np.mean(result_frame[metric]), np.std(result_frame[metric])]).round(4).tolist()) for metric in metrics]\n",
    "    results_full = pd.concat([results_full, pd.DataFrame(result_list, columns = [k])], axis=1)\n",
    "results_full.index = metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "id": "a1181722",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>test_base</th>\n",
       "      <th>test_method</th>\n",
       "      <th>test_opt</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>auc</th>\n",
       "      <td>0.8286, 0.8352</td>\n",
       "      <td>0.8062, 0.8236</td>\n",
       "      <td>0.8278, 0.8342</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>dpViol</th>\n",
       "      <td>0.0196, 0.0228</td>\n",
       "      <td>0.0121, 0.0153</td>\n",
       "      <td>0.0095, 0.0117</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>eoddsViol</th>\n",
       "      <td>0.0287, 0.0371</td>\n",
       "      <td>0.0192, 0.0268</td>\n",
       "      <td>0.0114, 0.017</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>prpViol</th>\n",
       "      <td>0.1287, 0.1643</td>\n",
       "      <td>0.261, 0.5684</td>\n",
       "      <td>0.1254, 0.184</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                test_base     test_method        test_opt\n",
       "auc        0.8286, 0.8352  0.8062, 0.8236  0.8278, 0.8342\n",
       "dpViol     0.0196, 0.0228  0.0121, 0.0153  0.0095, 0.0117\n",
       "eoddsViol  0.0287, 0.0371  0.0192, 0.0268   0.0114, 0.017\n",
       "prpViol    0.1287, 0.1643   0.261, 0.5684   0.1254, 0.184"
      ]
     },
     "execution_count": 99,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Check results for test\n",
    "results_full[[\"test_base\",\"test_method\",\"test_opt\"]]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c4cbb7c",
   "metadata": {},
   "source": [
    "# Method vs. Base Comparison"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "65233689",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compare_test_pval(base, opt):\n",
    "    auc = wilcoxon(base[\"auc\"], opt[\"auc\"], alternative = \"less\").pvalue\n",
    "    dpViol = wilcoxon(base[\"dpViol\"], opt[\"dpViol\"], alternative = \"greater\").pvalue\n",
    "    eoddsViol = wilcoxon(base[\"eoddsViol\"], opt[\"eoddsViol\"], alternative = \"greater\").pvalue\n",
    "    prpViol = wilcoxon(base[\"prpViol\"], opt[\"prpViol\"], alternative = \"greater\").pvalue\n",
    "    return [auc, dpViol, eoddsViol, prpViol]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "b1725c4c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>test_base</th>\n",
       "      <th>test_opt</th>\n",
       "      <th>pVal</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>auc</th>\n",
       "      <td>0.7932 ± 0.0016</td>\n",
       "      <td>0.7923 ± 0.0017</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>dpViol</th>\n",
       "      <td>0.03 ± 0.0041</td>\n",
       "      <td>0.0207 ± 0.0026</td>\n",
       "      <td>0.000001</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>eoddsViol</th>\n",
       "      <td>0.0403 ± 0.0061</td>\n",
       "      <td>0.0247 ± 0.0025</td>\n",
       "      <td>0.000001</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>prpViol</th>\n",
       "      <td>0.159 ± 0.0245</td>\n",
       "      <td>0.1803 ± 0.0328</td>\n",
       "      <td>0.975780</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                 test_base         test_opt      pVal\n",
       "auc        0.7932 ± 0.0016  0.7923 ± 0.0017  1.000000\n",
       "dpViol       0.03 ± 0.0041  0.0207 ± 0.0026  0.000001\n",
       "eoddsViol  0.0403 ± 0.0061  0.0247 ± 0.0025  0.000001\n",
       "prpViol     0.159 ± 0.0245  0.1803 ± 0.0328  0.975780"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_path = \"/.../MultipleFairness/Data/acs_west_public/\"\n",
    "\n",
    "\"\"\"\n",
    "Expect results data to have columns\n",
    "| scoreObj | auc | dpViol | eoddsViol | prpViol | gap | objVal | objBound | \n",
    "Require 3 .csv files (saved from comparison script)\n",
    "1. {train/test}_base.csv: Contains the solution using the identity (no bin transitions)\n",
    "2. {train/test}_opt.csv: Contains the solution obtained from solving the reformulation using MILP, attempting to reduce fairness violations by half\n",
    "\n",
    "We are mainly concerned with finding the performance on the testing data since if we find a feasible point in the training data we know exactly what \n",
    "the fairness violations should be based on the constraint\n",
    "\"\"\"\n",
    "\n",
    "test_base = pd.read_csv(results_path+\"test_base.csv\")\n",
    "test_opt = pd.read_csv(results_path+\"test_opt.csv\")\n",
    "\n",
    "train_base = pd.read_csv(results_path+\"train_base.csv\")\n",
    "train_opt = pd.read_csv(results_path+\"train_opt.csv\")\n",
    "\n",
    "results = {\"train base\": train_base, \"train_opt\": train_opt,\n",
    "          \"test_base\": test_base, \"test_opt\": test_opt}\n",
    "\n",
    "metrics = [\"auc\", \"dpViol\", \"eoddsViol\", \"prpViol\"]\n",
    "\n",
    "f = lambda x: str(x[0]) + \" ± \" + str(x[1])\n",
    "results_full = pd.DataFrame()\n",
    "for k, frame in results.items():\n",
    "    result_frame = frame\n",
    "    result_list = [f(np.array([np.mean(result_frame[metric]), np.std(result_frame[metric])]).round(4).tolist()) for metric in metrics]\n",
    "    results_full = pd.concat([results_full, pd.DataFrame(result_list, columns = [k])], axis=1)\n",
    "results_full.index = metrics\n",
    "pvals = pd.DataFrame({\"pVal\": compare_test_pval(test_base, test_opt)}, index = metrics).round(6)\n",
    "pd.concat([results_full[[\"test_base\",\"test_opt\"]], pvals], axis=1)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
