{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e78fcdce",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from glob import glob\n",
    "from omegaconf import OmegaConf\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib\n",
    "matplotlib.rcParams['mathtext.fontset'] = 'stix'\n",
    "matplotlib.rcParams['font.family'] = 'STIXGeneral'\n",
    "# use the seaborn whitegrid style\n",
    "plt.style.use('seaborn-v0_8-whitegrid')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d16c67b",
   "metadata": {},
   "source": [
    "# Helper function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "729da8c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def summarize_hyperparam(env, mu_type, alg_names):\n",
    "    def summ_ep(df):\n",
    "        delta = df.delta.abs().mean()\n",
    "        delta_obs = df.delta_obs.abs().mean()\n",
    "        rew = df.sort_values(\"t\")[\"resource\"].values[-1]\n",
    "        accept_g0 = df[df.group_id == 0].action.mean()\n",
    "        accept_g1 = df[df.group_id == 1].action.mean()\n",
    "        return pd.Series(\n",
    "            {\n",
    "                \"delta\": delta,\n",
    "                \"delta_last\": df.sort_values(\"t\")[\"delta\"].values[-1],\n",
    "                \"delta_obs\": delta_obs,\n",
    "                \"reward\": rew,\n",
    "                \"accept_g0\": accept_g0,\n",
    "                \"accept_g1\": accept_g1,\n",
    "            }\n",
    "        )\n",
    "\n",
    "    files = [\n",
    "        f\"experiments/{env}/{mu_type}/{alg}/eval/eval_data.csv\" for alg in alg_names\n",
    "    ]\n",
    "    results = []\n",
    "    for file in files:\n",
    "        try:\n",
    "            alg = file.split(\"/\")[-3]\n",
    "            df = pd.read_csv(file)\n",
    "            # check if there is any delta_obs negative, and print alg if so\n",
    "            if (df.delta_obs < 0).any():\n",
    "                print(f\"Negative delta_obs found in {alg}\")\n",
    "            df = df.groupby(\"ep\").apply(summ_ep).aggregate([\"mean\", \"std\"])\n",
    "            alg_simplified = alg.split(\"(\")[0]#alg.replace(\"sellf_hard\", \"sellf\").split(\"_\")[0]\n",
    "            results.append(\n",
    "                {\n",
    "                    \"alg\": alg,\n",
    "                    \"alg_simplified\": alg_simplified,\n",
    "                    \"delta\": df[\"delta\"][\"mean\"],\n",
    "                    \"delta_std\": df[\"delta\"][\"std\"],\n",
    "                    \"reward\": df[\"reward\"][\"mean\"],\n",
    "                    \"reward_std\": df[\"reward\"][\"std\"],\n",
    "                    \"delta_obs\": df[\"delta_obs\"][\"mean\"],\n",
    "                    \"delta_obs_std\": df[\"delta_obs\"][\"std\"],\n",
    "                    \"accept_g0\": df[\"accept_g0\"][\"mean\"],\n",
    "                    \"accept_g1\": df[\"accept_g1\"][\"mean\"],\n",
    "                    \"delta_last\": df[\"delta_last\"][\"mean\"],\n",
    "                }\n",
    "            )\n",
    "        except:\n",
    "            continue\n",
    "    results = pd.DataFrame(results)\n",
    "\n",
    "    def evaluate(df):\n",
    "        deltas = df[\"delta_obs\"].values.copy()\n",
    "        rewards = df[\"reward\"].values.copy()\n",
    "\n",
    "        # round delta to 2 decimal places\n",
    "        deltas = np.round(deltas, 2)\n",
    "        deltas[deltas <= 0.05] = 0\n",
    "\n",
    "        delta_min = deltas.min()\n",
    "        rewards[deltas > delta_min] = 0\n",
    "\n",
    "        df[\"score\"] = rewards\n",
    "\n",
    "        return df\n",
    "\n",
    "    results = results.groupby(\"alg_simplified\", group_keys = False).apply(evaluate).reset_index()\n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5cb8826",
   "metadata": {},
   "outputs": [],
   "source": [
    "def base_legend_func(alg_name):\n",
    "    alg_name = alg_name.split(\"(\")[0]\n",
    "    if alg_name == \"sellf\":\n",
    "        return \"SELLF\"\n",
    "    if alg_name == \"sellf_deep\":\n",
    "        return \"SELLF (ReLU NN)\"\n",
    "    if alg_name == \"sellf_censored_3\":\n",
    "        return \"SELLF cens.\"\n",
    "    if alg_name == \"sellf_censored_25\":\n",
    "        return \"SELLF (Semi-stoc.)\"\n",
    "    if alg_name == \"pocar_full\" or alg_name == \"pocar_full_v2\":\n",
    "        return \"POCAR (Oracle)\"\n",
    "    if alg_name == \"pocar\":\n",
    "        return \"POCAR\"\n",
    "    if alg_name == \"ppo\":\n",
    "        return \"PPO\"\n",
    "    if alg_name == \"elbert\":\n",
    "        return \"ELBERT\"\n",
    "    if alg_name == \"focops\":\n",
    "        return \"FOCOPS\"\n",
    "    return alg_name\n",
    "\n",
    "def base_color_func(alg_name):\n",
    "    if \"sellf_deep\" in alg_name:\n",
    "        return \"#984ea3\"\n",
    "    if \"sellf_censored\" in alg_name:\n",
    "        return \"#7570b3\"\n",
    "    if \"sellf\" in alg_name:\n",
    "        return \"#e78ac3\"\n",
    "    if \"full\" in alg_name:\n",
    "        return \"#fc8d62\"\n",
    "    if \"ppo\" in alg_name:\n",
    "        return \"#8da0cb\"\n",
    "    if \"pocar\" in alg_name:\n",
    "        return \"#a6d854\"\n",
    "    if \"elbert\" in alg_name:\n",
    "        return \"#1d9871\"\n",
    "    if \"focops\" in alg_name:\n",
    "        return \"#e5c494\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0308db9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_alg_names(alg_name):\n",
    "    config = OmegaConf.load(f\"configs/{alg_name}.yaml\")\n",
    "    alg_names = []\n",
    "    for params in config.algorithm_param_list:\n",
    "        params_info = \" \".join([f\"{k}:{v}\" for k, v in params.items()])\n",
    "        alg_names.append(f\"{alg_name}({params_info})\")\n",
    "    \n",
    "    if len(alg_names) == 0:\n",
    "        alg_names.append(f\"{alg_name}()\")\n",
    "    return alg_names\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4763b0a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_zorder(alg_name):\n",
    "    if \"sellf_censored\" in alg_name:\n",
    "        return 2\n",
    "    elif \"sellf\" in alg_name:\n",
    "        return 3\n",
    "\n",
    "    return 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b34f346d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def base_plot(selected_algs, env_name, mu_type, save_path):\n",
    "    fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(6, 3))\n",
    "\n",
    "    df_list = [\n",
    "        pd.read_csv(f\"experiments/{env_name}/{mu_type}/{alg}/eval/eval_data.csv\")\n",
    "        for alg in selected_algs\n",
    "    ]\n",
    "\n",
    "    for i in range(len(df_list)):\n",
    "        df_list[i][\"accept_g0\"] = df_list[i][\"action\"] * (df_list[i][\"group_id\"] == 0)\n",
    "        df_list[i][\"accept_g1\"] = df_list[i][\"action\"] * (df_list[i][\"group_id\"] == 1)\n",
    "\n",
    "        # for each ep, compute cumulative sum of accept_g0 and accept_g1\n",
    "        df_list[i][\"n_accept_g0\"] = df_list[i].groupby(\"ep\")[\"accept_g0\"].cumsum()\n",
    "        df_list[i][\"n_accept_g1\"] = df_list[i].groupby(\"ep\")[\"accept_g1\"].cumsum()\n",
    "\n",
    "        df_list[i][\"delta\"] = df_list[i][\"delta\"].abs()\n",
    "\n",
    "        df_list[i] = (\n",
    "            df_list[i]\n",
    "            .groupby(\"t\")\n",
    "            .agg(\n",
    "                {\n",
    "                    \"resource\": [\"mean\", \"std\"],\n",
    "                    \"delta\": [\"mean\", \"std\"],\n",
    "                    \"n_accept_g0\": [\"mean\", \"std\"],\n",
    "                    \"n_accept_g1\": [\"mean\", \"std\"],\n",
    "                }\n",
    "            )\n",
    "        )\n",
    "\n",
    "    def get_zorder(alg_name):\n",
    "        if \"sellf_censored\" in alg_name:\n",
    "            return 2\n",
    "        elif \"sellf\" in alg_name:\n",
    "            return 3\n",
    "        return 1\n",
    "    \n",
    "    for alg, df in zip(selected_algs, df_list):\n",
    "        axs[0].plot(\n",
    "            df.index,\n",
    "            df[\"resource\"][\"mean\"],\n",
    "            label=base_legend_func(alg),\n",
    "            color=base_color_func(alg),\n",
    "            linestyle=\"--\" if \"full\" in alg else \"-\",\n",
    "            zorder=get_zorder(alg),\n",
    "        )\n",
    "        axs[0].fill_between(\n",
    "            df.index,\n",
    "            df[\"resource\"][\"mean\"] - df[\"resource\"][\"std\"],\n",
    "            df[\"resource\"][\"mean\"] + df[\"resource\"][\"std\"],\n",
    "            alpha=0.3,\n",
    "            color=base_color_func(alg),\n",
    "            zorder=get_zorder(alg),\n",
    "        )\n",
    "\n",
    "        axs[1].plot(\n",
    "            df.index,\n",
    "            df[\"delta\"][\"mean\"],\n",
    "            label=base_legend_func(alg),\n",
    "            color=base_color_func(alg),\n",
    "            linestyle=\"--\" if \"full\" in alg else \"-\",\n",
    "            zorder=get_zorder(alg),\n",
    "        )\n",
    "        axs[1].fill_between(\n",
    "            df.index,\n",
    "            df[\"delta\"][\"mean\"] - df[\"delta\"][\"std\"],\n",
    "            df[\"delta\"][\"mean\"] + df[\"delta\"][\"std\"],\n",
    "            alpha=0.3,\n",
    "            color=base_color_func(alg),\n",
    "            zorder=get_zorder(alg),\n",
    "        )\n",
    "\n",
    "        #diff_accept = df[\"n_accept_g0\"][\"mean\"]\n",
    "        #axs[2].plot(\n",
    "        #    df.index,\n",
    "        #    diff_accept,\n",
    "        #    label=base_legend_func(alg),\n",
    "        #    color=base_color_func(alg),\n",
    "        #    linestyle=\"--\" if \"full\" in alg else \"-\",\n",
    "        #    lw=2,\n",
    "        #)\n",
    "\n",
    "    for i in range(2):\n",
    "        axs[i].set_xlabel(\"Timestep\")\n",
    "\n",
    "    axs[0].set_title(\"Cumulative Reward\", fontsize = 12)\n",
    "    axs[1].set_title(\"Disparity ($|\\Delta_t|$)\", fontsize = 12)\n",
    "    #axs[2].set_title(\"Nº of acceptances for\\nunderprivileged group\")\n",
    "\n",
    "    # place legend outside bellow all plots\n",
    "    fig.subplots_adjust(bottom=0.3, wspace=0.33)\n",
    "\n",
    "    axs[1].legend(\n",
    "        loc=\"upper center\",\n",
    "        bbox_to_anchor=(-0.2, -0.2),\n",
    "        ncol=4 #len(selected_algs),\n",
    "    )\n",
    "\n",
    "    for i in range(2):\n",
    "        axs[i].grid(True)\n",
    "    plt.savefig(save_path, bbox_inches=\"tight\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4688bfe4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_table(result_table):\n",
    "    table = result_table[[\"alg_simplified\", \"delta\", \"delta_std\", \"reward\", \"reward_std\"]].copy()\n",
    "    # round 2 decimal places\n",
    "    table[\"delta\"] = table[\"delta\"].round(2)\n",
    "    table[\"delta_std\"] = table[\"delta_std\"].round(2)\n",
    "    table[\"reward\"] = table[\"reward\"].round(2)\n",
    "    table[\"reward_std\"] = table[\"reward_std\"].round(1)\n",
    "\n",
    "    print(\"alg \\t\\t\\t delta \\t\\t\\t reward\")\n",
    "    def alg_order(alg_name):\n",
    "        if \"ppo\" in alg_name:\n",
    "            return 0\n",
    "        if \"pocar_full\" in alg_name:\n",
    "            return 2\n",
    "        if \"pocar\" in alg_name:\n",
    "            return 1\n",
    "        if \"sellf\" in alg_name:\n",
    "            return 3\n",
    "        return 4\n",
    "\n",
    "    table[\"order\"] = table[\"alg_simplified\"].apply(alg_order)\n",
    "    table = table.sort_values(by=[\"order\", \"alg_simplified\"])\n",
    "    for i, row in table.iterrows():\n",
    "        skip = \"\\t\\t\\t\" if len(base_legend_func(row['alg_simplified'])) <= 5 else \"\\t\\t\"\n",
    "        print(f\"{base_legend_func(row['alg_simplified'])} {skip} & {row['delta']} ($\\pm$ {row['delta_std']}) \\t & {row['reward']} ($\\pm$ {row['reward_std']})\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06e9bcff",
   "metadata": {},
   "source": [
    "# Ablation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fafbebba",
   "metadata": {},
   "outputs": [],
   "source": [
    "env_name = \"fico_equal\"\n",
    "mu_type = \"accuracy\"\n",
    "train_files = glob(f\"experiments/ablation/**/{env_name}/{mu_type}/**/models/*.csv\")\n",
    "eval_files = glob(f\"experiments/ablation/**/{env_name}/{mu_type}/**/eval/*.csv\")\n",
    "alg_dict = dict([(f.split(\"/\")[-3], []) for f in train_files])\n",
    "alg_dict_eval = dict([(f.split(\"/\")[-3], []) for f in eval_files])\n",
    "for f in train_files:\n",
    "    alg_name = f.split(\"/\")[-3]\n",
    "    alg_dict[alg_name].append(f)\n",
    "for f in eval_files:\n",
    "    alg_name = f.split(\"/\")[-3]\n",
    "    alg_dict_eval[alg_name].append(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2f4ea14",
   "metadata": {},
   "outputs": [],
   "source": [
    "### sumarize learning metrics\n",
    "\n",
    "alg_dfs = {}\n",
    "\n",
    "for alg_name, file_list in alg_dict.items():\n",
    "    df_list = []\n",
    "    for f in file_list:\n",
    "        df = pd.read_csv(f)\n",
    "        df[\"t\"] = np.arange(len(df))\n",
    "        df_list.append(df)\n",
    "    df = pd.concat(df_list)\n",
    "    df[\"delta_diff\"] = (df[\"delta\"] - df[\"delta_obs\"])\n",
    "    df[\"term_g0\"] =(1 - df[\"accept_g0\"]) * df[\"renyi_div_g0\"]\n",
    "    df[\"term_g1\"] =(1 - df[\"accept_g1\"]) * df[\"renyi_div_g1\"]\n",
    "    df[\"term\"] = df[\"term_g0\"] + df[\"term_g1\"]\n",
    "\n",
    "    # apply moving average with window size 10 for every numeric column\n",
    "    for col in [\"reward\", \"delta\", \"delta_obs\", \"delta_diff\"]:\n",
    "        df[col] = df[col].rolling(window=1).mean()\n",
    "\n",
    "    df = df.groupby(\"t\").agg({\n",
    "        \"reward\": [\"mean\", \"std\"],\n",
    "        \"delta\": [\"mean\", \"std\"],\n",
    "        \"delta_obs\": [\"mean\", \"std\"],\n",
    "        \"delta_diff\": [\"mean\", \"std\"],\n",
    "        \"term\": [\"mean\", \"std\"],\n",
    "        \"accept_g0\": [\"mean\", \"std\"],\n",
    "        \"accept_g1\": [\"mean\", \"std\"],\n",
    "        \"max_weight_g0\" : [\"mean\", \"std\"],\n",
    "        \"max_weight_g1\" : [\"mean\", \"std\"],\n",
    "        \"aK_min_g0\" : [\"mean\", \"std\"],\n",
    "        \"aK_min_g1\" : [\"mean\", \"std\"],\n",
    "    })\n",
    "    alg_dfs[alg_name] = df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5aa454d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "alg_dfs_eval = {}\n",
    "for alg_name, file_list in alg_dict_eval.items():\n",
    "    df_list = []\n",
    "    for f in file_list:\n",
    "        df = pd.read_csv(f)\n",
    "        df[\"t\"] = np.arange(len(df))\n",
    "        df_list.append(df)\n",
    "    df = pd.concat(df_list)\n",
    "    df[\"delta\"] = df[\"delta\"].abs()\n",
    "    df = df.groupby(\"t\").agg({\"delta\" : \"mean\"})\n",
    "    alg_dfs_eval[alg_name] = {\n",
    "        \"mean\" : df[\"delta\"].mean(),\n",
    "        \"std\" : df[\"delta\"].std()\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c153f30c",
   "metadata": {},
   "outputs": [],
   "source": [
    "sorted_algs = sorted(alg_dfs, key = lambda x : float(x.split(\":\")[2].replace(\")\", \"\")))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41d20549",
   "metadata": {},
   "outputs": [],
   "source": [
    "colors = ['#c7e9b4','#7fcdbb','#41b6c4','#2c7fb8','#253494']\n",
    "\n",
    "def color_func(alg_name):\n",
    "    idx = sorted_algs.index(alg_name)\n",
    "    return colors[idx]\n",
    "\n",
    "def legend_func(alg_name):\n",
    "    return alg_name.split(\" \")[-1].replace(\")\", \"\").replace(\"beta_2:\", \"$\\\\beta_2=$\") \n",
    "\n",
    "\n",
    "fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(7.5, 3), gridspec_kw={'width_ratios': [1, 1, 0.7]})\n",
    "\n",
    "\n",
    "for alg_name in sorted_algs:\n",
    "    df = alg_dfs[alg_name]\n",
    "    axs[0].plot(\n",
    "        df[\"delta_diff\"][\"mean\"],\n",
    "        label=(alg_name),\n",
    "        color=color_func(alg_name),\n",
    "        lw = 2,\n",
    "    )\n",
    "\n",
    "    axs[0].fill_between(\n",
    "        df.index,\n",
    "        df[\"delta_diff\"][\"mean\"] - df[\"delta_diff\"][\"std\"],\n",
    "        df[\"delta_diff\"][\"mean\"] + df[\"delta_diff\"][\"std\"],\n",
    "        alpha=0.2,\n",
    "        color = color_func(alg_name),\n",
    "    )\n",
    "\n",
    "    axs[0].axhline(0, color=\"black\", linestyle=\"--\", linewidth=1, alpha=0.3)\n",
    "\n",
    "\n",
    "    axs[1].plot(\n",
    "        df[\"term\"][\"mean\"],\n",
    "        label=legend_func(alg_name),\n",
    "        color=color_func(alg_name), \n",
    "        lw = 2,\n",
    "    )\n",
    "\n",
    "    axs[1].fill_between(\n",
    "        df.index,\n",
    "        df[\"term\"][\"mean\"] - df[\"term\"][\"std\"],\n",
    "        df[\"term\"][\"mean\"] + df[\"term\"][\"std\"],\n",
    "        alpha=0.2,\n",
    "        color = color_func(alg_name),\n",
    "    )\n",
    "\n",
    "    axs[1].axhline(0, color=\"black\", linestyle=\"--\", linewidth=1, alpha=0.3)\n",
    "\n",
    "\n",
    "    df = alg_dfs_eval[alg_name]\n",
    "    axs[2].bar(\n",
    "        alg_name,\n",
    "        df[\"mean\"],\n",
    "        yerr=df[\"std\"],\n",
    "        color=color_func(alg_name),\n",
    "    )\n",
    "    axs[2].set_xticks(range(len(sorted_algs)))\n",
    "    axs[2].set_xticklabels(\n",
    "        [\"\" for s in sorted_algs],\n",
    "        #[s.split(\" \")[-1].replace(\")\", \"\").replace(\"beta_2:\", \"$\\\\beta_2=$\") for s in sorted_algs],\n",
    "        rotation=45,\n",
    "    )\n",
    "    #plt.show()\n",
    "\n",
    "axs[0].set_ylim(-0.01, 0.03)\n",
    "axs[1].set_ylim(0, 5)\n",
    "axs[0].set_title(r\"$\\Delta_t - \\tilde \\Delta_t $\", fontsize=14)\n",
    "axs[1].set_title(r\"$\\Sigma_i r^i \\mathbb{E}[\\mathrm{w}(x, i)^2]$\", fontsize=14)\n",
    "axs[2].set_title(r\"$|\\Delta_t|$\", fontsize=14)\n",
    "\n",
    "fig.subplots_adjust(bottom=0.3, wspace=0.33)\n",
    "\n",
    "axs[0].set_xlabel(\"Learning iteration\", fontsize = 12)\n",
    "axs[1].set_xlabel(\"Learning iteration\", fontsize = 12)\n",
    "axs[2].set_xlabel(r\"$\\beta_2$\", fontsize = 12)\n",
    "axs[1].legend(\n",
    "    loc=\"upper center\",\n",
    "    bbox_to_anchor=(0.35, -0.3),\n",
    "    ncol=5,\n",
    "    fontsize=12\n",
    ")\n",
    "\n",
    "for i in range(3):\n",
    "    axs[i].grid(True)\n",
    "\n",
    "# make the last only be horizontal grid\n",
    "axs[-1].xaxis.grid(False)\n",
    "\n",
    "# reduce the horizontal gap between plots\n",
    "\n",
    "\n",
    "# place y ticks of last plot on the right\n",
    "axs[2].yaxis.tick_right()\n",
    "axs[2].yaxis.set_label_position(\"right\")\n",
    "\n",
    "fig.subplots_adjust(wspace=0.15)\n",
    "\n",
    "plt.savefig(\"figures/ablation_beta2.pdf\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ebc4298b",
   "metadata": {},
   "source": [
    "### Analysis of weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa972b92",
   "metadata": {},
   "outputs": [],
   "source": [
    "colors = ['#c7e9b4','#7fcdbb','#41b6c4','#2c7fb8','#253494']\n",
    "\n",
    "def color_func(alg_name):\n",
    "    idx = sorted_algs.index(alg_name)\n",
    "    return colors[idx]\n",
    "\n",
    "def legend_func(alg_name):\n",
    "    return alg_name.split(\" \")[-1].replace(\")\", \"\").replace(\"beta_2:\", \"$\\\\beta_2=$\")  \n",
    "\n",
    "\n",
    "fig, axs = plt.subplots(nrows=1, ncols=4, figsize=(10, 3))\n",
    "\n",
    "\n",
    "for alg_name in sorted_algs:\n",
    "    df = alg_dfs[alg_name]\n",
    "\n",
    "    # first plot the weights\n",
    "    axs[0].plot(\n",
    "        df[\"max_weight_g0\"][\"mean\"],\n",
    "        label=legend_func(alg_name),\n",
    "        color=color_func(alg_name),\n",
    "        lw = 2,\n",
    "    )\n",
    "\n",
    "    axs[0].fill_between(\n",
    "        df.index,\n",
    "        df[\"max_weight_g0\"][\"mean\"] - df[\"max_weight_g0\"][\"std\"],\n",
    "        df[\"max_weight_g0\"][\"mean\"] + df[\"max_weight_g0\"][\"std\"],\n",
    "        alpha=0.2,\n",
    "        color = color_func(alg_name),\n",
    "    )\n",
    "\n",
    "    axs[1].plot(\n",
    "        df[\"max_weight_g1\"][\"mean\"],\n",
    "        label=legend_func(alg_name),\n",
    "        color=color_func(alg_name),\n",
    "        lw = 2,\n",
    "    )\n",
    "\n",
    "    axs[1].fill_between(\n",
    "        df.index,\n",
    "        df[\"max_weight_g1\"][\"mean\"] - df[\"max_weight_g1\"][\"std\"],\n",
    "        df[\"max_weight_g1\"][\"mean\"] + df[\"max_weight_g1\"][\"std\"],\n",
    "        alpha=0.2,\n",
    "        color = color_func(alg_name),\n",
    "    )\n",
    "\n",
    "    # now, plot the denominators\n",
    "    axs[2].plot(\n",
    "        df[\"aK_min_g0\"][\"mean\"],\n",
    "        label=legend_func(alg_name),\n",
    "        color=color_func(alg_name),\n",
    "        lw = 2,\n",
    "    )\n",
    "\n",
    "    axs[2].fill_between(\n",
    "        df.index,\n",
    "        df[\"aK_min_g0\"][\"mean\"] - df[\"aK_min_g0\"][\"std\"],\n",
    "        df[\"aK_min_g0\"][\"mean\"] + df[\"aK_min_g0\"][\"std\"],\n",
    "        alpha=0.2,\n",
    "        color = color_func(alg_name),\n",
    "    )\n",
    "\n",
    "    axs[3].plot(\n",
    "        df[\"aK_min_g1\"][\"mean\"],\n",
    "        label=legend_func(alg_name),\n",
    "        color=color_func(alg_name),\n",
    "        lw = 2,\n",
    "    )\n",
    "\n",
    "    axs[3].fill_between(\n",
    "        df.index,\n",
    "        df[\"aK_min_g1\"][\"mean\"] - df[\"aK_min_g1\"][\"std\"],\n",
    "        df[\"aK_min_g1\"][\"mean\"] + df[\"aK_min_g1\"][\"std\"],\n",
    "        alpha=0.2,\n",
    "        color = color_func(alg_name),\n",
    "    )\n",
    "\n",
    "\n",
    "#axs[0].set_ylim(-0.01, 0.03)\n",
    "#axs[1].set_ylim(0, 5)\n",
    "axs[0].set_title(r\"$\\max_x \\{\\mathrm{w}(x, 0)\\}$\")\n",
    "axs[1].set_title(r\"$\\max_x \\{\\mathrm{w}(x, 1)\\}$\")\n",
    "axs[2].set_title(r\"$\\min_x P(A[1:K] = 1 |x,0)$\")\n",
    "axs[3].set_title(r\"$\\min_X P(A[1:K] = 1 | x, 1)$\")\n",
    "#axs[2].set_title(r\"$|\\Delta_t|$\")\n",
    "#\n",
    "#fig.subplots_adjust(bottom=0.3, wspace=0.33)\n",
    "#\n",
    "#axs[0].set_xlabel(\"Learning iteration\")\n",
    "#axs[1].set_xlabel(\"Learning iteration\")\n",
    "#axs[2].set_xlabel(r\"$\\beta_2$\")\n",
    "axs[1].legend(\n",
    "    loc=\"upper center\",\n",
    "    bbox_to_anchor=(0.8, -0.1),\n",
    "    ncol=5,\n",
    ")\n",
    "\n",
    "for i in range(4):\n",
    "   axs[i].grid(True)\n",
    "\n",
    "#axs[2].set_xlim(0, 30)\n",
    "#axs[3].set_xlim(0, 30)\n",
    "\n",
    "\n",
    "plt.savefig(\"figures/ablation_weights.pdf\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b18b9410",
   "metadata": {},
   "source": [
    "# Bound"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "063a9b0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "env_name = \"fico_equal\"\n",
    "mu_type = \"accuracy\"\n",
    "train_files = glob(f\"experiments/bound/**/{env_name}/{mu_type}/**/models/*.csv\")\n",
    "alg_dict = dict([(f.split(\"/\")[-3], []) for f in train_files])\n",
    "for f in train_files:\n",
    "    alg_name = f.split(\"/\")[-3]\n",
    "    alg_dict[alg_name].append(f)\n",
    "\n",
    "\n",
    "alg_dfs = {}\n",
    "\n",
    "for alg_name, file_list in alg_dict.items():\n",
    "    df_list = []\n",
    "    for f in file_list:\n",
    "        df = pd.read_csv(f)\n",
    "        df[\"t\"] = np.arange(len(df))\n",
    "        df_list.append(df)\n",
    "    df = pd.concat(df_list)\n",
    "    df[\"delta_diff\"] = (df[\"delta\"] - df[\"delta_obs\"])\n",
    "    df[\"renyi_div_g0\"] = df[\"renyi_div_g0\"].fillna(0)\n",
    "    df[\"renyi_div_g1\"] = df[\"renyi_div_g1\"].fillna(0)\n",
    "    df[\"error_accepted_g0\"] = df[\"error_accepted_g0\"].fillna(0)\n",
    "    df[\"error_accepted_g1\"] = df[\"error_accepted_g1\"].fillna(0)\n",
    "\n",
    "    df = df.groupby(\"t\").agg({\n",
    "        \"reward\": [\"mean\", \"std\"],\n",
    "        \"delta\": [\"mean\", \"std\"],\n",
    "        \"delta_obs\": [\"mean\", \"std\"],\n",
    "        \"delta_diff\": [\"mean\", \"std\"],\n",
    "        \"renyi_div_g0\": [\"mean\", \"std\"],\n",
    "        \"renyi_div_g1\": [\"mean\", \"std\"],\n",
    "        \"error_accepted_g0\": [\"mean\", \"std\"],\n",
    "        \"error_accepted_g1\": [\"mean\", \"std\"],\n",
    "    })\n",
    "    alg_dfs[alg_name] = df\n",
    "\n",
    "sorted_algs = sorted(alg_dfs, key = lambda x : float(x.split(\":\")[2].replace(\")\", \"\")))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42b5f8dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "colors = ['#c7e9b4','#7fcdbb','#41b6c4','#2c7fb8','#253494']\n",
    "\n",
    "def color_func(alg_name):\n",
    "    idx = sorted_algs.index(alg_name)\n",
    "    return colors[idx]\n",
    "\n",
    "def legend_func(alg_name):\n",
    "    return alg_name.split(\" \")[-1].replace(\")\", \"\").replace(\"beta_2:\", \"$\\\\beta_2=$\") \n",
    "\n",
    "\n",
    "fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(9, 7), sharex=True, sharey=\"row\")\n",
    "\n",
    "\n",
    "for alg_name in sorted_algs:\n",
    "    df = alg_dfs[alg_name]\n",
    "\n",
    "    for g in range(2):\n",
    "        axs[0, g].plot(\n",
    "            df[f\"error_accepted_g{g}\"][\"mean\"],\n",
    "            label=legend_func(alg_name),\n",
    "            color=color_func(alg_name), \n",
    "            lw = 2,\n",
    "        )   \n",
    "\n",
    "        axs[1, g].plot(\n",
    "            df[f\"renyi_div_g{g}\"][\"mean\"],\n",
    "            label=legend_func(alg_name),\n",
    "            color=color_func(alg_name),\n",
    "            lw = 2,\n",
    "        )\n",
    "\n",
    "        axs[0, g].fill_between(\n",
    "            df.index,\n",
    "            df[f\"error_accepted_g{g}\"][\"mean\"] - df[f\"error_accepted_g{g}\"][\"std\"],\n",
    "            df[f\"error_accepted_g{g}\"][\"mean\"] + df[f\"error_accepted_g{g}\"][\"std\"],\n",
    "            alpha=0.2,\n",
    "            color = color_func(alg_name),\n",
    "        )\n",
    "\n",
    "        axs[1, g].fill_between(\n",
    "            df.index,\n",
    "            df[f\"renyi_div_g{g}\"][\"mean\"] - df[f\"renyi_div_g{g}\"][\"std\"],\n",
    "            df[f\"renyi_div_g{g}\"][\"mean\"] + df[f\"renyi_div_g{g}\"][\"std\"],\n",
    "            alpha=0.2,\n",
    "            color = color_func(alg_name),\n",
    "        )\n",
    "\n",
    "fig.subplots_adjust(bottom=0.3, wspace=0.33)\n",
    "\n",
    "axs[1, 0].legend(\n",
    "    loc=\"upper center\",\n",
    "    bbox_to_anchor=(1, -0.3),\n",
    "    ncol=5,\n",
    "    fontsize=12\n",
    ")\n",
    "\n",
    "groups_names = [\"Unprivileged\", \"Privileged\"]\n",
    "for g in range(2):\n",
    "    axs[0, g].set_title(f\"{groups_names[g]} Group\", fontsize=14)\n",
    "\n",
    "axs[0, 0].set_ylabel(\"Error on rejected pop.\", fontsize=12)\n",
    "axs[1, 0].set_ylabel(\"Renyi Divergence\", fontsize=12)\n",
    "\n",
    "\n",
    "axs[0, 1].set_ylim(-0.02, 0.06)\n",
    "\n",
    "fig.subplots_adjust(wspace=0.15)\n",
    "\n",
    "plt.savefig(\"figures/bound_terms.pdf\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2dcbbcc9",
   "metadata": {},
   "source": [
    "# Sec. 5 Experiments"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f717fa4b",
   "metadata": {},
   "source": [
    "## FICO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e16e188c",
   "metadata": {},
   "outputs": [],
   "source": [
    "env_name = \"fico_equal\"\n",
    "mu_type = \"tpr\"\n",
    "alg_names = [\n",
    "    \"ppo\",\n",
    "    \"pocar\",\n",
    "    \"pocar_full\",\n",
    "    \"sellf\",\n",
    "    \"sellf_censored_25\",\n",
    "    \"elbert\",\n",
    "    \"focops\",\n",
    "]\n",
    "alg_names = sum(\n",
    "    [get_alg_names(alg) for alg in alg_names],\n",
    "    [],\n",
    ")\n",
    "results = summarize_hyperparam(env_name, mu_type, alg_names)\n",
    "process_table(\n",
    "    results.sort_values(\"score\", ascending=False)\n",
    "    .groupby(\"alg_simplified\")\n",
    "    .first()\n",
    "    .reset_index()\n",
    ")\n",
    "\n",
    "selected_algs = (\n",
    "    results.sort_values(\"score\", ascending=False)\n",
    "    .groupby(\"alg_simplified\")\n",
    "    .first()['alg']\n",
    "    .tolist()\n",
    ")\n",
    "\n",
    "# place sellf first\n",
    "sellf_algs = [alg for alg in selected_algs if \"sellf\" in alg]\n",
    "selected_algs = [alg for alg in selected_algs if \"sellf\" not in alg]\n",
    "selected_algs = sellf_algs + selected_algs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04f38466",
   "metadata": {},
   "outputs": [],
   "source": [
    "env_name = \"fico_equal\"\n",
    "mu_type = \"tpr\"\n",
    "alg_names = [\n",
    "    \"ppo\",\n",
    "    \"pocar\",\n",
    "    \"pocar_full\",\n",
    "    \"sellf\",\n",
    "    \"sellf_censored_25\",\n",
    "    \"elbert\",\n",
    "    \"focops\",\n",
    "]\n",
    "alg_names = sum(\n",
    "    [get_alg_names(alg) for alg in alg_names],\n",
    "    [],\n",
    ")\n",
    "results = summarize_hyperparam(env_name, mu_type, alg_names)\n",
    "process_table(\n",
    "    results.sort_values(\"score\", ascending=False)\n",
    "    .groupby(\"alg_simplified\")\n",
    "    .first()\n",
    "    .reset_index()\n",
    ")\n",
    "\n",
    "selected_algs = (\n",
    "    results.sort_values(\"score\", ascending=False)\n",
    "    .groupby(\"alg_simplified\")\n",
    "    .first()['alg']\n",
    "    .tolist()\n",
    ")\n",
    "\n",
    "# place sellf first\n",
    "sellf_algs = [alg for alg in selected_algs if \"sellf\" in alg]\n",
    "selected_algs = [alg for alg in selected_algs if \"sellf\" not in alg]\n",
    "selected_algs = sellf_algs + selected_algs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dae5044d",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(6, 2.7), gridspec_kw={'width_ratios': [1, 1, 0.7]})\n",
    "\n",
    "df_list = [\n",
    "    pd.read_csv(f\"experiments/{env_name}/{mu_type}/{alg}/eval/eval_data.csv\")\n",
    "    for alg in selected_algs\n",
    "]\n",
    "\n",
    "# first, let's calculate the accept rate of each alg\n",
    "accept_rate_g0 = {}\n",
    "accept_rate_g1 = {}\n",
    "\n",
    "for i in range(len(df_list)):\n",
    "    df_list[i][\"accept_and_g0\"] = df_list[i][\"action\"] * (df_list[i][\"group_id\"] == 0)\n",
    "    df_list[i][\"accept_and_g1\"] = df_list[i][\"action\"] * (df_list[i][\"group_id\"] == 1)\n",
    "    df_list[i][\"g0\"] = (df_list[i][\"group_id\"] == 0)\n",
    "    df_list[i][\"g1\"] = (df_list[i][\"group_id\"] == 1)\n",
    "    accept_g0 = df_list[i].groupby(\"ep\")[\"accept_and_g0\"].sum()\n",
    "    accept_g1 = df_list[i].groupby(\"ep\")[\"accept_and_g1\"].sum()\n",
    "\n",
    "    n_g0 = df_list[i].groupby(\"ep\")[\"g0\"].sum()\n",
    "    n_g1 = df_list[i].groupby(\"ep\")[\"g1\"].sum()\n",
    "\n",
    "    accept_rate_g0[selected_algs[i]] = {\n",
    "        \"mean\": (accept_g0 / n_g0).mean(),\n",
    "        \"std\": (accept_g0 / n_g0).std(),\n",
    "    }\n",
    "    accept_rate_g1[selected_algs[i]] = {\n",
    "        \"mean\": (accept_g1 / n_g1).mean(),\n",
    "        \"std\": (accept_g1 / n_g1).std(),\n",
    "    }\n",
    "\n",
    "    df_list[i][\"delta\"] = df_list[i][\"delta\"].abs()\n",
    "\n",
    "    df_list[i] = (\n",
    "        df_list[i]\n",
    "        .groupby(\"t\")\n",
    "        .agg(\n",
    "            {\n",
    "                \"resource\": [\"mean\", \"std\"],\n",
    "                \"delta\": [\"mean\", \"std\"],\n",
    "            }\n",
    "        )\n",
    "    )\n",
    "\n",
    "for alg, df in zip(selected_algs, df_list):\n",
    "    \n",
    "\n",
    "    axs[0].plot(\n",
    "        df.index,\n",
    "        df[\"resource\"][\"mean\"],\n",
    "        label=base_legend_func(alg),\n",
    "        color=base_color_func(alg),\n",
    "        linestyle=\"--\" if \"full\" in alg else \"-\",\n",
    "        zorder = get_zorder(alg),\n",
    "    )\n",
    "    axs[0].fill_between(\n",
    "        df.index,\n",
    "        df[\"resource\"][\"mean\"] - df[\"resource\"][\"std\"],\n",
    "        df[\"resource\"][\"mean\"] + df[\"resource\"][\"std\"],\n",
    "        alpha=0.3,\n",
    "        color=base_color_func(alg),\n",
    "        zorder = get_zorder(alg),\n",
    "    )\n",
    "\n",
    "    axs[1].plot(\n",
    "        df.index,\n",
    "        df[\"delta\"][\"mean\"],\n",
    "        label=base_legend_func(alg),\n",
    "        color=base_color_func(alg),\n",
    "        linestyle=\"--\" if \"full\" in alg else \"-\",\n",
    "        zorder = get_zorder(alg),\n",
    "    )\n",
    "    axs[1].fill_between(\n",
    "        df.index,\n",
    "        df[\"delta\"][\"mean\"] - df[\"delta\"][\"std\"],\n",
    "        df[\"delta\"][\"mean\"] + df[\"delta\"][\"std\"],\n",
    "        alpha=0.3,\n",
    "        color=base_color_func(alg),\n",
    "        zorder = get_zorder(alg),\n",
    "    )\n",
    "\n",
    "    # axs[2].plot(\n",
    "    #     df.index,\n",
    "    #     df[\"n_accept_g0\"][\"mean\"],\n",
    "    #     label=base_legend_func(alg),\n",
    "    #     color=base_color_func(alg),\n",
    "    #     linestyle=\"--\" if \"full\" in alg else \"-\",\n",
    "    #     lw=2,\n",
    "    # )\n",
    "\n",
    "    # axs[2].fill_between(\n",
    "    #     df.index,\n",
    "    #     df[\"n_accept_g0\"][\"mean\"] - df[\"n_accept_g0\"][\"std\"],\n",
    "    #     df[\"n_accept_g0\"][\"mean\"] + df[\"n_accept_g0\"][\"std\"],\n",
    "    #     alpha=0.3,\n",
    "    #     color=base_color_func(alg),\n",
    "    # )\n",
    "\n",
    "    # scatter the acceptance rate\n",
    "    #axs[2].scatter(\n",
    "    #    [accept_rate_g0],\n",
    "    #    [accept_rate_g1],\n",
    "    #    label=base_legend_func(alg),\n",
    "    #    color=base_color_func(alg),\n",
    "    #    s=100,\n",
    "    #)\n",
    "\n",
    "    # scatter with error bars\n",
    "    axs[2].errorbar(\n",
    "        accept_rate_g0[alg][\"mean\"],\n",
    "        accept_rate_g1[alg][\"mean\"],\n",
    "        xerr=accept_rate_g0[alg][\"std\"],\n",
    "        yerr=accept_rate_g1[alg][\"std\"],\n",
    "        label=base_legend_func(alg),\n",
    "        color=base_color_func(alg),\n",
    "        fmt='o',\n",
    "        markersize=8,\n",
    "        capsize=5,\n",
    "    )\n",
    "        \n",
    "for i in range(2):\n",
    "    axs[i].set_xlabel(\"Timestep\", fontsize = 12)\n",
    "\n",
    "axs[2].set_xlabel(\"$P(A = 1 | Z = 0)$\", fontsize = 12)\n",
    "axs[2].set_ylabel(\"$P(A = 1 | Z = 1)$\", fontsize = 12)\n",
    "# place the y ticks of the last plot on the right\n",
    "axs[2].yaxis.tick_right()\n",
    "axs[2].yaxis.set_label_position(\"right\")\n",
    "\n",
    "axs[0].set_title(\"Cumulative Reward\")\n",
    "axs[1].set_title(\"Disparity $(|\\Delta_t|)$\")\n",
    "axs[2].set_title(\"Accept rate\")\n",
    "\n",
    "# place legend outside bellow all plots\n",
    "fig.subplots_adjust(bottom=0.3, wspace=0.33)\n",
    "\n",
    "axs[1].legend(\n",
    "    loc=\"upper center\",\n",
    "    bbox_to_anchor=(0.35, -0.3),\n",
    "    ncol= 4, #len(selected_algs),\n",
    ")\n",
    "\n",
    "for i in range(3):\n",
    "    axs[i].grid(True)\n",
    "\n",
    "fig.subplots_adjust(wspace = 0.25)\n",
    "plt.savefig(\"figures/fico_tpr.pdf\", bbox_inches=\"tight\", dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b6c8c16e",
   "metadata": {},
   "source": [
    "## COMPAS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2eea49f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "env_name = \"compas\"\n",
    "mu_type = \"accuracy\"\n",
    "alg_names = [\n",
    "    \"ppo\",\n",
    "    \"pocar\",\n",
    "    \"pocar_full_v2\",\n",
    "    \"sellf\",\n",
    "    \"sellf_censored_25\",\n",
    "    \"elbert\",\n",
    "    \"focops\",\n",
    "]\n",
    "alg_names = sum(\n",
    "    [get_alg_names(alg) for alg in alg_names],\n",
    "    [],\n",
    ")\n",
    "results = summarize_hyperparam(env_name, mu_type, alg_names)\n",
    "process_table(\n",
    "    results.sort_values(\"score\", ascending=False)\n",
    "    .groupby(\"alg_simplified\")\n",
    "    .first()\n",
    "    .reset_index()\n",
    ")\n",
    "\n",
    "selected_algs = (\n",
    "    results.sort_values(\"score\", ascending=False)\n",
    "    .groupby(\"alg_simplified\")\n",
    "    .first()['alg']\n",
    "    .tolist()\n",
    ")\n",
    "\n",
    "# place sellf first\n",
    "sellf_algs = [alg for alg in selected_algs if \"sellf\" in alg]\n",
    "selected_algs = [alg for alg in selected_algs if \"sellf\" not in alg]\n",
    "selected_algs = sellf_algs + selected_algs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e9d1555",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_plot(selected_algs, env_name, mu_type, f\"figures/{env_name}_{mu_type}.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a4b905a",
   "metadata": {},
   "source": [
    "## ENEM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2075cb52",
   "metadata": {},
   "outputs": [],
   "source": [
    "env_name = \"enem\"\n",
    "mu_type = \"qualification\"\n",
    "alg_names = [\n",
    "    \"ppo\",\n",
    "    \"pocar\",\n",
    "    \"pocar_full\",\n",
    "    \"sellf\",\n",
    "    \"sellf_censored_25\",\n",
    "    \"elbert\",\n",
    "    \"focops\",\n",
    "]\n",
    "alg_names = sum(\n",
    "    [get_alg_names(alg) for alg in alg_names],\n",
    "    [],\n",
    ")\n",
    "results = summarize_hyperparam(env_name, mu_type, alg_names)\n",
    "process_table(\n",
    "    results.sort_values(\"score\", ascending=False)\n",
    "    .groupby(\"alg_simplified\")\n",
    "    .first()\n",
    "    .reset_index()\n",
    ")\n",
    "\n",
    "selected_algs = (\n",
    "    results.sort_values(\"score\", ascending=False)\n",
    "    .groupby(\"alg_simplified\")\n",
    "    .first()['alg']\n",
    "    .tolist()\n",
    ")\n",
    "\n",
    "# place sellf first\n",
    "sellf_algs = [alg for alg in selected_algs if \"sellf\" in alg]\n",
    "selected_algs = [alg for alg in selected_algs if \"sellf\" not in alg]\n",
    "selected_algs = sellf_algs + selected_algs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fed5332",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(6.5, 2.7))\n",
    "\n",
    "df_list = [\n",
    "    pd.read_csv(f\"experiments/{env_name}/{mu_type}/{alg}/eval/eval_data.csv\")\n",
    "    for alg in selected_algs\n",
    "]\n",
    "\n",
    "for i in range(len(df_list)):\n",
    "    df_list[i][\"accept_g0\"] = df_list[i][\"action\"] * (df_list[i][\"group_id\"] == 0)\n",
    "    df_list[i][\"accept_g1\"] = df_list[i][\"action\"] * (df_list[i][\"group_id\"] == 1)\n",
    "\n",
    "    # for each ep, compute cumulative sum of accept_g0 and accept_g1\n",
    "    df_list[i][\"n_accept_g0\"] = df_list[i].groupby(\"ep\")[\"accept_g0\"].cumsum()\n",
    "    df_list[i][\"n_accept_g1\"] = df_list[i].groupby(\"ep\")[\"accept_g1\"].cumsum()\n",
    "\n",
    "    # for each ep, calculate a rolling window of label with window 50\n",
    "    df_list[i][\"label\"] = (\n",
    "        df_list[i]\n",
    "        .groupby(\"ep\")[\"label\"]\n",
    "        .rolling(window=500, min_periods=30)\n",
    "        .mean()\n",
    "        .reset_index(0, drop=True)\n",
    "    )\n",
    "\n",
    "    df_list[i][\"delta\"] = df_list[i][\"delta\"].abs()\n",
    "\n",
    "    df_list[i] = (\n",
    "        df_list[i]\n",
    "        .groupby(\"t\")\n",
    "        .agg(\n",
    "            {\n",
    "                \"resource\": [\"mean\", \"std\"],\n",
    "                \"delta\": [\"mean\", \"std\"],\n",
    "                \"n_accept_g0\": [\"mean\", \"std\"],\n",
    "                \"n_accept_g1\": [\"mean\", \"std\"],\n",
    "                \"label\": [\"mean\", \"std\"],\n",
    "            }\n",
    "        )\n",
    "    )\n",
    "\n",
    "\n",
    "for alg, df in zip(selected_algs, df_list):\n",
    "    axs[0].plot(\n",
    "        df.index,\n",
    "        df[\"resource\"][\"mean\"],\n",
    "        label=base_legend_func(alg),\n",
    "        color=base_color_func(alg),\n",
    "        linestyle=\"--\" if \"full\" in alg else \"-\",\n",
    "        zorder = get_zorder(alg)\n",
    "    )\n",
    "    axs[0].fill_between(\n",
    "        df.index,\n",
    "        df[\"resource\"][\"mean\"] - df[\"resource\"][\"std\"],\n",
    "        df[\"resource\"][\"mean\"] + df[\"resource\"][\"std\"],\n",
    "        alpha=0.3,\n",
    "        color=base_color_func(alg),\n",
    "        zorder = get_zorder(alg)\n",
    "    )\n",
    "\n",
    "    axs[1].plot(\n",
    "        df.index,\n",
    "        df[\"delta\"][\"mean\"],\n",
    "        label=base_legend_func(alg),\n",
    "        color=base_color_func(alg),\n",
    "        linestyle=\"--\" if \"full\" in alg else \"-\",\n",
    "        zorder = get_zorder(alg),\n",
    "    )\n",
    "    axs[1].fill_between(\n",
    "        df.index,\n",
    "        df[\"delta\"][\"mean\"] - df[\"delta\"][\"std\"],\n",
    "        df[\"delta\"][\"mean\"] + df[\"delta\"][\"std\"],\n",
    "        alpha=0.3,\n",
    "        color=base_color_func(alg),\n",
    "        zorder =  get_zorder(alg),\n",
    "    )\n",
    "\n",
    "    axs[2].plot(\n",
    "        df.index,\n",
    "        df[\"label\"][\"mean\"],\n",
    "        label=base_legend_func(alg),\n",
    "        color=base_color_func(alg),\n",
    "        linestyle=\"--\" if \"full\" in alg else \"-\",\n",
    "        lw=2,\n",
    "        zorder =  get_zorder(alg),\n",
    "    )\n",
    "    axs[2].fill_between(\n",
    "        df.index,\n",
    "        df[\"label\"][\"mean\"] - df[\"label\"][\"std\"],\n",
    "        df[\"label\"][\"mean\"] + df[\"label\"][\"std\"],\n",
    "        alpha=0.3,\n",
    "        color=base_color_func(alg),\n",
    "        zorder = get_zorder(alg),\n",
    "    )\n",
    "\n",
    "for i in range(3):\n",
    "    axs[i].set_xlabel(\"Timestep\")\n",
    "\n",
    "axs[0].set_title(\"Cumulative Reward\", fontsize = 12)\n",
    "axs[1].set_title(\"Disparity ($|\\Delta_t|$)\", fontsize = 12)\n",
    "axs[2].set_title(\"$P(Y = 1 | Z = $unprivileged)\", fontsize = 12)\n",
    "\n",
    "# place legend outside bellow all plots\n",
    "fig.subplots_adjust(bottom=0.3, wspace=0.33)\n",
    "\n",
    "axs[1].legend(\n",
    "    loc=\"upper center\",\n",
    "    bbox_to_anchor=(0.35, -0.3),\n",
    "    ncol=4 #len(selected_algs),\n",
    ")\n",
    "\n",
    "for i in range(3):\n",
    "    axs[i].grid(True)\n",
    "# make y ticks of last plot on the right\n",
    "\n",
    "fig.subplots_adjust(wspace = 0.35)\n",
    "plt.savefig(\"figures/enem_qualification.pdf\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f5431020",
   "metadata": {},
   "source": [
    "# Appendix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe33de74",
   "metadata": {},
   "outputs": [],
   "source": [
    "method_names = [\n",
    "    \"ppo\",\n",
    "    \"pocar\",\n",
    "    \"pocar_full\",\n",
    "    \"sellf\",\n",
    "    \"sellf_censored_25\",\n",
    "    \"elbert\",\n",
    "    \"focops\",\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "33b7d5b5",
   "metadata": {},
   "source": [
    "## FICO"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4214a451",
   "metadata": {},
   "source": [
    "### Accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d3cdf2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "env_name = \"fico_equal\"\n",
    "mu_type = \"accuracy\"\n",
    "method_names = [\n",
    "    \"ppo\",\n",
    "    \"pocar\",\n",
    "    \"pocar_full\",\n",
    "    \"sellf\",\n",
    "    \"sellf_censored_25\",\n",
    "    \"elbert\",\n",
    "    \"focops\",\n",
    "]\n",
    "alg_names = sum([get_alg_names(alg) for alg in method_names], [])\n",
    "results = summarize_hyperparam(env_name, mu_type, alg_names)\n",
    "process_table(results.sort_values(\"score\", ascending=False).groupby(\"alg_simplified\").first().reset_index())\n",
    "\n",
    "\n",
    "selected_algs = (\n",
    "    results.sort_values(\"score\", ascending=False)\n",
    "    .groupby(\"alg_simplified\")\n",
    "    .first()['alg']\n",
    "    .tolist()\n",
    ")\n",
    "\n",
    "# place sellf first\n",
    "sellf_algs = [alg for alg in selected_algs if \"sellf\" in alg]\n",
    "selected_algs = [alg for alg in selected_algs if \"sellf\" not in alg]\n",
    "selected_algs = sellf_algs + selected_algs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fab2401e",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_plot(selected_algs, env_name, mu_type, f\"figures/{env_name}_{mu_type}.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "681b1824",
   "metadata": {},
   "source": [
    "### Qualification parity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45284404",
   "metadata": {},
   "outputs": [],
   "source": [
    "env_name = \"fico_equal\"\n",
    "mu_type = \"qualification\"\n",
    "method_names = [\n",
    "    \"ppo\",\n",
    "    \"pocar\",\n",
    "    \"pocar_full\",\n",
    "    \"sellf\",\n",
    "    \"sellf_censored_25\",\n",
    "    \"elbert\",\n",
    "    \"focops\",\n",
    "]\n",
    "alg_names = sum([get_alg_names(alg) for alg in method_names], [])\n",
    "results = summarize_hyperparam(env_name, mu_type, alg_names)\n",
    "process_table(results.sort_values(\"score\", ascending=False).groupby(\"alg_simplified\").first().reset_index())\n",
    "\n",
    "\n",
    "selected_algs = (\n",
    "    results.sort_values(\"score\", ascending=False)\n",
    "    .groupby(\"alg_simplified\")\n",
    "    .first()['alg']\n",
    "    .tolist()\n",
    ")\n",
    "\n",
    "# place sellf first\n",
    "sellf_algs = [alg for alg in selected_algs if \"sellf\" in alg]\n",
    "selected_algs = [alg for alg in selected_algs if \"sellf\" not in alg]\n",
    "selected_algs = sellf_algs + selected_algs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8df78a87",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_plot(selected_algs, env_name, mu_type, f\"figures/{env_name}_{mu_type}.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8c5b07c6",
   "metadata": {},
   "source": [
    "## COMPAS"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5aa46989",
   "metadata": {},
   "source": [
    "### Equal opportunity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f90e2e01",
   "metadata": {},
   "outputs": [],
   "source": [
    "env_name = \"compas\"\n",
    "mu_type = \"tpr\"\n",
    "alg_names = [\n",
    "    \"ppo\",\n",
    "    \"pocar\",\n",
    "    \"pocar_full\",\n",
    "    \"sellf\",\n",
    "    \"sellf_censored_25\",\n",
    "    \"sellf_censored_3\",\n",
    "    \"elbert\",\n",
    "    \"focops\",\n",
    "]\n",
    "alg_names = sum(\n",
    "    [get_alg_names(alg) for alg in alg_names],\n",
    "    [],\n",
    ")\n",
    "results = summarize_hyperparam(env_name, mu_type, alg_names)\n",
    "process_table(\n",
    "    results.sort_values(\"score\", ascending=False)\n",
    "    .groupby(\"alg_simplified\")\n",
    "    .first()\n",
    "    .reset_index()\n",
    ")\n",
    "\n",
    "selected_algs = (\n",
    "    results.sort_values(\"score\", ascending=False)\n",
    "    .groupby(\"alg_simplified\")\n",
    "    .first()['alg']\n",
    "    .tolist()\n",
    ")\n",
    "\n",
    "# place sellf first\n",
    "sellf_algs = [alg for alg in selected_algs if \"sellf\" in alg]\n",
    "selected_algs = [alg for alg in selected_algs if \"sellf\" not in alg]\n",
    "selected_algs = sellf_algs + selected_algs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba656d7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "results[results.alg.isin(selected_algs)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d591350",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_plot(selected_algs, env_name, mu_type, f\"figures/{env_name}_{mu_type}.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ee20c5f5",
   "metadata": {},
   "source": [
    "### Qualification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48d58c97",
   "metadata": {},
   "outputs": [],
   "source": [
    "env_name = \"compas\"\n",
    "mu_type = \"qualification\"\n",
    "alg_names = [\n",
    "    \"ppo\",\n",
    "    \"pocar\",\n",
    "    \"pocar_full\",\n",
    "    \"sellf\",\n",
    "    \"sellf_censored_25\",\n",
    "    \"sellf_censored_3\",\n",
    "    \"elbert\",\n",
    "    \"focops\",\n",
    "]\n",
    "alg_names = sum(\n",
    "    [get_alg_names(alg) for alg in alg_names],\n",
    "    [],\n",
    ")\n",
    "results = summarize_hyperparam(env_name, mu_type, alg_names)\n",
    "process_table(\n",
    "    results.sort_values(\"score\", ascending=False)\n",
    "    .groupby(\"alg_simplified\")\n",
    "    .first()\n",
    "    .reset_index()\n",
    ")\n",
    "\n",
    "selected_algs = (\n",
    "    results.sort_values(\"score\", ascending=False)\n",
    "    .groupby(\"alg_simplified\")\n",
    "    .first()['alg']\n",
    "    .tolist()\n",
    ")\n",
    "\n",
    "# place sellf first\n",
    "sellf_algs = [alg for alg in selected_algs if \"sellf\" in alg]\n",
    "selected_algs = [alg for alg in selected_algs if \"sellf\" not in alg]\n",
    "selected_algs = sellf_algs + selected_algs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a850edc1",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_plot(selected_algs, env_name, mu_type, f\"figures/{env_name}_{mu_type}.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3094b7d",
   "metadata": {},
   "source": [
    "# ENEM"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "91c50f4f",
   "metadata": {},
   "source": [
    "### Equality of opportunity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d7f02a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "env_name = \"enem\"\n",
    "mu_type = \"tpr\"\n",
    "alg_names = [\n",
    "    \"ppo\",\n",
    "    \"pocar\",\n",
    "    \"pocar_full\",\n",
    "    \"sellf\",\n",
    "    \"sellf_censored_25\",\n",
    "    \"sellf_censored_3\",\n",
    "    \"elbert\",\n",
    "    \"focops\",\n",
    "]\n",
    "alg_names = sum(\n",
    "    [get_alg_names(alg) for alg in alg_names],\n",
    "    [],\n",
    ")\n",
    "results = summarize_hyperparam(env_name, mu_type, alg_names)\n",
    "process_table(\n",
    "    results.sort_values(\"score\", ascending=False)\n",
    "    .groupby(\"alg_simplified\")\n",
    "    .first()\n",
    "    .reset_index()\n",
    ")\n",
    "\n",
    "selected_algs = (\n",
    "    results.sort_values(\"score\", ascending=False)\n",
    "    .groupby(\"alg_simplified\")\n",
    "    .first()['alg']\n",
    "    .tolist()\n",
    ")\n",
    "\n",
    "# place sellf first\n",
    "sellf_algs = [alg for alg in selected_algs if \"sellf\" in alg]\n",
    "selected_algs = [alg for alg in selected_algs if \"sellf\" not in alg]\n",
    "selected_algs = sellf_algs + selected_algs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53324a44",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_plot(selected_algs, env_name, mu_type, f\"figures/{env_name}_{mu_type}.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f673280d",
   "metadata": {},
   "source": [
    "### Accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e6af4bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "env_name = \"enem\"\n",
    "mu_type = \"accuracy\"\n",
    "alg_names = [\n",
    "    \"ppo\",\n",
    "    \"pocar\",\n",
    "    \"pocar_full\",\n",
    "    \"sellf\",\n",
    "    \"sellf_censored_25\",\n",
    "    \"sellf_censored_3\",\n",
    "    \"elbert\",\n",
    "    \"focops\",\n",
    "]\n",
    "alg_names = sum(\n",
    "    [get_alg_names(alg) for alg in alg_names],\n",
    "    [],\n",
    ")\n",
    "results = summarize_hyperparam(env_name, mu_type, alg_names)\n",
    "process_table(\n",
    "    results.sort_values(\"score\", ascending=False)\n",
    "    .groupby(\"alg_simplified\")\n",
    "    .first()\n",
    "    .reset_index()\n",
    ")\n",
    "\n",
    "selected_algs = (\n",
    "    results.sort_values(\"score\", ascending=False)\n",
    "    .groupby(\"alg_simplified\")\n",
    "    .first()['alg']\n",
    "    .tolist()\n",
    ")\n",
    "\n",
    "# place sellf first\n",
    "sellf_algs = [alg for alg in selected_algs if \"sellf\" in alg]\n",
    "selected_algs = [alg for alg in selected_algs if \"sellf\" not in alg]\n",
    "selected_algs = sellf_algs + selected_algs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28606fe7",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_plot(selected_algs, env_name, mu_type, f\"figures/{env_name}_{mu_type}.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0a52e889",
   "metadata": {},
   "source": [
    "# Deep predictor model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b4da7b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "env_name = \"fico_equal\"\n",
    "mu_type = \"tpr\"\n",
    "alg_names = sum([get_alg_names(alg) for alg in [\"sellf\", \"sellf_deep\", \"ppo\"]], [])\n",
    "results = summarize_hyperparam(env_name, mu_type, alg_names)\n",
    "process_table(results.sort_values(\"score\", ascending=False).groupby(\"alg_simplified\").first().reset_index())\n",
    "\n",
    "selected_algs = (\n",
    "    results.sort_values(\"score\", ascending=False)\n",
    "    .groupby(\"alg_simplified\")\n",
    "    .first()[\"alg\"]\n",
    "    .tolist()\n",
    ")\n",
    "\n",
    "# place sellf first\n",
    "sellf_alg = [alg for alg in selected_algs if \"sellf\" in alg][0]\n",
    "selected_algs.remove(sellf_alg)\n",
    "selected_algs = [sellf_alg] + selected_algs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "beb59b16",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_plot(selected_algs, env_name, mu_type, \"figures/fico_tpr_deep.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6660882d",
   "metadata": {},
   "outputs": [],
   "source": [
    "env_name = \"enem\"\n",
    "mu_type = \"qualification\"\n",
    "alg_names = sum([get_alg_names(alg) for alg in [\"sellf\", \"sellf_deep\", \"ppo\"]], [])\n",
    "results = summarize_hyperparam(env_name, mu_type, alg_names)\n",
    "process_table(results.sort_values(\"score\", ascending=False).groupby(\"alg_simplified\").first())\n",
    "\n",
    "selected_algs = (\n",
    "    results.sort_values(\"score\", ascending=False)\n",
    "    .groupby(\"alg_simplified\")\n",
    "    .first()[\"alg\"]\n",
    "    .tolist()\n",
    ")\n",
    "\n",
    "# place sellf first\n",
    "sellf_alg = [alg for alg in selected_algs if \"sellf\" in alg][0]\n",
    "selected_algs.remove(sellf_alg)\n",
    "selected_algs = [sellf_alg] + selected_algs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "762dd189",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_plot(selected_algs, env_name, mu_type, \"figures/enem_qualification_deep.pdf\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
