{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch import nn\n",
    "import copy\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Multiple binary classifiers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "all_df = pd.read_csv(\"//exps/div_explore/toy_v2/results_seed0_sc.csv\")\n",
    "all_df1 = pd.read_csv(\"//exps/div_explore/toy_v2/results_seed0_ai.csv\")\n",
    "all_df2 = pd.read_csv(\"//exps/div_explore/toy_v2/results_seed0_ci.csv\")\n",
    "all_df3 = pd.read_csv(\"//exps/div_explore/toy_v2/results_seed0_3shifts.csv\")\n",
    "all_df = pd.concat([all_df, all_df1, all_df2, all_df3], ignore_index=True)\n",
    "\n",
    "\n",
    "all_df['n'] = all_df['n'].astype('str')\n",
    "all_df['rank'] = all_df.groupby(['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat'])['wga_te_err'].rank(\"first\")\n",
    "all_df['wga_te_err_var'] = all_df.groupby(['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat'])['wga_te_err'].transform('var')\n",
    "\n",
    "# def get_gt_rank(x, mode='tie', filter_thre=0.05):\n",
    "#     # x is a dataframe with columns ['method', 'wga_te_err'], find the method(s) with smallest wga_te_err, with filter_thre tolerance for tie\n",
    "#     min_err = x['wga_te_err'].min()\n",
    "#     winners = x[x['wga_te_err'] <= min_err + filter_thre][\"method\"].to_list()\n",
    "#     return '|'.join(winners)\n",
    "\n",
    "# all_df[\"winners\"] = all_df.groupby(['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat'])[['method', 'wga_te_err']].apply(lambda x: get_gt_rank(x))\n",
    "\n",
    "def get_gt_rank(x, mode='tie', filter_thre=0.05):\n",
    "    # x is a dataframe with columns ['method', 'wga_te_err'], find the method(s) with smallest wga_te_err, with filter_thre tolerance for tie\n",
    "    min_err = x['wga_te_err'].min()\n",
    "    winners = x[x['wga_te_err'] <= min_err + filter_thre][\"method\"].to_list()\n",
    "    return '|'.join(winners)\n",
    "\n",
    "# Apply the function to each group and reset the index to merge back with the original dataframe\n",
    "winners_series = all_df.groupby(['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat'])[['method', 'wga_te_err']].apply(lambda x: get_gt_rank(x)).reset_index(name='winners')\n",
    "\n",
    "# Merge the results back with the original dataframe\n",
    "all_df = all_df.merge(winners_series, on=['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat'])\n",
    "\n",
    "all_df_rank = all_df[all_df[\"rank\"]==1.0]\n",
    "to_one_hot = {\n",
    "    \"GroupDRO\":     [1,0,0,0,0],\n",
    "    \"ERM\":          [0,1,0,0,0],\n",
    "    \"undersample\":  [0,0,1,0,0],\n",
    "    \"oversample\":   [0,0,0,1,0],\n",
    "    \"remax-margin\": [0,0,0,0,1],\n",
    "}\n",
    "all_df_rank[\"one_hot\"] = all_df_rank[\"method\"].map(lambda x: to_one_hot[x])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_df_rank"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# convert to multiple binary classification\n",
    "# double-check two method case\n",
    "# def prepare_data(n):\n",
    "#     from sklearn.model_selection import train_test_split\n",
    "#     tr_idx, te_idx = train_test_split(range(n), test_size=0.2, random_state=0)\n",
    "#     return tr_idx, te_idx\n",
    "\n",
    "def prepare_data(n, tr_size=3000):\n",
    "    from sklearn.model_selection import train_test_split\n",
    "    tr_idx, te_idx = train_test_split(range(n), test_size=0.2, random_state=0)\n",
    "\n",
    "    if tr_size > 0:\n",
    "        num = len(tr_idx)\n",
    "        te_size = (num - tr_size)/num\n",
    "        tr_idx, val_idx = train_test_split(tr_idx, test_size=te_size, random_state=0)\n",
    "        print(\"train size:\", len(tr_idx))\n",
    "\n",
    "    return tr_idx, te_idx\n",
    "\n",
    "def train_binary_classifier(all_df, tr_idx, te_idx, method1, method2, filter_thre=0.05, mode=\"filter\", remove_col=-1):\n",
    "    # mode can be tie\n",
    "    import copy\n",
    "    import warnings\n",
    "    from sklearn.preprocessing import StandardScaler\n",
    "    from sklearn.neural_network import MLPClassifier\n",
    "\n",
    "    warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "    all_df_m1 = all_df[all_df[\"method\"]==method1]\n",
    "    all_df_m2 = all_df[all_df[\"method\"]==method2]\n",
    "    all_df_m1[\"loss_diff\"] = all_df_m1['wga_te_err'].to_numpy() - all_df_m2['wga_te_err'].to_numpy()\n",
    "\n",
    "    all_df = copy.deepcopy(all_df_m1)\n",
    "\n",
    "    if mode == \"filter\":\n",
    "        all_df['over_rank'] = all_df['loss_diff'] < 0\n",
    "    elif mode ==\"tie\":\n",
    "        def lossdiff2rank(x):\n",
    "            if x < -0.05:\n",
    "                return 1\n",
    "            elif x > 0.05:\n",
    "                return 0\n",
    "            else:\n",
    "                return 2\n",
    "        all_df['over_rank'] = all_df['loss_diff'].map(lambda x: lossdiff2rank(x))\n",
    "    all_df['over_rank'] = all_df['over_rank'].astype('int')\n",
    "\n",
    "    all_df_tr = all_df.iloc[tr_idx]\n",
    "    all_df_te = all_df.iloc[te_idx]\n",
    "\n",
    "    if mode == \"filter\":\n",
    "        all_df_tr = all_df_tr[(all_df_tr['loss_diff'] < -filter_thre)|(all_df_tr['loss_diff'] > filter_thre)]\n",
    "        all_df_te = all_df_te[(all_df_te['loss_diff'] < -filter_thre)|(all_df_te['loss_diff'] > filter_thre)]\n",
    "\n",
    "    all_df_tr = all_df_tr[['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat', 'over_rank']]\n",
    "    all_df_te = all_df_te[['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat', 'over_rank']]\n",
    "\n",
    "    X_train = all_df_tr.to_numpy()[:,:-1].astype('float')\n",
    "    y_train = all_df_tr.to_numpy()[:,-1].astype('int')\n",
    "    X_test = all_df_te.to_numpy()[:,:-1].astype('float')\n",
    "    y_test = all_df_te.to_numpy()[:,-1].astype('int')\n",
    "\n",
    "    if remove_col != -1:\n",
    "        cols = list(range(X_train.shape[1]))\n",
    "        cols.remove(remove_col)\n",
    "        X_train, X_test = X_train[:,cols], X_test[:,cols]\n",
    "\n",
    "    scaler = StandardScaler()\n",
    "    X_train = scaler.fit_transform(X_train)\n",
    "    X_test = scaler.transform(X_test)\n",
    "\n",
    "    # clf = MLPClassifier(random_state=1, max_iter=1000000, verbose=False, tol=5e-3, n_iter_no_change=10000, alpha=0.0001, hidden_layer_sizes=(100,10,)).fit(X_train, y_train)\n",
    "    clf = MLPClassifier(random_state=1, max_iter=1000000, verbose=False, tol=1e-3, n_iter_no_change=2000, alpha=0.01, hidden_layer_sizes=(100,10,)).fit(X_train, y_train)\n",
    "    train_acc = clf.score(X_train, y_train)\n",
    "    test_acc = clf.score(X_test, y_test)\n",
    "    print(method1, method2, \"remove: \", remove_col, \"train size: \", len(all_df_tr), \"train acc: \", clf.score(X_train, y_train), \"test acc: \",clf.score(X_test, y_test))\n",
    "    # print()\n",
    "\n",
    "    return clf, X_test, test_acc\n",
    "\n",
    "\n",
    "import itertools\n",
    "res = []\n",
    "methods = np.array([\"ERM\", \"GroupDRO\", \"oversample\", \"undersample\", \"remax-margin\"])\n",
    "# combinations = np.array(list(itertools.combinations(methods, 2)))\n",
    "combinations = np.array([[\"oversample\", \"undersample\"], [\"undersample\", \"remax-margin\"], [\"undersample\", \"ERM\"], [\"remax-margin\", \"GroupDRO\"]])\n",
    "\n",
    "num_exps = all_df[all_df['method']==methods[0]].shape[0]\n",
    "tr_idx, te_idx = prepare_data(num_exps)\n",
    "remove_list = list(range(6)) + [-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "zoo = {}\n",
    "for m1, m2 in combinations:\n",
    "    for remove_col in remove_list:\n",
    "        zoo[(m1, m2)], X_test, test_acc = train_binary_classifier(all_df, tr_idx, te_idx, m1, m2, filter_thre=0.05, mode=\"tie\", remove_col=remove_col)\n",
    "        curr_res = {\n",
    "            \"m1\": m1,\n",
    "            \"m2\": m2,\n",
    "            \"test acc\": test_acc,\n",
    "            \"remove\": remove_col,\n",
    "        }\n",
    "        res.append(curr_res)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_res = pd.DataFrame(res)\n",
    "df_res[\"method\"] = df_res[\"m1\"] + \"|\" + df_res[\"m2\"]\n",
    "# map the method to the format of first letter of each method\n",
    "df_res[\"method\"] = df_res[\"method\"].map(lambda x: \"-\".join([i[0].upper() for i in x.split(\"|\")]))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_res\n",
    "# save it\n",
    "df_res.to_csv(\"toy_v2_binary_results_3000.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# col_name = ['N/A', r'$n$', r'$d_{sc}$', r'$d_{ci}$', r'$d_{ai}$', r'$r$', r'$d$']\n",
    "col_name = ['N/A', r'$n$', r'$d_{sc}$', r'$d_{ls}$', r'$d_{cs}$', r'$r$', r'$d$']\n",
    "df_res['Removed descriptor'] = df_res['remove'].map(lambda x: col_name[int(x)+1])\n",
    "\n",
    "df_res['method_mapped'] = df_res['method'].map(lambda x: {\"O-U\":f\"Oversample/ \\n Undersample\", \"U-R\":f\"Undersample/ \\n Logit corr.\", \"U-E\":f\"Undersample/ \\n ERM\", \"R-G\": f\"GroupDRO/ \\n Logit corr.\"}[x])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_res = df_res[df_res['method']!=\"U-R\"]\n",
    "plt.figure(figsize=(7, 4))\n",
    "sns.set(style=\"whitegrid\")\n",
    "sns.set(font_scale=1.3)\n",
    "# plot the te_accs and jac_accs\n",
    "# adjust the plot size\n",
    "# plt.figure(figsize=(6, 5))\n",
    "# df2 = pd.DataFrame({\"Removed column\": col_list, \"Test accuracy\": jac_accs, \"Type\": \"Jaccard acc.\"})\n",
    "# df = pd.concat([df1, df2])\n",
    "# use seaborn\n",
    "sns.set_style(\"whitegrid\")\n",
    "# barplot\n",
    "sns.barplot(data=df_res, x=\"method_mapped\", y=\"test acc\", hue=\"Removed descriptor\", palette='RdYlBu', hue_order=col_name)\n",
    "# sns.barplot(data=df_res, x=\"method_mapped\", y=\"test acc\", hue=\"Removed descriptor\", palette='viridis', hue_order=col_name)\n",
    "plt.ylabel(\"Test 0-1 ACC.\")\n",
    "plt.xlabel(\"\")\n",
    "plt.legend(ncols=1, bbox_to_anchor=(1.01, 0.99), title=f\"Removed \\n descriptor\")\n",
    "plt.ylim(0.7, 0.95)\n",
    "plt.tight_layout()\n",
    "# save with high resolution\n",
    "plt.savefig(\"pattern3.pdf\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_res = df_res[df_res['method']!=\"U-R\"]\n",
    "plt.figure(figsize=(7, 5))\n",
    "sns.set(style=\"whitegrid\")\n",
    "sns.set(font_scale=1.3)\n",
    "# plot the te_accs and jac_accs\n",
    "# adjust the plot size\n",
    "# plt.figure(figsize=(6, 5))\n",
    "# df2 = pd.DataFrame({\"Removed column\": col_list, \"Test accuracy\": jac_accs, \"Type\": \"Jaccard acc.\"})\n",
    "# df = pd.concat([df1, df2])\n",
    "# use seaborn\n",
    "sns.set_style(\"whitegrid\")\n",
    "# barplot\n",
    "sns.barplot(data=df_res, x=\"method_mapped\", y=\"test acc\", hue=\"Removed descriptor\", palette='RdYlBu', hue_order=col_name)\n",
    "# sns.barplot(data=df_res, x=\"method_mapped\", y=\"test acc\", hue=\"Removed descriptor\", palette='viridis', hue_order=col_name)\n",
    "plt.ylabel(\"Test 0-1 ACC.\")\n",
    "plt.xlabel(\"\")\n",
    "plt.legend(ncols=1, bbox_to_anchor=(1.01, 0.85), title=f\"Removed \\n descriptor\")\n",
    "plt.ylim(0.7, 0.95)\n",
    "plt.tight_layout()\n",
    "# save with high resolution\n",
    "plt.savefig(\"pattern3.pdf\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(5.3, 5))\n",
    "sns.set(style=\"whitegrid\")\n",
    "sns.set(font_scale=1.6)\n",
    "# plot the te_accs and jac_accs\n",
    "# adjust the plot size\n",
    "# plt.figure(figsize=(6, 5))\n",
    "# df2 = pd.DataFrame({\"Removed column\": col_list, \"Test accuracy\": jac_accs, \"Type\": \"Jaccard acc.\"})\n",
    "# df = pd.concat([df1, df2])\n",
    "# use seaborn\n",
    "sns.set_style(\"whitegrid\")\n",
    "# barplot\n",
    "sns.barplot(data=df_res, x=\"method_mapped\", y=\"test acc\", hue=\"Removed descriptor\", palette='RdYlBu', hue_order=col_name)\n",
    "plt.ylabel(\"Test 0-1 ACC.\")\n",
    "plt.xlabel(\"\")\n",
    "plt.legend(ncols=1, bbox_to_anchor=(1.01, 0.85), title=f\"Removed \\n descriptor\")\n",
    "plt.ylim(0.65, 0.95)\n",
    "plt.tight_layout()\n",
    "# save with high resolution\n",
    "plt.savefig(\"pattern1.pdf\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(1.5, 2.5))\n",
    "sns.set(style=\"whitegrid\")\n",
    "plot_list = [-1,0,1,4,5]\n",
    "feat_list = ['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat', 'none']\n",
    "methods = df_res['method'].unique()\n",
    "palette = sns.color_palette(\"deep\")[:len(plot_list)]\n",
    "# map of plot_list and palette\n",
    "plot_list_palette = dict(zip(plot_list, palette))\n",
    "\n",
    "i=0\n",
    "for m in methods:\n",
    "    df_res_ = df_res[df_res['method']==m]\n",
    "    df_res_ = df_res_.sort_values(by=\"test acc\", ascending=False)\n",
    "    # itereate df_res_ and plot\n",
    "    for row in df_res_.iterrows():\n",
    "        idx, row = row\n",
    "        if row[\"remove\"] in plot_list:\n",
    "            c = plot_list_palette[row[\"remove\"]]\n",
    "            plt.bar(m, row[\"test acc\"], color=c, label=feat_list[row[\"remove\"]], alpha=0.9)\n",
    "    if i == 0:\n",
    "        plt.legend(title='Removed feature', bbox_to_anchor=(1.02, 1))\n",
    "        i+=1\n",
    "\n",
    "plt.ylim(0.65, 0.95)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(6, 4))\n",
    "sns.set(style=\"whitegrid\")\n",
    "plot_list = [-1,0,1,4,5]\n",
    "feat_list = ['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat', 'none']\n",
    "methods = df_res['method'].unique()\n",
    "palette = sns.color_palette(\"deep\")[:len(plot_list)]\n",
    "# map of plot_list and palette\n",
    "plot_list_palette = dict(zip(plot_list, palette))\n",
    "\n",
    "i=0\n",
    "for m in methods:\n",
    "    df_res_ = df_res[df_res['method']==m]\n",
    "    df_res_ = df_res_.sort_values(by=\"test acc\", ascending=False)\n",
    "    # itereate df_res_ and plot\n",
    "    for row in df_res_.iterrows():\n",
    "        idx, row = row\n",
    "        if row[\"remove\"] in plot_list:\n",
    "            c = plot_list_palette[row[\"remove\"]]\n",
    "            plt.bar(m, row[\"test acc\"], color=c, label=feat_list[row[\"remove\"]], alpha=0.9)\n",
    "    if i == 0:\n",
    "        plt.legend(title='Removed feature', bbox_to_anchor=(1.02, 1))\n",
    "        i+=1\n",
    "\n",
    "plt.ylim(0.6, 1.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# rotate the barplot\n",
    "plot_list = [-1,0,1,4,5]\n",
    "df_res = df_res[df_res[\"remove\"].isin(plot_list)]\n",
    "sns.barplot(data=df_res, x=\"m1\", y=\"test acc\",  hue=\"remove\")\n",
    "plt.ylim(0.65, 0.95)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(1.5, 2.5))\n",
    "sns.set(style=\"whitegrid\")\n",
    "plot_list = [-1,0,1,4,5]\n",
    "feat_list = ['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat', 'none']\n",
    "methods = df_res['method'].unique()\n",
    "palette = sns.color_palette(\"deep\")[:len(plot_list)]\n",
    "# map of plot_list and palette\n",
    "plot_list_palette = dict(zip(plot_list, palette))\n",
    "\n",
    "i=0\n",
    "for m in methods:\n",
    "    df_res_ = df_res[df_res['method']==m]\n",
    "    df_res_ = df_res_.sort_values(by=\"test acc\", ascending=False)\n",
    "    # itereate df_res_ and plot\n",
    "    for row in df_res_.iterrows():\n",
    "        idx, row = row\n",
    "        if row[\"remove\"] in plot_list:\n",
    "            c = plot_list_palette[row[\"remove\"]]\n",
    "            plt.bar(m, row[\"test acc\"], color=c, label=feat_list[row[\"remove\"]], alpha=0.9)\n",
    "    if i == 0:\n",
    "        plt.legend(title='Removed feature', bbox_to_anchor=(1.02, 1))\n",
    "        i+=1\n",
    "\n",
    "plt.ylim(0.65, 0.95)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "div_backup",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
