{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c2ff89e0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-28T01:49:56.551039Z",
     "start_time": "2023-09-28T01:49:56.534931Z"
    }
   },
   "source": [
    "Please, before executing this notebook, train some models such as:\n",
    "\"python3 law.py 5 0.7 cuda:0 ROAD TEST_LAW_ROAD_10 20 10 10 10\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "266fd6b0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-29T02:11:08.465706Z",
     "start_time": "2023-09-29T02:11:08.446312Z"
    }
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "from matplotlib import pyplot as plt\n",
    "from sklearn import datasets as ds\n",
    "\n",
    "import sys\n",
    "import os\n",
    "sys.path.insert(1, os.path.join(sys.path[0], '..'))\n",
    "from utils.fairness_utils import pareto_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14146988",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-29T02:11:08.940842Z",
     "start_time": "2023-09-29T02:11:08.920333Z"
    }
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c36b4ef8",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-29T02:11:09.261913Z",
     "start_time": "2023-09-29T02:11:09.167308Z"
    }
   },
   "outputs": [],
   "source": [
    "FOLDER = \"./../results_law/\"\n",
    "\n",
    "DATASET_NAME = \"LAW\"\n",
    "\n",
    "\n",
    "\n",
    "fROAD = \"TEST_\" + DATASET_NAME + \"_ROAD_10_DP\" \n",
    "\n",
    "# colors\n",
    "#palette = plt.cm.tab20\n",
    "palette = sns.color_palette(\"bright\", 10)\n",
    "\n",
    "METHODS_DICT = {\"road\": {\"fname\":fROAD,\n",
    "                          \"name\": \"ROAD (Ours)\",\n",
    "                          \"linestyle\":\"-\",\n",
    "                          \"color\": palette[2]}\n",
    "                \n",
    "}\n",
    "for method in METHODS_DICT:\n",
    "    df_meth = pd.read_csv(FOLDER + METHODS_DICT[method][\"fname\"] + \".csv\")\n",
    "    df_meth[\"lambda\"] = df_meth[\"lambda\"].astype('str')\n",
    "    df_meth[\"lambda\"] = df_meth[\"lambda\"].map(lambda x: x[:4])\n",
    "    df_meth[\"tau\"] = df_meth[\"tau\"].astype('str')\n",
    "    df_meth[\"tau\"] = df_meth[\"tau\"].map(lambda x: x[:4])\n",
    "    df_meth[\"run_id\"] = df_meth[\"run_id\"].astype('str')\n",
    "    \n",
    "    df_meth_rdm = pd.read_csv(FOLDER + METHODS_DICT[method][\"fname\"] + \"_rdm.csv\")\n",
    "    df_meth_rdm[\"lambda\"] = df_meth_rdm[\"lambda\"].astype('str')\n",
    "    df_meth_rdm[\"lambda\"] = df_meth_rdm[\"lambda\"].map(lambda x: x[:4])\n",
    "    df_meth_rdm[\"tau\"] = df_meth_rdm[\"tau\"].astype('str')\n",
    "    df_meth_rdm[\"tau\"] = df_meth_rdm[\"tau\"].map(lambda x: x[:4])\n",
    "    df_meth_rdm[\"run_id\"] = df_meth_rdm[\"run_id\"].astype('str')\n",
    "    \n",
    "    METHODS_DICT[method][\"df\"] = df_meth\n",
    "    METHODS_DICT[method][\"df_rdm\"] = df_meth_rdm\n",
    "    \n",
    "#METHODS_DICT[\"drodualcont\"].keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43a69f5d",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-29T02:11:09.335194Z",
     "start_time": "2023-09-29T02:11:09.301495Z"
    }
   },
   "outputs": [],
   "source": [
    "    \n",
    "# for printing labels on figures\n",
    "FEATURE_NAMES = {\"Global DI\": \"Global Fairness (DI)\",\n",
    "                \"Global Acc\": \"Accuracy\",\n",
    "                \"top1_DI\": \"Local Fairness (worst 1 DI)\",\n",
    "                \"top3_DI\": \"Local Fairness (worst 3 DI)\", \n",
    "                \"q_DI_0.8\": \"Local Fairness (0.8 quantile)\"}\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6166f8fe",
   "metadata": {},
   "source": [
    "# Pareto curves"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "75eb9017",
   "metadata": {},
   "source": [
    "## 1. Acc - Local Fairness"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1439bb86",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-29T02:12:22.537774Z",
     "start_time": "2023-09-29T02:11:10.459367Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "#### WARNING/ Slow because of PAreto\n",
    "\n",
    "\n",
    "# Pareto plot Global Acc * top1DI à iso Global DI\n",
    "sns.set_style('whitegrid')\n",
    "\n",
    "sns.set_palette(\"bright\")\n",
    "\n",
    "#### IMPORTANT: this defines the feature to \"fix\" and the corresponding range. E.g.: keeping only models \n",
    "# with GLOBAL DI between 0.0 and 0.05\n",
    "iso_f = \"Global DI\"\n",
    "GDIRANGE = 0.0, 0.05\n",
    "\n",
    "### FEATURES FOR X AND Y AXES\n",
    "f1, f2 = \"top1_DI\", \"Global Acc\"\n",
    "\n",
    "\n",
    "legend_names = []\n",
    "for key in METHODS_DICT:\n",
    "    print( key)\n",
    "    method = METHODS_DICT[key]    \n",
    "    try:\n",
    "        df = method[\"df\"]\n",
    "        df2 = df[(df[iso_f] > GDIRANGE[0])&(df[iso_f] < GDIRANGE[1])]\n",
    "        color = method[\"color\"]\n",
    "        linsty = method[\"linestyle\"]  \n",
    "\n",
    "    except KeyError:\n",
    "        continue\n",
    "\n",
    "    ### PARETO FRONT\n",
    "    dff = pareto_df(df2, f1, f2)\n",
    "    g = sns.lineplot(data=dff[dff[\"pareto\"]==1], x=f1, y=f2, color=color, linewidth=3, label=method[\"name\"],linestyle =linsty)\n",
    "    g2 = sns.scatterplot(data=dff[dff[\"pareto\"]==1], x=f1, y=f2, color=color, legend=False, s=30, markers=\"o\",linestyle =linsty)\n",
    "    legend_names.append(method[\"name\"])\n",
    "\n",
    "g.invert_xaxis()\n",
    "\n",
    "plt.legend(fontsize=12, loc='lower left', frameon=True)\n",
    "#plt.legend(legend_names, fontsize=13, loc='lower left')\n",
    "\n",
    "\n",
    "plt.xlabel('Local Unfairness (worst 1 DI)', fontsize=17)\n",
    "plt.ylabel(FEATURE_NAMES[f2], fontsize=17)\n",
    "plt.xticks(fontsize=12)\n",
    "plt.yticks(fontsize=12)\n",
    "plt.ylim((0.5, 0.72))\n",
    "\n",
    "plt.title(DATASET_NAME + ' (Global DI <' + str(GDIRANGE[1]) + ')', fontsize=17)\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "FNAME = \"../results_img/\"\n",
    "\n",
    "plt.savefig(FNAME + \"law_paretoTop1.pdf\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ebf0d12",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-28T17:56:20.686204Z",
     "start_time": "2023-09-28T17:54:59.256249Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "#### WARNING/ Slow because of PAreto\n",
    "\n",
    "\n",
    "# Pareto plot Global Acc * top1DI à iso Global DI\n",
    "sns.set_style('whitegrid')\n",
    "\n",
    "sns.set_palette(\"bright\")\n",
    "\n",
    "#### IMPORTANT: this defines the feature to \"fix\" and the corresponding range. E.g.: keeping only models \n",
    "# with GLOBAL DI between 0.0 and 0.05\n",
    "iso_f = \"Global DI\"\n",
    "GDIRANGE = 0.0, 0.05\n",
    "\n",
    "### FEATURES FOR X AND Y AXES\n",
    "f1, f2 = \"top3_DI\", \"Global Acc\"\n",
    "\n",
    "\n",
    "legend_names = []\n",
    "for key in METHODS_DICT:\n",
    "    print( key)\n",
    "    method = METHODS_DICT[key]    \n",
    "    try:\n",
    "        df = method[\"df\"]\n",
    "        df2 = df[(df[iso_f] > GDIRANGE[0])&(df[iso_f] < GDIRANGE[1])]\n",
    "        color = method[\"color\"]\n",
    "    except KeyError:\n",
    "        continue\n",
    "\n",
    "    ### PARETO FRONT\n",
    "    dff = pareto_df(df2, f1, f2)\n",
    "    g = sns.lineplot(data=dff[dff[\"pareto\"]==1], x=f1, y=f2, color=color, linewidth=3, label=method[\"name\"])\n",
    "    g2 = sns.scatterplot(data=dff[dff[\"pareto\"]==1], x=f1, y=f2, color=color, legend=False, s=30, markers=\"o\")\n",
    "    legend_names.append(method[\"name\"])\n",
    "\n",
    "g.invert_xaxis()\n",
    "\n",
    "plt.legend(fontsize=12, loc='lower left', frameon=True)\n",
    "#plt.legend(legend_names, fontsize=13, loc='lower left')\n",
    "\n",
    "\n",
    "plt.xlabel('Local Unfairness (worst 3 DI)', fontsize=17)\n",
    "plt.ylabel(FEATURE_NAMES[f2], fontsize=17)\n",
    "plt.xticks(fontsize=12)\n",
    "plt.yticks(fontsize=12)\n",
    "plt.ylim((0.39, 0.72))\n",
    "plt.title(DATASET_NAME + ' (Global DI < 0.05 )', fontsize=17)\n",
    "plt.tight_layout()\n",
    "\n",
    "FNAME = \"../results_img/\"\n",
    "\n",
    "plt.savefig(FNAME + \"law_paretoTop3bFINAL.pdf\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "029487b9",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
