{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "31a28e9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import glob\n",
    "from itertools import zip_longest"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8d60f8d2",
   "metadata": {},
   "source": [
    "## table for to complement the metrics figures \n",
    "Columns: Methods\n",
    "\n",
    "Rows: Max/Min/Gap AUC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "b15195b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "methods = ['baseline', 'resampling', 'DomainInd', 'LAFTR', 'CFair', 'LNL', 'EnD', 'ODR', 'GroupDRO', 'SWAD', 'SAM']\n",
    "result_path = '../results/csv/'\n",
    "results = glob.glob(result_path + '*.csv')\n",
    "\n",
    "def get_aucs(dataset, sensitive):\n",
    "    filename = dataset + '-' + sensitive + '.csv'\n",
    "    mean_std_df = pd.read_csv(result_path + filename)\n",
    "    mean_std = mean_std_df.to_numpy()\n",
    "\n",
    "    rt = np.char.split(mean_std[:, 1:].astype('str'), u\"\\u00B1\")\n",
    "    rt = np.delete(rt, -2, 0)\n",
    "    means = np.zeros(rt.shape)\n",
    "    #stds = np.zeros(rt.shape)\n",
    "    \n",
    "    for i in range(rt.shape[0]):\n",
    "        for j in range(rt.shape[1]):\n",
    "            means[i][j] = float(rt[i][j][0])\n",
    "            #stds[i][j] = float(rt[i][j][1])\n",
    "    means = pd.DataFrame(means, columns = mean_std_df.columns[1:])\n",
    "\n",
    "    min_auc = means['Test worst_auc'].values * 100.\n",
    "\n",
    "    aucs = means[list(means.filter(regex=(\"auc\")))]\n",
    "    aucs = aucs.to_numpy()\n",
    "    max_auc = []\n",
    "    for i in range(aucs.shape[0]):\n",
    "        temp = aucs[i].tolist()\n",
    "        max_auc.append(max(temp)* 100.)\n",
    "    return np.asarray(max_auc), min_auc, np.asarray(max_auc)-min_auc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "1649e5ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_row(lst, minmax):\n",
    "    # bold min/max values\n",
    "\n",
    "    print(\"& \", lst[0], end=' ')\n",
    "    value_list = lst[1:]\n",
    "    if minmax == 2:\n",
    "        idx = [i for i, j in enumerate(value_list) if j == min(value_list)]\n",
    "    else:\n",
    "        idx = [i for i, j in enumerate(value_list) if j == max(value_list)]\n",
    "    for i, itm in enumerate(value_list):\n",
    "        if i not in idx:\n",
    "            print(\"& \", \"%.2f\" % itm, end=' ')\n",
    "        else:\n",
    "            print(\"& \", \"\\\\textbf{%.2f}\" % itm, end=' ')\n",
    "    print(\"\\\\\\\\\", sep=\" \")\n",
    "\n",
    "def print_head(methods):\n",
    "    head_list = list(df2.columns)\n",
    "    print(\"\\\\textbf{Attr.} & \\\\textbf{Metrics} \", end=\" \")\n",
    "    for method in methods:\n",
    "        print(\"& \", \"\\\\textbf{{{}}}\".format(method), end = ' ')\n",
    "    print(\"\\\\\\\\\")\n",
    "\n",
    "# for each dataset, aggregate the results of different attributes first (or prepare them into different dataframe), then print\n",
    "def print_table(dfs, dataset, attrs, methods):\n",
    "    \"\"\"Pretty-print a 2D array of data, optionally with row/col labels\"\"\"\n",
    "    \n",
    "    print(\"\\\\begin{table}[h]\")\n",
    "    print(\"\\\\caption{{Methods on {datas} dataset.}}\".format(datas=dataset))\n",
    "    print(\"\\\\begin{center}\")\n",
    "\n",
    "    num_methods = len(dfs[0].columns) - 1\n",
    "    print(\"\\\\resizebox{\\\\textwidth}{!}{%\")\n",
    "    print(\"\\\\begin{tabular}{c|c|\" + \"r\" * num_methods + \"}\")\n",
    "    \n",
    "    print(\"\\\\toprule\")\n",
    "    print_head(methods)\n",
    "\n",
    "    for i, df in enumerate(dfs):\n",
    "        print(\"\\\\midrule\")\n",
    "        print(\"\\\\multirow{{3}}{{*}}{{{}}} \".format(attrs[i]), end='')\n",
    "        for j in range(len(df)):\n",
    "            metric_list = df.iloc[j].values.tolist()\n",
    "            print_row(metric_list, j)\n",
    "        \n",
    "    print(\"\\\\bottomrule\")\n",
    "    print(\"\\\\end{tabular}\")\n",
    "    print(\"}\")\n",
    "    print(\"\\\\end{center}\")\n",
    "    print(\"\\\\end{table}\")\n",
    "    print('\\n' * 4)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "448b0db8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO: merge results with different names\n",
    "\n",
    "exp_dict = {'HAM10000': [{'attribute': 'Sex', 'sens_num': 2}], 'HAM100004': [{'attribute': 'Age', 'sens_num': 4}], \n",
    "           'CXP': [{'attribute': 'Age', 'sens_num': 5}, {'attribute': 'Sex', 'sens_num': 2}, {'attribute': 'Race', 'sens_num': 2}],\n",
    "           'MIMIC_CXR': [{'attribute': 'Age', 'sens_num': 5}, {'attribute': 'Sex', 'sens_num': 2}, {'attribute': 'Race', 'sens_num': 2}],\n",
    "           'PAPILA': [{'attribute': 'Age', 'sens_num': 2}, {'attribute': 'Sex', 'sens_num': 2}],\n",
    "           'ADNI': [{'attribute': 'Age', 'sens_num': 2}, {'attribute': 'Sex', 'sens_num': 2}],\n",
    "           'OCT': [{'attribute': 'Age', 'sens_num': 2}],}\n",
    "\n",
    "for dataset, sens_exps in exp_dict.items():\n",
    "    dfs, attrs = [], []\n",
    "    for sens_exp in sens_exps:\n",
    "        attribute = sens_exp['attribute']\n",
    "        sens_num = sens_exp['sens_num']\n",
    "        auc_list = get_aucs(dataset, attribute)\n",
    "\n",
    "        df = pd.DataFrame(auc_list, columns = methods)\n",
    "        df['Metrics'] = ['Max.', 'Min.', 'Gap']\n",
    "        df = df[ ['Metrics'] + [ col for col in df.columns if col != 'Metrics' ] ]\n",
    "        dfs.append(df)\n",
    "        attrs.append(attribute)\n",
    "    print_table(dfs, dataset, attrs, methods)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch11",
   "language": "python",
   "name": "torch11"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
