{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "91b8ce4c",
   "metadata": {},
   "source": [
    "Please, before executing this notebook train some models such as: python3 german.py 5 0.7 cuda:0 ROAD TEST_GERMAN_ROAD_10 20 10 10 10 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "266fd6b0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-28T03:18:06.873409Z",
     "start_time": "2023-09-28T03:18:06.850374Z"
    }
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "from matplotlib import pyplot as plt\n",
    "from sklearn import datasets as ds\n",
    "import sys, os\n",
    "sys.path.insert(1, os.path.join(sys.path[0], '..'))\n",
    "from utils.fairness_utils import pareto_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14146988",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-28T03:18:07.193058Z",
     "start_time": "2023-09-28T03:18:07.169255Z"
    }
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c36b4ef8",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-28T03:18:07.606232Z",
     "start_time": "2023-09-28T03:18:07.563650Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "FOLDER = \"./../results_german/\"\n",
    "\n",
    "DATASET_NAME = \"GERMAN\"\n",
    "\n",
    "fROAD = \"TEST_\" + DATASET_NAME + \"_ROAD_10_DP\" \n",
    "\n",
    "# colors\n",
    "#palette = plt.cm.tab20\n",
    "palette = sns.color_palette(\"bright\", 10)\n",
    "\n",
    "METHODS_DICT = {\"road\": {\"fname\":fROAD,\n",
    "                          \"name\": \"ROAD (Ours)\",\n",
    "                          \"linestyle\":\"-\",\n",
    "                          \"color\": palette[2]}\n",
    "                \n",
    "}\n",
    "\n",
    "for method in METHODS_DICT:\n",
    "    df_meth = pd.read_csv(FOLDER + METHODS_DICT[method][\"fname\"] + \".csv\")\n",
    "    df_meth[\"lambda\"] = df_meth[\"lambda\"].astype('str')\n",
    "    df_meth[\"lambda\"] = df_meth[\"lambda\"].map(lambda x: x[:4])\n",
    "    df_meth[\"tau\"] = df_meth[\"tau\"].astype('str')\n",
    "    df_meth[\"tau\"] = df_meth[\"tau\"].map(lambda x: x[:4])\n",
    "    df_meth[\"run_id\"] = df_meth[\"run_id\"].astype('str')\n",
    "    \n",
    "    df_meth_rdm = pd.read_csv(FOLDER + METHODS_DICT[method][\"fname\"] + \"_rdm.csv\")\n",
    "    df_meth_rdm[\"lambda\"] = df_meth_rdm[\"lambda\"].astype('str')\n",
    "    df_meth_rdm[\"lambda\"] = df_meth_rdm[\"lambda\"].map(lambda x: x[:4])\n",
    "    df_meth_rdm[\"tau\"] = df_meth_rdm[\"tau\"].astype('str')\n",
    "    df_meth_rdm[\"tau\"] = df_meth_rdm[\"tau\"].map(lambda x: x[:4])\n",
    "    df_meth_rdm[\"run_id\"] = df_meth_rdm[\"run_id\"].astype('str')\n",
    "    \n",
    "    METHODS_DICT[method][\"df\"] = df_meth\n",
    "    METHODS_DICT[method][\"df_rdm\"] = df_meth_rdm\n",
    "       \n",
    "# for printing labels on figures\n",
    "FEATURE_NAMES = {\"Global DI\": \"Global Fairness (DI)\",\n",
    "                \"Global Acc\": \"Accuracy\",\n",
    "                \"top1_DI\": \"Local Fairness (worst 1 DI)\",\n",
    "                \"top3_DI\": \"Local Fairness (worst 3 DI)\", \n",
    "                \"q_DI_0.8\": \"Local Fairness (0.8 quantile)\"}\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6166f8fe",
   "metadata": {},
   "source": [
    "# Pareto curves"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "75eb9017",
   "metadata": {},
   "source": [
    "## 1. Acc - Local Fairness"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "90c6a6f2",
   "metadata": {},
   "source": [
    "### 1.a Demographic subgroups"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d62782a3",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-28T01:56:11.171425Z",
     "start_time": "2023-09-28T01:55:39.969701Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "#### WARNING/ Slow because of PAreto\n",
    "\n",
    "\n",
    "# Pareto plot Global Acc * top1DI à iso Global DI\n",
    "sns.set_style('whitegrid')\n",
    "\n",
    "sns.set_palette(\"bright\")\n",
    "\n",
    "#### IMPORTANT: this defines the feature to \"fix\" and the corresponding range. E.g.: keeping only models \n",
    "# with GLOBAL DI between 0.0 and 0.05\n",
    "iso_f = \"Global DI\"\n",
    "GDIRANGE = 0.0, 0.05\n",
    "\n",
    "### FEATURES FOR X AND Y AXES\n",
    "f1, f2 = \"top1_DI\", \"Global Acc\"\n",
    "\n",
    "\n",
    "legend_names = []\n",
    "for key in METHODS_DICT:\n",
    "    print( key)\n",
    "    method = METHODS_DICT[key]    \n",
    "    try:\n",
    "        df = method[\"df\"]\n",
    "        df2 = df[(df[iso_f] >= GDIRANGE[0])&(df[iso_f] < GDIRANGE[1])]\n",
    "        color = method[\"color\"]\n",
    "        linsty = method[\"linestyle\"]  \n",
    "\n",
    "    except KeyError:\n",
    "        continue\n",
    "    ### PARETO FRONT\n",
    "    dff = pareto_df(df2, f1, f2)\n",
    "    g = sns.lineplot(data=dff[dff[\"pareto\"]==1], x=f1, y=f2, color=color, linewidth=3, label=method[\"name\"],linestyle =linsty)\n",
    "    g2 = sns.scatterplot(data=dff[dff[\"pareto\"]==1], x=f1, y=f2, color=color, legend=False, s=30, markers=\"o\",linestyle =linsty)\n",
    "    legend_names.append(method[\"name\"])\n",
    "\n",
    "g.invert_xaxis()\n",
    "\n",
    "plt.legend(fontsize=12)\n",
    "#plt.legend(legend_names, fontsize=13, loc='lower left')\n",
    "\n",
    "\n",
    "plt.xlabel('Local Unfairness (worst 1 DI)', fontsize=17)\n",
    "plt.ylabel(FEATURE_NAMES[f2], fontsize=17)\n",
    "plt.xticks(fontsize=12)\n",
    "plt.yticks(fontsize=12)\n",
    "plt.ylim((0.67, 0.82))\n",
    "plt.title(DATASET_NAME + ' (Global DI <' + str(GDIRANGE[1]) + ')', fontsize=17)\n",
    "plt.tight_layout()\n",
    "\n",
    "FNAME = \"../results_img/\"\n",
    "\n",
    "plt.savefig(FNAME + \"german_paretoTop1.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "243e7621",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-28T01:57:18.464275Z",
     "start_time": "2023-09-28T01:56:52.779003Z"
    }
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "#### WARNING/ Slow because of PAreto\n",
    "\n",
    "\n",
    "# Pareto plot Global Acc * top1DI à iso Global DI\n",
    "sns.set_style('whitegrid')\n",
    "\n",
    "sns.set_palette(\"bright\")\n",
    "\n",
    "#### IMPORTANT: this defines the feature to \"fix\" and the corresponding range. E.g.: keeping only models \n",
    "# with GLOBAL DI between 0.0 and 0.05\n",
    "iso_f = \"Global DI\"\n",
    "GDIRANGE = 0.0, 0.05\n",
    "\n",
    "### FEATURES FOR X AND Y AXES\n",
    "f1, f2 = \"top3_DI\", \"Global Acc\"\n",
    "\n",
    "\n",
    "legend_names = []\n",
    "for key in METHODS_DICT:\n",
    "    print( key)\n",
    "    method = METHODS_DICT[key]    \n",
    "    try:\n",
    "        df = method[\"df\"]\n",
    "        df2 = df[(df[iso_f] > GDIRANGE[0])&(df[iso_f] < GDIRANGE[1])]\n",
    "        color = method[\"color\"]\n",
    "    except KeyError:\n",
    "        continue\n",
    "        \n",
    "    ### PARETO FRONT\n",
    "    dff = pareto_df(df2, f1, f2)\n",
    "    g = sns.lineplot(data=dff[dff[\"pareto\"]==1], x=f1, y=f2, color=color, linewidth=3, label=method[\"name\"])\n",
    "    g2 = sns.scatterplot(data=dff[dff[\"pareto\"]==1], x=f1, y=f2, color=color, legend=False, s=30, markers=\"o\")\n",
    "    legend_names.append(method[\"name\"])\n",
    "\n",
    "g.invert_xaxis()\n",
    "\n",
    "plt.legend(fontsize=12)\n",
    "#plt.legend(legend_names, fontsize=13, loc='lower left')\n",
    "\n",
    "\n",
    "plt.xlabel('Local Unfairness (worst 3 DI)', fontsize=17)\n",
    "plt.ylabel(FEATURE_NAMES[f2], fontsize=17)\n",
    "plt.xticks(fontsize=12)\n",
    "plt.yticks(fontsize=12)\n",
    "plt.ylim((0.67, 0.8))\n",
    "plt.title(DATASET_NAME + ' (Global DI <' + str(GDIRANGE[1]) + ')', fontsize=17)\n",
    "plt.tight_layout()\n",
    "\n",
    "FNAME = \"./results_img/\"\n",
    "\n",
    "#plt.savefig(FNAME + \"german_paretoTop3.pdf\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e8c7ad8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
