{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "from glob import glob\n",
    "from viz_utils import plot_pareto, plot_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "files = glob(\"baselines/DP_FERMI/results/logistic-regression*.parquet\")\n",
    "# files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_df = (pd.concat([pd.read_parquet(f) for f in files])\n",
    "              .reset_index(drop=True).query(\"epsilon != 0.5\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plot Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr:last-of-type th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th colspan=\"3\" halign=\"left\">test_misclassification_error</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>min</th>\n",
       "      <th>mean</th>\n",
       "      <th>std</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>epsilon</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1.0</th>\n",
       "      <td>0.178834</td>\n",
       "      <td>0.235694</td>\n",
       "      <td>0.053328</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3.0</th>\n",
       "      <td>0.186924</td>\n",
       "      <td>0.236501</td>\n",
       "      <td>0.058341</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9.0</th>\n",
       "      <td>0.182548</td>\n",
       "      <td>0.236439</td>\n",
       "      <td>0.056419</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10000.0</th>\n",
       "      <td>0.205490</td>\n",
       "      <td>0.231703</td>\n",
       "      <td>0.010691</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        test_misclassification_error                    \n",
       "                                 min      mean       std\n",
       "epsilon                                                 \n",
       "1.0                         0.178834  0.235694  0.053328\n",
       "3.0                         0.186924  0.236501  0.058341\n",
       "9.0                         0.182548  0.236439  0.056419\n",
       "10000.0                     0.205490  0.231703  0.010691"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_df.groupby([\"epsilon\"]).agg({\"test_misclassification_error\": [\"min\", \"mean\", \"std\"]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "_results_df = results_df.query(\"fairness_metric == 'demographic_parity'\").rename(\n",
    "    columns={\"test_misclassification_error\": \"test_error\", \n",
    "             \"test_demographic_parity\": \"test_disparity\",\n",
    "             \"epsilon\": \"budget\",\n",
    "             \"model_number\": \"seed\",\n",
    "    },\n",
    "    inplace=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_results('adult', 'DemParity', results_df=_results_df, skip_others=False, plot_only_baselines=True, q=None, label=\"DP-FERMI Calculated\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Choosing Hyper-parameters to train further"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0\n",
      "         lambda  lr_theta  lr_W         C  lipschitz_theta  \\\n",
      "1767  0.251059     0.005  0.01  4.202886         4.953937   \n",
      "\n",
      "      test_misclassification_error  test_demographic_parity  \n",
      "1767                      0.178834                 0.161794   \n",
      "\n",
      "3.0\n",
      "         lambda  lr_theta  lr_W         C  lipschitz_theta  \\\n",
      "3334  0.121782     0.005  0.01  1.598963         1.846146   \n",
      "\n",
      "      test_misclassification_error  test_demographic_parity  \n",
      "3334                      0.186924                 0.115015   \n",
      "\n",
      "9.0\n",
      "         lambda  lr_theta  lr_W         C  lipschitz_theta  \\\n",
      "2207  0.170239     0.005  0.01  2.925974         3.076476   \n",
      "\n",
      "      test_misclassification_error  test_demographic_parity  \n",
      "2207                      0.182548                 0.145954   \n",
      "\n",
      "10000.0\n",
      "        lambda  lr_theta  lr_W         C  lipschitz_theta  \\\n",
      "1294  0.05038     0.005  0.01  4.362386         4.282848   \n",
      "\n",
      "      test_misclassification_error  test_demographic_parity  \n",
      "1294                       0.20549                 0.076751   \n",
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "Empty DataFrame\n",
       "Columns: []\n",
       "Index: []"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "results_df.groupby([\"epsilon\"]).apply(lambda x: print(f\"{x.name}\\n\", x[[\"lambda\", \"lr_theta\",  \"lr_W\", \"C\", \"lipschitz_theta\", \"test_misclassification_error\", \"test_demographic_parity\"]].sort_values([\"test_misclassification_error\", \"test_demographic_parity\"]).head(1), \"\\n\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.lineplot(data=plot_pareto(results_df.query(\"epsilon == 1\"), \n",
    "                              \"test_misclassification_error\", \"test_demographic_parity\", \n",
    "                              budget_col_name=\"epsilon\", seed_col_name=\"model_number\"), \n",
    "                              x=\"test_misclassification_error\", \n",
    "                              y=\"test_demographic_parity\", \n",
    "                              hue=\"epsilon\", markers=True, dashes=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 121,
   "metadata": {},
   "outputs": [],
   "source": [
    "_plot_df = plot_pareto(results_df, \n",
    "            x_col= \"valid_misclassification_error\", \n",
    "            y_col=\"valid_demographic_parity\", \n",
    "            budget_col_name=\"epsilon\", \n",
    "            seed_col_name=\"model_number\", q=20).drop_duplicates(\n",
    "                            subset=[\"epsilon\", \"model_number\", \"valid_misclassification_error_q\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.lineplot(data=_plot_df, x=\"test_misclassification_error\", y=\"test_demographic_parity\", hue=\"epsilon\", markers=True, dashes=False, palette=\"tab10\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Hyper-parameters to run for 200 epochs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 156,
   "metadata": {},
   "outputs": [],
   "source": [
    "params = (_plot_df[['lr_theta', 'lr_W', 'lambda','C', 'lipschitz_theta']]\n",
    "                .reset_index()\n",
    "                .drop([\"level_2\", \"model_number\"], axis=1)\n",
    "                .drop_duplicates()\n",
    "                .rename(columns={\"lambda\": \"lambd\"}))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 169,
   "metadata": {},
   "outputs": [],
   "source": [
    "params.to_parquet(\"baselines/DP_FERMI/scripts/adult_200epochs_params.parquet\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 163,
   "metadata": {},
   "outputs": [],
   "source": [
    "bash_command = lambda  epsilon, lr_theta, lr_W, lambd, C, lipschitz_theta: f'PYTHONPATH=\".\" python dp_fermi/dp_fermi.py --epochs 200 --dataset \"adult\" --batch_size 1024 --num_layers 1 --save_parquets --epsilon {epsilon} --lambd {lambd} --C {C} --lipschitz_theta {lipschitz_theta} --lr_theta {lr_theta} --lr_W {lr_W} --save_parquets_path ./results_200epochs'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 166,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PYTHONPATH=\".\" python dp_fermi/dp_fermi.py --epochs 200 --dataset \"adult\" --batch_size 1024 --num_layers 1 --save_parquets --epsilon 1.0 --lambd 2.253601063219744 --C 2.5966404402749075 --lipschitz_theta 3.4958000183677833 --lr_theta 0.005 --lr_W 0.01 --save_parquets_path ./results_200epochs\n"
     ]
    }
   ],
   "source": [
    "commands = [bash_command(**_params) for _params in params.to_dict(orient=\"records\")]\n",
    "print(commands[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 167,
   "metadata": {},
   "outputs": [],
   "source": [
    "user_1 = 'example/example'\n",
    "user_2 = 'exmp'\n",
    "\n",
    "preamble = f\"\"\"\n",
    "#!/bin/bash\n",
    "#SBATCH --account edok\n",
    "#SBATCH --partition edok\n",
    "#SBATCH --qos edok\n",
    "#SBATCH --gres=gpu:1\n",
    "#SBATCH --cpus-per-task=4\n",
    "#SBATCH --exclude=edok2,edok3\n",
    "#SBATCH --mem=64GB\n",
    "\n",
    "source /data/u/{user_1}/bin/activate pytorch\n",
    "cd /h/321/{user_2}/DP_FERMI\n",
    "\n",
    "echo \"Job start at $(date)\"\n",
    "\n",
    "\"\"\" \n",
    "\n",
    "with open(f\"/h/321/{user_2}/FairPATE/baselines/DP_FERMI/scripts/adult_200epochs.sh\", \"w\") as f:\n",
    "    f.write(preamble)\n",
    "    f.write(\"\\n\".join(commands))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Results for 200 epochs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_files = glob(f\"/h/321/{user_2}/DP_FERMI/results_200epochs/logistic-regression*.parquet\")\n",
    "_results_df = (pd.concat([pd.read_parquet(f) for f in _files])\n",
    "              .reset_index(drop=True).query(\"epsilon != 0.5\"))\n",
    "_results_df = _results_df.query(\"fairness_metric == 'demographic_parity'\").rename(\n",
    "    columns={\"test_misclassification_error\": \"test_error\", \n",
    "             \"test_demographic_parity\": \"test_disparity\",\n",
    "             \"epsilon\": \"budget\",\n",
    "             \"model_number\": \"seed\",\n",
    "    },\n",
    "    inplace=False)\n",
    "\n",
    "_ = plot_results('adult', 'DemParity', results_df=_results_df, skip_others=False, plot_only_baselines=True, q=20, label=\"DP-FERMI Calculated\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Single-Run 200epoch Results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Adult"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_files = glob(\"/data/projects/fairPATE/DP_FERMI_tabular_200/adult_logistic-regression*.parquet\")\n",
    "_results_df = (pd.concat([pd.read_parquet(f) for f in _files])\n",
    "              .reset_index(drop=True)\n",
    "            #   .query(\"epsilon != 0.5\")\n",
    "              )\n",
    "_results_df = _results_df.query(\"fairness_metric == 'demographic_parity'\").rename(\n",
    "    columns={\"test_misclassification_error\": \"test_error\", \n",
    "             \"test_demographic_parity\": \"test_disparity\",\n",
    "             \"epsilon\": \"budget\",\n",
    "             \"model_number\": \"seed\",\n",
    "    },\n",
    "    inplace=False)\n",
    "\n",
    "_ = plot_results('adult', 'DemParity', results_df=_results_df, skip_others=False, plot_only_baselines=True, q=20, label=\"DP-FERMI Calculated\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Parkinsons"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_files = glob(\"/data/projects/fairPATE/DP_FERMI_tabular_200/parkinsons_logistic-regression*.parquet\")\n",
    "_results_df = (pd.concat([pd.read_parquet(f) for f in _files])\n",
    "              .reset_index(drop=True)\n",
    "            #   .query(\"epsilon != 0.5\")\n",
    "              )\n",
    "_results_df = _results_df.query(\"fairness_metric == 'demographic_parity'\").rename(\n",
    "    columns={\"test_misclassification_error\": \"test_error\", \n",
    "             \"test_demographic_parity\": \"test_disparity\",\n",
    "             \"epsilon\": \"budget\",\n",
    "             \"model_number\": \"seed\",\n",
    "    },\n",
    "    inplace=False)\n",
    "\n",
    "_ = plot_results('parkinsons', 'DemParity', results_df=_results_df, skip_others=False, plot_only_baselines=True, q=None, label=\"DP-FERMI Calculated\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Credit-Card"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_files = glob(\"/data/projects/fairPATE/DP_FERMI_tabular_200/credit-card_logistic-regression*.parquet\")\n",
    "_results_df = (pd.concat([pd.read_parquet(f) for f in _files])\n",
    "              .reset_index(drop=True)\n",
    "            #   .query(\"epsilon != 0.5\")\n",
    "              )\n",
    "_results_df = _results_df.query(\"fairness_metric == 'demographic_parity'\").rename(\n",
    "    columns={\"test_misclassification_error\": \"test_error\", \n",
    "             \"test_demographic_parity\": \"test_disparity\",\n",
    "             \"epsilon\": \"budget\",\n",
    "             \"model_number\": \"seed\",\n",
    "    },\n",
    "    inplace=False)\n",
    "\n",
    "_ = plot_results('credit-card', 'DemParity', results_df=_results_df, skip_others=False, plot_only_baselines=True, q=None, label=\"DP-FERMI Calculated\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
