{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import tomli\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from datavis import METHOD_MAPPER, DATASET_MAPPER"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [],
   "source": [
    "EVAL_PATH = './eval/'\n",
    "RESULTS = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_config(path) -> dict:\n",
    "    with open(path, 'rb') as f:\n",
    "        return tomli.load(f)\n",
    "    \n",
    "def load_json(js_path) -> dict:\n",
    "    with open(js_path, 'r') as f:\n",
    "        return json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['codi', 'fairsmote', 'fairtabddpm', 'fairtabgan', 'goggle', 'great', 'smote', 'stasy', 'tabddpm', 'tabsyn']\n"
     ]
    }
   ],
   "source": [
    "def print_config():\n",
    "    config = load_config('assess.toml')\n",
    "    considered = config['methods']['considered']\n",
    "    RESULTS['methods'] = considered\n",
    "    print(considered)\n",
    "\n",
    "print_config()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_score(dataset, stats=False):\n",
    "    score_path = os.path.join(EVAL_PATH, 'quality', dataset, 'score.json')\n",
    "    dcr_path = os.path.join(EVAL_PATH, 'quality', dataset, 'dcr.json')\n",
    "\n",
    "    score = load_json(score_path)\n",
    "    dcr = load_json(dcr_path)\n",
    "    for data in score:\n",
    "        score[data]['dcr'] = dcr[data]\n",
    "    \n",
    "    score_stats = {}\n",
    "    if stats:\n",
    "        for data in score:\n",
    "            score_stats[data] = {}\n",
    "            for metric in score[data]:\n",
    "                if metric not in score_stats:\n",
    "                    mu, sigma = np.mean(score[data][metric]), np.std(score[data][metric])\n",
    "                    mu, sigma = round(mu, 3), round(sigma, 3)\n",
    "                    score_stats[data][metric] = (mu, sigma)\n",
    "        return score_stats\n",
    "    \n",
    "    return score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_tuple(t):\n",
    "    ans = (1 - t[0], t[1])\n",
    "    ans = (round(ans[0], 3), round(ans[1], 3))\n",
    "    return ans\n",
    "\n",
    "def convert_to_latex(t):    \n",
    "    ans = f\"${t[0]:.3f}_{{\\pm {t[1]:.3f}}}$\"\n",
    "    # eliminate the leading zero\n",
    "    ans = ans.replace('0.', '.')\n",
    "    return ans"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_tex_score_df_dataset(dataset, idx_order, metrics=['shape', 'trend']):\n",
    "    score = read_score(dataset, stats=True)\n",
    "    # assert metrics in {['shape', 'trend'], ['shape', 'trend', 'dcr']}\n",
    "    score_df = pd.DataFrame(score).T[metrics]\n",
    "\n",
    "    # rename the columns\n",
    "    ds_name = DATASET_MAPPER[dataset]\n",
    "    score_df = score_df.rename(columns={\n",
    "        'shape': f'{ds_name} - Density',\n",
    "        'trend': f'{ds_name} - Correlation',\n",
    "    })\n",
    "\n",
    "    # rearrange the index order\n",
    "    score_df = score_df.reindex(idx_order)\n",
    "\n",
    "    # rename index according to the method mapper\n",
    "    score_df = score_df.rename(index=METHOD_MAPPER)\n",
    "\n",
    "    # convert the tuple to latex\n",
    "    score_df = score_df.map(convert_tuple).map(convert_to_latex)\n",
    "    return score_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx_order = [\n",
    "    'codi', 'goggle', 'great', 'smote', 'stasy',\n",
    "    'tabddpm', 'tabsyn', 'fairsmote', 'fairtabgan', 'fairtabddpm',\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [],
   "source": [
    "adult_df = get_tex_score_df_dataset('adult', idx_order)\n",
    "compas_df = get_tex_score_df_dataset('compass', idx_order)\n",
    "bank_df = get_tex_score_df_dataset('bank', idx_order)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Adult - Density</th>\n",
       "      <th>Bank - Density</th>\n",
       "      <th>COMPAS - Density</th>\n",
       "      <th>Ave. - Density</th>\n",
       "      <th>Adult - Correlation</th>\n",
       "      <th>Bank - Correlation</th>\n",
       "      <th>COMPAS - Correlation</th>\n",
       "      <th>Ave. - Correlation</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>CoDi</th>\n",
       "      <td>$.145_{\\pm .000}$</td>\n",
       "      <td>$.154_{\\pm .000}$</td>\n",
       "      <td>$.205_{\\pm .001}$</td>\n",
       "      <td>$16.8\\%$</td>\n",
       "      <td>$.495_{\\pm .000}$</td>\n",
       "      <td>$.344_{\\pm .001}$</td>\n",
       "      <td>$.550_{\\pm .001}$</td>\n",
       "      <td>$46.3\\%$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Goggle</th>\n",
       "      <td>$.158_{\\pm .001}$</td>\n",
       "      <td>$.131_{\\pm .000}$</td>\n",
       "      <td>$.164_{\\pm .001}$</td>\n",
       "      <td>$15.1\\%$</td>\n",
       "      <td>$.447_{\\pm .016}$</td>\n",
       "      <td>$.232_{\\pm .000}$</td>\n",
       "      <td>$.332_{\\pm .022}$</td>\n",
       "      <td>$33.7\\%$</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "          Adult - Density     Bank - Density   COMPAS - Density  \\\n",
       "CoDi    $.145_{\\pm .000}$  $.154_{\\pm .000}$  $.205_{\\pm .001}$   \n",
       "Goggle  $.158_{\\pm .001}$  $.131_{\\pm .000}$  $.164_{\\pm .001}$   \n",
       "\n",
       "       Ave. - Density Adult - Correlation Bank - Correlation  \\\n",
       "CoDi         $16.8\\%$   $.495_{\\pm .000}$  $.344_{\\pm .001}$   \n",
       "Goggle       $15.1\\%$   $.447_{\\pm .016}$  $.232_{\\pm .000}$   \n",
       "\n",
       "       COMPAS - Correlation Ave. - Correlation  \n",
       "CoDi      $.550_{\\pm .001}$           $46.3\\%$  \n",
       "Goggle    $.332_{\\pm .022}$           $33.7\\%$  "
      ]
     },
     "execution_count": 79,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_df = pd.concat([adult_df, compas_df, bank_df], axis=1)\n",
    "density_cols = [col for col in all_df.columns if 'Density' in col]\n",
    "correlation_cols = [col for col in all_df.columns if 'Correlation' in col]\n",
    "all_df['Ave. - Density'] = all_df[density_cols].apply(lambda x: np.mean([float(i.split('_')[0][1:]) for i in x]), axis=1)\n",
    "all_df['Ave. - Correlation'] = all_df[correlation_cols].apply(lambda x: np.mean([float(i.split('_')[0][1:]) for i in x]), axis=1)\n",
    "\n",
    "# to percentage\n",
    "all_df['Ave. - Density'] = all_df['Ave. - Density'].map(lambda x: x * 100)\n",
    "all_df['Ave. - Correlation'] = all_df['Ave. - Correlation'].map(lambda x: x * 100)\n",
    "\n",
    "all_df['Ave. - Density'] = all_df['Ave. - Density'].map(lambda x: f\"${x:.1f}\\%$\")\n",
    "all_df['Ave. - Correlation'] = all_df['Ave. - Correlation'].map(lambda x: f\"${x:.1f}\\%$\")\n",
    "\n",
    "# reorder the columns\n",
    "all_df = all_df[\n",
    "    [\n",
    "        'Adult - Density', 'Bank - Density', 'COMPAS - Density', 'Ave. - Density',\n",
    "        'Adult - Correlation', 'Bank - Correlation', 'COMPAS - Correlation', 'Ave. - Correlation',\n",
    "    ]\n",
    "]\n",
    "all_df.head(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CoDi     & $.145_{\\pm .000}$ & $.154_{\\pm .000}$ & $.205_{\\pm .001}$ & $16.8\\%$ & $.495_{\\pm .000}$ & $.344_{\\pm .001}$ & $.550_{\\pm .001}$ & $46.3\\%$ \\\\\n",
      "Goggle   & $.158_{\\pm .001}$ & $.131_{\\pm .000}$ & $.164_{\\pm .001}$ & $15.1\\%$ & $.447_{\\pm .016}$ & $.232_{\\pm .000}$ & $.332_{\\pm .022}$ & $33.7\\%$ \\\\\n",
      "GReaT    & $.077_{\\pm .000}$ & $.102_{\\pm .001}$ & $.093_{\\pm .001}$ & $9.1\\%$ & $.208_{\\pm .015}$ & $.213_{\\pm .010}$ & $.165_{\\pm .016}$ & $19.5\\%$ \\\\\n",
      "SMOTE    & $.025_{\\pm .000}$ & $.020_{\\pm .000}$ & $.021_{\\pm .001}$ & $2.2\\%$ & $.054_{\\pm .004}$ & $.042_{\\pm .005}$ & $.047_{\\pm .005}$ & $4.8\\%$ \\\\\n",
      "STaSy    & $.102_{\\pm .000}$ & $.182_{\\pm .000}$ & $.108_{\\pm .001}$ & $13.1\\%$ & $.163_{\\pm .000}$ & $.221_{\\pm .001}$ & $.138_{\\pm .001}$ & $17.4\\%$ \\\\\n",
      "TabDDPM  & $.037_{\\pm .001}$ & $.028_{\\pm .001}$ & $.057_{\\pm .001}$ & $4.1\\%$ & $.055_{\\pm .001}$ & $.052_{\\pm .001}$ & $.090_{\\pm .001}$ & $6.6\\%$ \\\\\n",
      "TabSyn   & $.010_{\\pm .000}$ & $.009_{\\pm .000}$ & $.027_{\\pm .001}$ & $1.5\\%$ & $.035_{\\pm .001}$ & $.033_{\\pm .002}$ & $.054_{\\pm .001}$ & $4.1\\%$ \\\\\n",
      "FairCB   & $.076_{\\pm .000}$ & $.066_{\\pm .000}$ & $.039_{\\pm .000}$ & $6.0\\%$ & $.125_{\\pm .000}$ & $.111_{\\pm .000}$ & $.074_{\\pm .000}$ & $10.3\\%$ \\\\\n",
      "FairTGAN & $.034_{\\pm .000}$ & $.030_{\\pm .000}$ & $.055_{\\pm .001}$ & $4.0\\%$ & $.080_{\\pm .001}$ & $.053_{\\pm .000}$ & $.087_{\\pm .001}$ & $7.3\\%$ \\\\\n",
      "Ours     & $.126_{\\pm .001}$ & $.121_{\\pm .000}$ & $.109_{\\pm .001}$ & $11.9\\%$ & $.201_{\\pm .000}$ & $.174_{\\pm .005}$ & $.174_{\\pm .003}$ & $18.3\\%$ \\\\\n"
     ]
    }
   ],
   "source": [
    "methods = all_df.index\n",
    "max_str_len = max([len(m) for m in methods])\n",
    "\n",
    "for i in range(all_df.shape[0]):\n",
    "    method = all_df.index[i]\n",
    "    n_slash = max_str_len - len(method)\n",
    "    print(\n",
    "        method + ' ' * n_slash, '&',\n",
    "        all_df.iloc[i].values[0], '&',\n",
    "        all_df.iloc[i].values[1], '&',\n",
    "        all_df.iloc[i].values[2], '&',\n",
    "        all_df.iloc[i].values[3], '&',\n",
    "        all_df.iloc[i].values[4], '&',\n",
    "        all_df.iloc[i].values[5], '&',\n",
    "        all_df.iloc[i].values[6], '&',\n",
    "        all_df.iloc[i].values[7], '\\\\\\\\',\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "metadata": {},
   "outputs": [],
   "source": [
    "adult_df = get_tex_score_df_dataset('adult', idx_order, metrics=['dcr'])\n",
    "compas_df = get_tex_score_df_dataset('compass', idx_order, metrics=['dcr'])\n",
    "bank_df = get_tex_score_df_dataset('bank', idx_order, metrics=['dcr'])\n",
    "adult_df = adult_df.rename(columns={'dcr': 'Adult - DCR'})\n",
    "compas_df = compas_df.rename(columns={'dcr': 'COMPAS - DCR'})\n",
    "bank_df = bank_df.rename(columns={'dcr': 'Bank - DCR'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_df = pd.concat([adult_df, compas_df, bank_df], axis=1)\n",
    "dcr_cols = [col for col in all_df.columns if 'DCR' in col]  \n",
    "all_df['Ave. - DCR'] = all_df[dcr_cols].apply(lambda x: np.mean([float(i.split('_')[0][1:]) for i in x]), axis=1)\n",
    "all_df['Ave. - DCR'] = all_df['Ave. - DCR'].map(lambda x: x * 100)\n",
    "all_df['Ave. - DCR'] = all_df['Ave. - DCR'].map(lambda x: f\"${x:.1f}\\%$\")\n",
    "all_df = all_df[\n",
    "    [\n",
    "        'Adult - DCR', 'Bank - DCR', 'COMPAS - DCR', 'Ave. - DCR',\n",
    "    ]\n",
    "]\n",
    "# add an index Real to all_df\n",
    "all_df.loc['Real'] = ['-', '-', '-', '-']\n",
    "# put real at the beginning\n",
    "all_df = all_df.loc[['Real'] + list(all_df.index[:-1])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Adult - DCR</th>\n",
       "      <th>Bank - DCR</th>\n",
       "      <th>COMPAS - DCR</th>\n",
       "      <th>Ave. - DCR</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Real</th>\n",
       "      <td>-</td>\n",
       "      <td>-</td>\n",
       "      <td>-</td>\n",
       "      <td>-</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>CoDi</th>\n",
       "      <td>$.331_{\\pm .003}$</td>\n",
       "      <td>$.348_{\\pm .001}$</td>\n",
       "      <td>$.400_{\\pm .002}$</td>\n",
       "      <td>$36.0\\%$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Goggle</th>\n",
       "      <td>$.336_{\\pm .002}$</td>\n",
       "      <td>$.354_{\\pm .001}$</td>\n",
       "      <td>$.356_{\\pm .003}$</td>\n",
       "      <td>$34.9\\%$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>GReaT</th>\n",
       "      <td>$.320_{\\pm .002}$</td>\n",
       "      <td>$.348_{\\pm .003}$</td>\n",
       "      <td>$.377_{\\pm .003}$</td>\n",
       "      <td>$34.8\\%$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>SMOTE</th>\n",
       "      <td>$.327_{\\pm .004}$</td>\n",
       "      <td>$.265_{\\pm .001}$</td>\n",
       "      <td>$.273_{\\pm .003}$</td>\n",
       "      <td>$28.8\\%$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>STaSy</th>\n",
       "      <td>$.344_{\\pm .001}$</td>\n",
       "      <td>$.345_{\\pm .002}$</td>\n",
       "      <td>$.362_{\\pm .001}$</td>\n",
       "      <td>$35.0\\%$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TabDDPM</th>\n",
       "      <td>$.339_{\\pm .000}$</td>\n",
       "      <td>$.350_{\\pm .002}$</td>\n",
       "      <td>$.367_{\\pm .005}$</td>\n",
       "      <td>$35.2\\%$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TabSyn</th>\n",
       "      <td>$.339_{\\pm .002}$</td>\n",
       "      <td>$.351_{\\pm .005}$</td>\n",
       "      <td>$.367_{\\pm .006}$</td>\n",
       "      <td>$35.2\\%$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>FairCB</th>\n",
       "      <td>$.054_{\\pm .000}$</td>\n",
       "      <td>$.031_{\\pm .001}$</td>\n",
       "      <td>$.012_{\\pm .001}$</td>\n",
       "      <td>$3.2\\%$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>FairTGAN</th>\n",
       "      <td>$.348_{\\pm .002}$</td>\n",
       "      <td>$.348_{\\pm .004}$</td>\n",
       "      <td>$.374_{\\pm .006}$</td>\n",
       "      <td>$35.7\\%$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Ours</th>\n",
       "      <td>$.344_{\\pm .001}$</td>\n",
       "      <td>$.350_{\\pm .002}$</td>\n",
       "      <td>$.370_{\\pm .004}$</td>\n",
       "      <td>$35.5\\%$</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                Adult - DCR         Bank - DCR       COMPAS - DCR Ave. - DCR\n",
       "Real                      -                  -                  -          -\n",
       "CoDi      $.331_{\\pm .003}$  $.348_{\\pm .001}$  $.400_{\\pm .002}$   $36.0\\%$\n",
       "Goggle    $.336_{\\pm .002}$  $.354_{\\pm .001}$  $.356_{\\pm .003}$   $34.9\\%$\n",
       "GReaT     $.320_{\\pm .002}$  $.348_{\\pm .003}$  $.377_{\\pm .003}$   $34.8\\%$\n",
       "SMOTE     $.327_{\\pm .004}$  $.265_{\\pm .001}$  $.273_{\\pm .003}$   $28.8\\%$\n",
       "STaSy     $.344_{\\pm .001}$  $.345_{\\pm .002}$  $.362_{\\pm .001}$   $35.0\\%$\n",
       "TabDDPM   $.339_{\\pm .000}$  $.350_{\\pm .002}$  $.367_{\\pm .005}$   $35.2\\%$\n",
       "TabSyn    $.339_{\\pm .002}$  $.351_{\\pm .005}$  $.367_{\\pm .006}$   $35.2\\%$\n",
       "FairCB    $.054_{\\pm .000}$  $.031_{\\pm .001}$  $.012_{\\pm .001}$    $3.2\\%$\n",
       "FairTGAN  $.348_{\\pm .002}$  $.348_{\\pm .004}$  $.374_{\\pm .006}$   $35.7\\%$\n",
       "Ours      $.344_{\\pm .001}$  $.350_{\\pm .002}$  $.370_{\\pm .004}$   $35.5\\%$"
      ]
     },
     "execution_count": 104,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Real     & - & - & - \\\\\n",
      "CoDi     & $.331_{\\pm .003}$ & $.348_{\\pm .001}$ & $.400_{\\pm .002}$ \\\\\n",
      "Goggle   & $.336_{\\pm .002}$ & $.354_{\\pm .001}$ & $.356_{\\pm .003}$ \\\\\n",
      "GReaT    & $.320_{\\pm .002}$ & $.348_{\\pm .003}$ & $.377_{\\pm .003}$ \\\\\n",
      "SMOTE    & $.327_{\\pm .004}$ & $.265_{\\pm .001}$ & $.273_{\\pm .003}$ \\\\\n",
      "STaSy    & $.344_{\\pm .001}$ & $.345_{\\pm .002}$ & $.362_{\\pm .001}$ \\\\\n",
      "TabDDPM  & $.339_{\\pm .000}$ & $.350_{\\pm .002}$ & $.367_{\\pm .005}$ \\\\\n",
      "TabSyn   & $.339_{\\pm .002}$ & $.351_{\\pm .005}$ & $.367_{\\pm .006}$ \\\\\n",
      "FairCB   & $.054_{\\pm .000}$ & $.031_{\\pm .001}$ & $.012_{\\pm .001}$ \\\\\n",
      "FairTGAN & $.348_{\\pm .002}$ & $.348_{\\pm .004}$ & $.374_{\\pm .006}$ \\\\\n",
      "Ours     & $.344_{\\pm .001}$ & $.350_{\\pm .002}$ & $.370_{\\pm .004}$ \\\\\n"
     ]
    }
   ],
   "source": [
    "for i in range(all_df.shape[0]):\n",
    "    method = all_df.index[i]\n",
    "    n_slash = max_str_len - len(method)\n",
    "    print(\n",
    "        method + ' ' * n_slash, '&',\n",
    "        all_df.iloc[i].values[0], '&',\n",
    "        all_df.iloc[i].values[1], '&',\n",
    "        all_df.iloc[i].values[2], '\\\\\\\\',\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_auc(dataset, method, mode='original', stats=False):\n",
    "    js_path = os.path.join(EVAL_PATH, 'learning', dataset, f'best_{mode}_{method}.json')\n",
    "    scores = {}\n",
    "    ans = load_json(js_path)['CatBoost']\n",
    "    for key in ans:\n",
    "        scores[key] = ans[key]['Test']\n",
    "    scores = scores['AUC']\n",
    "    \n",
    "    if stats:\n",
    "        mu, sigma = np.mean(scores), np.std(scores)\n",
    "        mu, sigma = round(mu, 3), round(sigma, 3)\n",
    "        return (mu, sigma)\n",
    "    \n",
    "    return scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_all_auc(mode='original'):\n",
    "    assert mode in ['original', 'uniform']\n",
    "    datasets = ['adult', 'bank', 'compass']\n",
    "    methods = RESULTS['methods'] + ['real']\n",
    "    auc_scores = {}\n",
    "    for dataset in datasets:\n",
    "        auc_scores[dataset] = {}\n",
    "        for method in methods:\n",
    "            auc = read_auc(dataset, method, mode=mode, stats=True)\n",
    "            auc_scores[dataset][method] = auc\n",
    "    return auc_scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "metadata": {},
   "outputs": [],
   "source": [
    "def to_latex_percentage(num):\n",
    "    # input is a float\n",
    "    return f\"${num:.1f}\\%$\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 116,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Real     & $.928_{\\pm .000}$ & $.936_{\\pm .000}$ & $.799_{\\pm .000}$ & $88.8\\%$ & \\\\\n",
      "CoDi     & $.858_{\\pm .002}$ & $.827_{\\pm .015}$ & $.678_{\\pm .008}$ & $78.8\\%$ & \\\\\n",
      "Goggle   & $.738_{\\pm .004}$ & $.729_{\\pm .005}$ & $.653_{\\pm .010}$ & $70.7\\%$ & \\\\\n",
      "GReaT    & $.900_{\\pm .003}$ & $.690_{\\pm .026}$ & $.714_{\\pm .010}$ & $76.8\\%$ & \\\\\n",
      "SMOTE    & $.914_{\\pm .000}$ & $.928_{\\pm .001}$ & $.772_{\\pm .002}$ & $87.1\\%$ & \\\\\n",
      "STaSy    & $.883_{\\pm .000}$ & $.894_{\\pm .003}$ & $.722_{\\pm .014}$ & $83.3\\%$ & \\\\\n",
      "TabDDPM  & $.906_{\\pm .001}$ & $.917_{\\pm .002}$ & $.745_{\\pm .002}$ & $85.6\\%$ & \\\\\n",
      "TabSyn   & $.910_{\\pm .000}$ & $.918_{\\pm .001}$ & $.745_{\\pm .001}$ & $85.8\\%$ & \\\\\n",
      "FairCB   & $.914_{\\pm .000}$ & $.907_{\\pm .001}$ & $.765_{\\pm .001}$ & $86.2\\%$ & \\\\\n",
      "FairTGAN & $.885_{\\pm .002}$ & $.873_{\\pm .006}$ & $.703_{\\pm .002}$ & $82.0\\%$ & \\\\\n",
      "Ours     & $.895_{\\pm .001}$ & $.913_{\\pm .002}$ & $.732_{\\pm .001}$ & $84.7\\%$ & \\\\\n"
     ]
    }
   ],
   "source": [
    "mode = 'original'\n",
    "mode = 'uniform'\n",
    "auc_df = pd.DataFrame(read_all_auc(mode))\n",
    "# average the first elment of the tuple in the dataframe\n",
    "auc_df['average'] = auc_df.apply(lambda x: np.mean([i[0] for i in x.values]), axis=1)\n",
    "\n",
    "# convert to latex for adult, bank, and compass\n",
    "for dataset in ['adult', 'bank', 'compass']:\n",
    "    for method in RESULTS['methods']:\n",
    "        auc_df[dataset][method] = convert_to_latex(auc_df[dataset][method])\n",
    "    auc_df[dataset]['real'] = convert_to_latex(auc_df[dataset]['real'])\n",
    "\n",
    "# transform the average to differnce with real as the baseline\n",
    "# auc_df['average'] = (auc_df['average'] - auc_df['average']['real']) / auc_df['average']['real']\n",
    "\n",
    "# convert average to percentage and then to latex\n",
    "auc_df['average'] = auc_df['average'] * 100 \n",
    "\n",
    "# convert to latex\n",
    "auc_df['average'] = auc_df['average'].map(to_latex_percentage)\n",
    "\n",
    "# rename the columns\n",
    "auc_df = auc_df.rename(columns={\n",
    "    'average': 'Average',\n",
    "    'adult': 'Adult',\n",
    "    'bank': 'Bank',\n",
    "    'compass': 'COMPAS',\n",
    "})\n",
    "\n",
    "# reorder the index\n",
    "auc_df = auc_df.reindex(['real'] + idx_order)\n",
    "\n",
    "# rename the index\n",
    "auc_df = auc_df.rename(index=METHOD_MAPPER)\n",
    "\n",
    "methods = auc_df.index\n",
    "max_str_len = max([len(m) for m in methods])\n",
    "\n",
    "for i in range(auc_df.shape[0]):\n",
    "    method = auc_df.index[i]\n",
    "    n_slash = max_str_len - len(method)\n",
    "    print(\n",
    "        method + ' ' * n_slash, '&',\n",
    "        auc_df.iloc[i].values[0], '&',\n",
    "        auc_df.iloc[i].values[1], '&',\n",
    "        auc_df.iloc[i].values[2], '&',\n",
    "        auc_df.iloc[i].values[3], '&',\n",
    "        # all_df.iloc[i].values[0], '&',\n",
    "        # all_df.iloc[i].values[1], '&',\n",
    "        # all_df.iloc[i].values[2], '&',\n",
    "        # all_df.iloc[i].values[3], \n",
    "        '\\\\\\\\',\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 112,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_fair(dataset, method, mode='original', metric='DPR', stats=False):\n",
    "    js_path = os.path.join(EVAL_PATH, 'learning', dataset, f'best_{mode}_{method}.json')\n",
    "    scores = {}\n",
    "    ans = load_json(js_path)['CatBoost']\n",
    "    for key in ans:\n",
    "        scores[key] = ans[key]['Test']\n",
    "    \n",
    "    if dataset == 'adult' or dataset == 'compass':\n",
    "        attri = 'sex'\n",
    "    elif dataset == 'bank':\n",
    "        attri = 'age-group'\n",
    "        \n",
    "    scores = scores[metric][attri]\n",
    "    \n",
    "    if stats:\n",
    "        mu, sigma = np.mean(scores), np.std(scores)\n",
    "        mu, sigma = round(mu, 3), round(sigma, 3)\n",
    "        return (mu, sigma)\n",
    "    \n",
    "    return scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 117,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "uniform\n"
     ]
    }
   ],
   "source": [
    "print(mode)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 118,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Real     & $.307_{\\pm .004}$ & $.415_{\\pm .008}$ & $.790_{\\pm .005}$ & $50.4\\%$ & $.192_{\\pm .008}$ & $.368_{\\pm .008}$ & $.857_{\\pm .007}$ & $47.2\\%$ \\\\\n",
      "CoDi     & $.345_{\\pm .030}$ & $.402_{\\pm .063}$ & $.901_{\\pm .011}$ & $54.9\\%$ & $.328_{\\pm .038}$ & $.556_{\\pm .125}$ & $.923_{\\pm .010}$ & $60.2\\%$ \\\\\n",
      "Goggle   & $.838_{\\pm .027}$ & $.666_{\\pm .004}$ & $.854_{\\pm .094}$ & $78.6\\%$ & $.879_{\\pm .034}$ & $.715_{\\pm .007}$ & $.878_{\\pm .093}$ & $82.4\\%$ \\\\\n",
      "GReaT    & $.256_{\\pm .010}$ & $.622_{\\pm .269}$ & $.814_{\\pm .024}$ & $56.4\\%$ & $.169_{\\pm .017}$ & $.405_{\\pm .079}$ & $.802_{\\pm .055}$ & $45.9\\%$ \\\\\n",
      "SMOTE    & $.331_{\\pm .017}$ & $.413_{\\pm .025}$ & $.777_{\\pm .028}$ & $50.7\\%$ & $.276_{\\pm .045}$ & $.359_{\\pm .026}$ & $.826_{\\pm .041}$ & $48.7\\%$ \\\\\n",
      "STaSy    & $.304_{\\pm .072}$ & $.532_{\\pm .060}$ & $.780_{\\pm .072}$ & $53.9\\%$ & $.240_{\\pm .106}$ & $.542_{\\pm .065}$ & $.739_{\\pm .073}$ & $50.7\\%$ \\\\\n",
      "TabDDPM  & $.281_{\\pm .004}$ & $.341_{\\pm .017}$ & $.750_{\\pm .031}$ & $45.7\\%$ & $.191_{\\pm .010}$ & $.339_{\\pm .010}$ & $.816_{\\pm .040}$ & $44.9\\%$ \\\\\n",
      "TabSyn   & $.286_{\\pm .005}$ & $.319_{\\pm .018}$ & $.815_{\\pm .016}$ & $47.3\\%$ & $.182_{\\pm .011}$ & $.285_{\\pm .022}$ & $.847_{\\pm .016}$ & $43.8\\%$ \\\\\n",
      "FairCB   & $.305_{\\pm .005}$ & $.623_{\\pm .011}$ & $.762_{\\pm .023}$ & $56.3\\%$ & $.227_{\\pm .006}$ & $.680_{\\pm .021}$ & $.824_{\\pm .023}$ & $57.7\\%$ \\\\\n",
      "FairTGAN & $.481_{\\pm .021}$ & $.812_{\\pm .115}$ & $.692_{\\pm .026}$ & $66.2\\%$ & $.544_{\\pm .036}$ & $.567_{\\pm .169}$ & $.732_{\\pm .024}$ & $61.4\\%$ \\\\\n",
      "Ours     & $.503_{\\pm .017}$ & $.690_{\\pm .085}$ & $.781_{\\pm .006}$ & $65.8\\%$ & $.580_{\\pm .027}$ & $.631_{\\pm .087}$ & $.781_{\\pm .015}$ & $66.4\\%$ \\\\\n"
     ]
    }
   ],
   "source": [
    "dpr_dict = {}\n",
    "for dataset in ['adult', 'bank', 'compass']:\n",
    "    dpr_dict[dataset] = {}\n",
    "    for method in ['real'] + RESULTS['methods']:\n",
    "        dpr = read_fair(dataset, method, metric='DPR', stats=True, mode=mode)\n",
    "        dpr_dict[dataset][method] = dpr\n",
    "\n",
    "dpr_df = pd.DataFrame(dpr_dict)\n",
    "dpr_df['average'] = dpr_df.apply(lambda x: np.mean([i[0] for i in x.values]), axis=1)\n",
    "\n",
    "# convert to latex for adult, bank, and compass\n",
    "for dataset in ['adult', 'bank', 'compass']:\n",
    "    for method in RESULTS['methods']:\n",
    "        dpr_df[dataset][method] = convert_to_latex(dpr_df[dataset][method])\n",
    "    dpr_df[dataset]['real'] = convert_to_latex(dpr_df[dataset]['real'])\n",
    "\n",
    "dpr_df['average'] = dpr_df['average'] * 100 \n",
    "dpr_df['average'] = dpr_df['average'].map(to_latex_percentage)\n",
    "\n",
    "eor_dict = {}\n",
    "for dataset in ['adult', 'bank', 'compass']:\n",
    "    eor_dict[dataset] = {}\n",
    "    for method in ['real'] + RESULTS['methods']:\n",
    "        eor = read_fair(dataset, method, metric='EOR', stats=True, mode=mode)\n",
    "        eor_dict[dataset][method] = eor\n",
    "        \n",
    "eor_df = pd.DataFrame(eor_dict)\n",
    "eor_df['average'] = eor_df.apply(lambda x: np.mean([i[0] for i in x.values]), axis=1)\n",
    "\n",
    "# convert to latex for adult, bank, and compass\n",
    "for dataset in ['adult', 'bank', 'compass']:\n",
    "    for method in RESULTS['methods']:\n",
    "        eor_df[dataset][method] = convert_to_latex(eor_df[dataset][method])\n",
    "    eor_df[dataset]['real'] = convert_to_latex(eor_df[dataset]['real'])\n",
    "\n",
    "eor_df['average'] = eor_df['average'] * 100\n",
    "eor_df['average'] = eor_df['average'].map(to_latex_percentage)\n",
    "\n",
    "# reorder the index\n",
    "dpr_df = dpr_df.reindex(['real'] + idx_order)\n",
    "eor_df = eor_df.reindex(['real'] + idx_order)\n",
    "\n",
    "# rename the index\n",
    "dpr_df = dpr_df.rename(index=METHOD_MAPPER)\n",
    "eor_df = eor_df.rename(index=METHOD_MAPPER)\n",
    "\n",
    "fair_df = pd.concat([dpr_df, eor_df], axis=1)\n",
    "\n",
    "for i in range(fair_df.shape[0]):\n",
    "    method = fair_df.index[i]\n",
    "    n_slash = max_str_len - len(method)\n",
    "    print(\n",
    "        method + ' ' * n_slash, '&',\n",
    "        fair_df.iloc[i].values[0], '&',\n",
    "        fair_df.iloc[i].values[1], '&',\n",
    "        fair_df.iloc[i].values[2], '&',\n",
    "        fair_df.iloc[i].values[3], '&',\n",
    "        fair_df.iloc[i].values[4], '&',\n",
    "        fair_df.iloc[i].values[5], '&',\n",
    "        fair_df.iloc[i].values[6], '&',\n",
    "        fair_df.iloc[i].values[7], '\\\\\\\\',\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
