{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ast\n",
    "import os\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "\n",
    "sns.set_theme(palette=\"colorblind\", style='whitegrid', font_scale=1.25)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def group_func(x):\n",
    "    return (x.sum() - x.min() - x.max()) / (len(x) - 2)\n",
    "\n",
    "def group_func_2(x):\n",
    "    return (x.sum() - x.nlargest(5, columns=[\"violation\"]).sum() - x.nsmallest(5, columns=[\"violation\"]).sum()) / (len(x) - 10)\n",
    "\n",
    "\n",
    "METRIC_MAP = {\n",
    "    \"Demographic_Parity\": \"Dem_Parity\",\n",
    "    \"Equal_Opportunity\": \"Eq_Opp\", \n",
    "    \"Equalized_Odds\": \"Eq_Odds\",\n",
    "    \"Predictive_Equality\": \"Pred_Eq\",\n",
    "    \"Predictive_Parity\": \"Pred_Parity\",\n",
    "    \"Disparate_Impact\": \"Disp_Impact\"\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>c</th>\n",
       "      <th>dataset_sens_attrs</th>\n",
       "      <th>use_classwise</th>\n",
       "      <th>cp_method</th>\n",
       "      <th>fairness_metric</th>\n",
       "      <th>type</th>\n",
       "      <th>eff_mean</th>\n",
       "      <th>eff_std</th>\n",
       "      <th>violation_mean</th>\n",
       "      <th>violation_std</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.05</td>\n",
       "      <td>ACSEducation_</td>\n",
       "      <td>False</td>\n",
       "      <td>aps</td>\n",
       "      <td>Dem_Parity</td>\n",
       "      <td>base</td>\n",
       "      <td>2.9816</td>\n",
       "      <td>0.0018</td>\n",
       "      <td>0.4578</td>\n",
       "      <td>0.0022</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.05</td>\n",
       "      <td>ACSEducation_</td>\n",
       "      <td>False</td>\n",
       "      <td>aps</td>\n",
       "      <td>Dem_Parity</td>\n",
       "      <td>ours</td>\n",
       "      <td>5.9217</td>\n",
       "      <td>0.0139</td>\n",
       "      <td>0.0455</td>\n",
       "      <td>0.0081</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.05</td>\n",
       "      <td>ACSEducation_</td>\n",
       "      <td>False</td>\n",
       "      <td>aps</td>\n",
       "      <td>Eq_Odds</td>\n",
       "      <td>base</td>\n",
       "      <td>2.9816</td>\n",
       "      <td>0.0018</td>\n",
       "      <td>0.4654</td>\n",
       "      <td>0.0233</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.05</td>\n",
       "      <td>ACSEducation_</td>\n",
       "      <td>False</td>\n",
       "      <td>aps</td>\n",
       "      <td>Eq_Odds</td>\n",
       "      <td>ours</td>\n",
       "      <td>6.0000</td>\n",
       "      <td>0.0000</td>\n",
       "      <td>0.0000</td>\n",
       "      <td>0.0000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.05</td>\n",
       "      <td>ACSEducation_</td>\n",
       "      <td>False</td>\n",
       "      <td>aps</td>\n",
       "      <td>Eq_Opp</td>\n",
       "      <td>base</td>\n",
       "      <td>2.9816</td>\n",
       "      <td>0.0018</td>\n",
       "      <td>0.4356</td>\n",
       "      <td>0.0551</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3195</th>\n",
       "      <td>0.8</td>\n",
       "      <td>Pokec_z_region_gender</td>\n",
       "      <td>True</td>\n",
       "      <td>cfgnn</td>\n",
       "      <td>Disp_Impact</td>\n",
       "      <td>ours</td>\n",
       "      <td>2.7876</td>\n",
       "      <td>0.1404</td>\n",
       "      <td>0.8005</td>\n",
       "      <td>0.0367</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3196</th>\n",
       "      <td>0.8</td>\n",
       "      <td>Pokec_z_region_gender</td>\n",
       "      <td>True</td>\n",
       "      <td>daps</td>\n",
       "      <td>Disp_Impact</td>\n",
       "      <td>base</td>\n",
       "      <td>2.4093</td>\n",
       "      <td>0.0686</td>\n",
       "      <td>0.6579</td>\n",
       "      <td>0.0309</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3197</th>\n",
       "      <td>0.8</td>\n",
       "      <td>Pokec_z_region_gender</td>\n",
       "      <td>True</td>\n",
       "      <td>daps</td>\n",
       "      <td>Disp_Impact</td>\n",
       "      <td>ours</td>\n",
       "      <td>2.7427</td>\n",
       "      <td>0.1714</td>\n",
       "      <td>0.7492</td>\n",
       "      <td>0.0734</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3198</th>\n",
       "      <td>0.8</td>\n",
       "      <td>Pokec_z_region_gender</td>\n",
       "      <td>True</td>\n",
       "      <td>tps</td>\n",
       "      <td>Disp_Impact</td>\n",
       "      <td>base</td>\n",
       "      <td>2.2648</td>\n",
       "      <td>0.0663</td>\n",
       "      <td>0.6400</td>\n",
       "      <td>0.0384</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3199</th>\n",
       "      <td>0.8</td>\n",
       "      <td>Pokec_z_region_gender</td>\n",
       "      <td>True</td>\n",
       "      <td>tps</td>\n",
       "      <td>Disp_Impact</td>\n",
       "      <td>ours</td>\n",
       "      <td>2.6267</td>\n",
       "      <td>0.1838</td>\n",
       "      <td>0.7449</td>\n",
       "      <td>0.0610</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>3200 rows × 10 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "         c     dataset_sens_attrs  use_classwise cp_method fairness_metric  \\\n",
       "0     0.05          ACSEducation_          False       aps      Dem_Parity   \n",
       "1     0.05          ACSEducation_          False       aps      Dem_Parity   \n",
       "2     0.05          ACSEducation_          False       aps         Eq_Odds   \n",
       "3     0.05          ACSEducation_          False       aps         Eq_Odds   \n",
       "4     0.05          ACSEducation_          False       aps          Eq_Opp   \n",
       "...    ...                    ...            ...       ...             ...   \n",
       "3195   0.8  Pokec_z_region_gender           True     cfgnn     Disp_Impact   \n",
       "3196   0.8  Pokec_z_region_gender           True      daps     Disp_Impact   \n",
       "3197   0.8  Pokec_z_region_gender           True      daps     Disp_Impact   \n",
       "3198   0.8  Pokec_z_region_gender           True       tps     Disp_Impact   \n",
       "3199   0.8  Pokec_z_region_gender           True       tps     Disp_Impact   \n",
       "\n",
       "      type  eff_mean  eff_std  violation_mean  violation_std  \n",
       "0     base    2.9816   0.0018          0.4578         0.0022  \n",
       "1     ours    5.9217   0.0139          0.0455         0.0081  \n",
       "2     base    2.9816   0.0018          0.4654         0.0233  \n",
       "3     ours    6.0000   0.0000          0.0000         0.0000  \n",
       "4     base    2.9816   0.0018          0.4356         0.0551  \n",
       "...    ...       ...      ...             ...            ...  \n",
       "3195  ours    2.7876   0.1404          0.8005         0.0367  \n",
       "3196  base    2.4093   0.0686          0.6579         0.0309  \n",
       "3197  ours    2.7427   0.1714          0.7492         0.0734  \n",
       "3198  base    2.2648   0.0663          0.6400         0.0384  \n",
       "3199  ours    2.6267   0.1838          0.7449         0.0610  \n",
       "\n",
       "[3200 rows x 10 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result_files = [os.path.join(dp, f) for dp, _, fn in os.walk(os.path.expanduser(\"fairness_trials\")) for f in fn]\n",
    "\n",
    "all_results = []\n",
    "for rfile in result_files:\n",
    "    arr = rfile.split(\"/\")[1:]\n",
    "    dataset_sens_attrs, metric, use_classwise, cp_method = arr[0], arr[1], ast.literal_eval(arr[2]), arr[3]\n",
    "    if cp_method in [\"aps_no_rand.csv\", \"tps_classwise.csv\", \"dtps.csv\"]: continue\n",
    "\n",
    "    df = pd.read_csv(rfile, header=None, names=[\"Unnamed: 0\", \"c\",\"base_eff\", \"base_coverage\", \"base_violation\", \"eff\", \"coverage\", \"violation\"])\n",
    "    df = df.dropna()\n",
    "    \n",
    "    df['c'] = df[\"c\"].astype(str)\n",
    "\n",
    "    df = df[df[\"c\"] != \"0.01\"]\n",
    "\n",
    "    df[\"dataset_sens_attrs\"] = dataset_sens_attrs\n",
    "    df[\"use_classwise\"] = use_classwise\n",
    "    df[\"cp_method\"] = cp_method.split(\".\")[0]\n",
    "    df[\"fairness_metric\"] = METRIC_MAP[metric] if metric in METRIC_MAP else metric\n",
    "\n",
    "    all_results.append(df)\n",
    "\n",
    "res_df = pd.concat(all_results, ignore_index=True)\n",
    "\n",
    "res_df = res_df.drop(columns=[\"Unnamed: 0\"])\n",
    "\n",
    "res_df[\"eff\"] = res_df[\"eff\"].apply(lambda x: round(float(x[len(\"tensor(\"):-1]), 3))\n",
    "res_df[\"coverage\"] = res_df[\"coverage\"].apply(lambda x: round(float(x[len(\"tensor(\"):-1]), 3))\n",
    "res_df[\"violation\"] = res_df[\"violation\"].apply(lambda x: round(float(x[len(\"tensor(\"):-1]), 3))\n",
    "\n",
    "res_df[\"base_eff\"] = res_df[\"base_eff\"].apply(lambda x: round(float(x[len(\"tensor(\"):-1]), 3))\n",
    "res_df[\"base_coverage\"] = res_df[\"base_coverage\"].apply(lambda x: round(float(x[len(\"tensor(\"):-1]), 3))\n",
    "res_df[\"base_violation\"] = res_df[\"base_violation\"].apply(lambda x: round(float(x[len(\"tensor(\"):-1]), 3))\n",
    "\n",
    "our_res_df = res_df[['c', 'eff', 'coverage',\n",
    "       'violation', 'dataset_sens_attrs', 'use_classwise', 'cp_method',\n",
    "       'fairness_metric']]\n",
    "\n",
    "our_res_df[\"type\"] = \"ours\"\n",
    "\n",
    "base_res_df = res_df[['c', 'base_eff', 'base_coverage',\n",
    "       'base_violation', 'dataset_sens_attrs', 'use_classwise', 'cp_method',\n",
    "       'fairness_metric']].rename(columns={'base_eff': 'eff', 'base_coverage': 'coverage', 'base_violation': 'violation'})\n",
    "base_res_df[\"type\"] = \"base\"\n",
    "\n",
    "res_df = pd.concat([our_res_df,base_res_df],ignore_index=True)\n",
    "\n",
    "\n",
    "grouped_res_df = res_df.groupby((['c', 'dataset_sens_attrs', 'use_classwise', 'cp_method', 'fairness_metric', 'type']))\n",
    "\n",
    "mean_res_df = grouped_res_df.mean().reset_index() # apply(group_func)\n",
    "\n",
    "std_res_df = grouped_res_df.std().reset_index()\n",
    "\n",
    "res_df = mean_res_df[['c', 'dataset_sens_attrs', 'use_classwise', 'cp_method', 'fairness_metric', 'type']]\n",
    "\n",
    "res_df[\"eff_mean\"] = mean_res_df[\"eff\"].round(4)\n",
    "res_df[\"eff_std\"] = std_res_df[\"eff\"].round(4)\n",
    "\n",
    "res_df[\"violation_mean\"] = mean_res_df[\"violation\"].round(4)\n",
    "res_df[\"violation_std\"] = std_res_df[\"violation\"].round(4)\n",
    "\n",
    "res_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "melted_res_df = res_df.melt([\"c\", \"dataset_sens_attrs\", \"use_classwise\", \"cp_method\", \"fairness_metric\", \"type\"], var_name=\"stat_type\", value_name=\"stats\")\n",
    "\n",
    "for dataset_sens_attrs in res_df[\"dataset_sens_attrs\"].unique():\n",
    "    for use_classwise in [False, True]:\n",
    "        df = melted_res_df[(melted_res_df[\"dataset_sens_attrs\"] == dataset_sens_attrs) & (melted_res_df[\"use_classwise\"] == use_classwise)]\n",
    "        df[\"index\"] = f\"{dataset_sens_attrs}_{use_classwise}\"\n",
    "        results_table = df.groupby([\"index\", 'c', \"fairness_metric\", \"cp_method\", \"type\", \"stat_type\"])[\"stats\"].apply(lambda x: x.values[0]).reset_index().pivot(index=[\"index\", \"c\", \"fairness_metric\"], columns=[\"cp_method\", \"type\", \"stat_type\"], values=[\"stats\"]).droplevel(0, axis=0)\n",
    "        \n",
    "        results_table.to_excel(f\"./processed_trials/{dataset_sens_attrs}_{use_classwise}.xlsx\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "METRICS = list(map(lambda x: METRIC_MAP[x], [\"Predictive_Parity\"]))\n",
    "C = [\"0.05\", \"0.1\", \"0.15\", \"0.2\"]\n",
    "for dataset_sens_attrs in res_df[\"dataset_sens_attrs\"].unique():\n",
    "    if \"ACSEducation\" not in dataset_sens_attrs: continue\n",
    "    for use_classwise in [True]:\n",
    "        df = res_df[(res_df[\"dataset_sens_attrs\"] == dataset_sens_attrs) & (res_df[\"use_classwise\"] == use_classwise)]\n",
    "\n",
    "        print(f\"{dataset_sens_attrs}_{use_classwise}\")\n",
    "\n",
    "        filtered_df = df[(df[\"fairness_metric\"].isin(METRICS))]\n",
    "\n",
    "        eff_lim = 4\n",
    "        if \"ACSEducation\" in dataset_sens_attrs:\n",
    "           eff_lim = 6\n",
    "\n",
    "        if \"Credit\" in dataset_sens_attrs or \"Pokec\" in dataset_sens_attrs:\n",
    "           CP_METHODS = [\"aps\", \"cfgnn\", \"daps\", \"tps\"]\n",
    "        else:\n",
    "           CP_METHODS = [\"aps\", \"tps\"]\n",
    "\n",
    "\n",
    "        grid=sns.catplot(\n",
    "           data=filtered_df[filtered_df[\"type\"] == \"ours\"],\n",
    "           x=\"fairness_metric\",\n",
    "           y=\"eff_mean\",\n",
    "           hue=\"c\",\n",
    "           col=\"cp_method\",\n",
    "           sharex=False,\n",
    "           ci=None,\n",
    "           kind=\"bar\",\n",
    "           order=METRICS\n",
    "        )\n",
    "\n",
    "        grid.figure.set_size_inches(5, 5)\n",
    "        grid.set(xlabel=\"\", ylabel=\"Efficiency\",  ylim=(0, eff_lim))\n",
    "\n",
    "        for key, ax in grid.axes_dict.items():\n",
    "           ax.set_title(key.upper())\n",
    "\n",
    "        sns.move_legend(grid, \"center right\", bbox_to_anchor=(1.4, 0.5), ncol=1)\n",
    "\n",
    "        grid.legend.set_title(\"Closeness\\nThreshold:\")\n",
    "\n",
    "        grid2 = sns.catplot(\n",
    "           data=filtered_df[filtered_df[\"type\"] == \"ours\"],\n",
    "           x=\"fairness_metric\",\n",
    "           y=\"violation_mean\",\n",
    "           hue=\"c\",\n",
    "           col=\"cp_method\",\n",
    "           sharex=False,\n",
    "           ci=None,\n",
    "           kind=\"bar\",\n",
    "           order=METRICS\n",
    "        )\n",
    "\n",
    "        grid2.figure.set_size_inches(5, 5)\n",
    "        grid2.set(xlabel=\"\", ylabel=\"Actual Fairness Disparity\", ylim=(0, 0.5))\n",
    "\n",
    "        for key, ax in grid2.axes_dict.items():\n",
    "           ax.set_title(key.upper())\n",
    "\n",
    "        sns.move_legend(grid2, \"center right\", bbox_to_anchor=(1.4, 0.5), ncol=1)\n",
    "\n",
    "\n",
    "        grid2.legend.set_title(\"Closeness\\nThreshold:\")\n",
    "\n",
    "        for i in range(len(CP_METHODS)):\n",
    "          base_cp_df = filtered_df[(filtered_df[\"type\"] == \"base\") & (filtered_df[\"cp_method\"] == CP_METHODS[i])]\n",
    "\n",
    "          for ix, a in enumerate(grid.axes[0][i].patches):\n",
    "            x_start = a.get_x()\n",
    "            width = a.get_width()\n",
    "            if width == 0: \n",
    "               continue\n",
    "          \n",
    "            eff_mean = base_cp_df[(base_cp_df[\"fairness_metric\"] == METRICS[ix % len(METRICS)])][\"eff_mean\"].mean()\n",
    "            grid.axes[0][i].plot(\n",
    "                [x_start, x_start+width],\n",
    "                [\n",
    "                  eff_mean\n",
    "                ] * 2,\n",
    "                '--', \n",
    "                c='k'\n",
    "              )\n",
    "        \n",
    "          for ix, a in enumerate(grid2.axes[0][i].patches):\n",
    "            x_start = a.get_x()\n",
    "            width = a.get_width()\n",
    "            if width == 0: \n",
    "               continue\n",
    "               \n",
    "            violation_mean = base_cp_df[(base_cp_df[\"fairness_metric\"] == METRICS[ix % len(METRICS)])][\"violation_mean\"].mean()\n",
    "            grid2.axes[0][i].plot(\n",
    "                [x_start, x_start+width],\n",
    "                [\n",
    "                    violation_mean\n",
    "                ] * 2,\n",
    "                '--', \n",
    "                c='k'\n",
    "              )\n",
    "            \n",
    "            grid2.axes[0][i].plot(\n",
    "                x_start + width / 2,\n",
    "                float(C[ix // len(METRICS)]),\n",
    "                'o',\n",
    "                c='k',\n",
    "                markersize=3.5,\n",
    "              )\n",
    "        \n",
    "        grid.savefig(f\"./figures/{str(dataset_sens_attrs).strip('_')}_{use_classwise}_Proxy_efficiency.pdf\")\n",
    "        grid2.savefig(f\"./figures/{str(dataset_sens_attrs).strip('_')}_{use_classwise}_Proxy_violation.pdf\")\n",
    "        \n",
    "        plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "METRICS = list(map(lambda x: METRIC_MAP[x], [\"Demographic_Parity\", \"Equal_Opportunity\", \"Equalized_Odds\",  \"Predictive_Equality\"]))\n",
    "\n",
    "C = [\"0.05\", \"0.1\", \"0.15\", \"0.2\"]\n",
    "for dataset_sens_attrs in res_df[\"dataset_sens_attrs\"].unique():\n",
    "    if \"Credit\" not in dataset_sens_attrs: continue\n",
    "    for use_classwise in [False, True]:\n",
    "        df = res_df[(res_df[\"dataset_sens_attrs\"] == dataset_sens_attrs) & (res_df[\"use_classwise\"] == use_classwise)]\n",
    "\n",
    "        print(f\"{dataset_sens_attrs}_{use_classwise}\")\n",
    "\n",
    "        eff_lim = 4\n",
    "        if \"ACSEducation\" in dataset_sens_attrs:\n",
    "           eff_lim = 6\n",
    "\n",
    "        if \"Credit\" in dataset_sens_attrs:\n",
    "           CP_METHODS = [\"aps\",\"cfgnn\", \"daps\", \"tps\"]\n",
    "        elif \"Pokec\" in dataset_sens_attrs:\n",
    "           CP_METHODS = [\"aps\",\"cfgnn\", \"daps\", \"tps\"]\n",
    "        else:\n",
    "           CP_METHODS = [\"aps\", \"tps\"]\n",
    "\n",
    "\n",
    "        filtered_df = df[(df[\"fairness_metric\"].isin(METRICS)) & (df[\"cp_method\"].isin(CP_METHODS))]\n",
    "\n",
    "        grid=sns.catplot(\n",
    "           data=filtered_df[filtered_df[\"type\"] == \"ours\"],\n",
    "           x=\"fairness_metric\",\n",
    "           y=\"eff_mean\",\n",
    "           hue=\"c\",\n",
    "           col=\"cp_method\",\n",
    "           sharex=False,\n",
    "           ci=None,\n",
    "           kind=\"bar\",\n",
    "           order=METRICS,\n",
    "           col_wrap=2\n",
    "        )\n",
    "\n",
    "        grid.set(xlabel=\"\")\n",
    "        grid.set(ylabel=\"Efficiency\")\n",
    "        grid.set(ylim=(0, eff_lim))\n",
    "\n",
    "        for key, ax in grid.axes_dict.items():\n",
    "           ax.set_title(key.upper())\n",
    "\n",
    "        sns.move_legend(grid, \"upper center\", bbox_to_anchor=(0.5, 1.05), ncol=len(C))\n",
    "\n",
    "        grid.legend.set_title(\"Closeness Threshold:\")\n",
    "\n",
    "\n",
    "        grid2 = sns.catplot(\n",
    "           data=filtered_df[filtered_df[\"type\"] == \"ours\"],\n",
    "           x=\"fairness_metric\",\n",
    "           y=\"violation_mean\",\n",
    "           hue=\"c\",\n",
    "           col=\"cp_method\",\n",
    "           sharex=False,\n",
    "           ci=None,\n",
    "           kind=\"bar\",\n",
    "           order=METRICS,\n",
    "           col_wrap=2,\n",
    "        )\n",
    "\n",
    "        grid2.set(xlabel=\"\")\n",
    "        grid2.set(ylabel=\"Actual Fairness Disparity\")\n",
    "\n",
    "        if \"Credit\" in dataset_sens_attrs or \"ACSEducation\" in dataset_sens_attrs:\n",
    "          grid2.set(ylim=(0, 0.5), yticks=[0, 0.1, 0.2, 0.3, 0.4, 0.5])\n",
    "        else:\n",
    "          grid2.set(ylim=(0, 0.4), yticks=[0, 0.1, 0.2, 0.3, 0.4])\n",
    "\n",
    "        \n",
    "        for key, ax in grid2.axes_dict.items():\n",
    "           ax.set_title(key.upper())\n",
    "\n",
    "        sns.move_legend(grid2, \"upper center\", bbox_to_anchor=(0.5, 1.05), ncol=len(C))\n",
    "\n",
    "        grid2.legend.set_title(\"Closeness Threshold:\")\n",
    "\n",
    "        for i in range(len(CP_METHODS)):\n",
    "          base_cp_df = filtered_df[(filtered_df[\"type\"] == \"base\") & (filtered_df[\"cp_method\"] == CP_METHODS[i])]\n",
    "\n",
    "          for ix, a in enumerate(grid.axes[i].patches):\n",
    "            x_start = a.get_x()\n",
    "            width = a.get_width()\n",
    "            if width == 0: \n",
    "               continue\n",
    "          \n",
    "            eff_mean = base_cp_df[(base_cp_df[\"fairness_metric\"] == METRICS[ix % len(METRICS)])][\"eff_mean\"].mean()\n",
    "            grid.axes[i].plot(\n",
    "                [x_start, x_start+width],\n",
    "                [\n",
    "                  eff_mean\n",
    "                ] * 2,\n",
    "                '--', \n",
    "                c='k'\n",
    "              )\n",
    "\n",
    "          for ix, a in enumerate(grid2.axes[i].patches):\n",
    "            x_start = a.get_x()\n",
    "            width = a.get_width()\n",
    "            if width == 0: \n",
    "               continue\n",
    "               \n",
    "            violation_mean = base_cp_df[(base_cp_df[\"fairness_metric\"] == METRICS[ix % len(METRICS)])][\"violation_mean\"].mean()# & (base_cp_df[\"c\"] == C[ix // len(C)])][\"violation_mean\"]#.mean()\n",
    "            grid2.axes[i].plot(\n",
    "                [x_start, x_start+width],\n",
    "                [\n",
    "                    violation_mean\n",
    "                ] * 2,\n",
    "                '--', \n",
    "                c='k'\n",
    "              )\n",
    "            \n",
    "            grid2.axes[i].plot(\n",
    "                x_start + width / 2,\n",
    "                float(C[ix // len(METRICS)]),\n",
    "                'o',\n",
    "                c='k',\n",
    "                markersize=3.5,\n",
    "              )\n",
    "        \n",
    "        grid.savefig(f\"./figures/{str(dataset_sens_attrs).strip('_')}_{use_classwise}_efficiency.pdf\")\n",
    "        grid2.savefig(f\"./figures/{str(dataset_sens_attrs).strip('_')}_{use_classwise}_violation.pdf\")\n",
    "\n",
    "        plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "METRICS = list(map(lambda x: METRIC_MAP[x],[\"Disparate_Impact\"]))\n",
    "\n",
    "for dataset_sens_attrs in res_df[\"dataset_sens_attrs\"].unique():\n",
    "  for use_classwise in [False, True]:\n",
    "      df = res_df[(res_df[\"dataset_sens_attrs\"] == dataset_sens_attrs) & (res_df[\"use_classwise\"] == use_classwise) & (res_df[\"fairness_metric\"].isin(METRICS))]\n",
    "      df[\"index\"] = f\"{dataset_sens_attrs}_{use_classwise}\"\n",
    "      print(f\"{dataset_sens_attrs}_{use_classwise}\\n\")\n",
    "\n",
    "      results_table = df.groupby([\"index\", \"cp_method\", \"type\"])[\"eff_mean\"].apply(lambda x: x.values[0]).reset_index().pivot(index=[\"index\"], columns=[\"cp_method\", \"type\"], values=[\"eff_mean\"])\n",
    "      \n",
    "      # display(results_table)\n",
    "      print(results_table.to_latex())\n",
    "      \n",
    "      results_table = df.groupby([\"index\", \"cp_method\", \"type\"])[\"violation_mean\"].apply(lambda x: x.values[0]).reset_index().pivot(index=[\"index\"], columns=[\"cp_method\", \"type\"], values=[\"violation_mean\"])\n",
    "      \n",
    "      # display(results_table)\n",
    "      print(results_table.to_latex())\n",
    "      print()\n",
    "\n",
    "    # results_table.to_excel(f\"./processed_trials/{dataset_sens_attrs}_{use_classwise}.xlsx\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "fairgraph",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
