{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "from time import time\n",
    "from sklearn.metrics import f1_score, roc_auc_score\n",
    "np.random.seed(0)\n",
    "\n",
    "from rabit import Action, RecourseBoostingClassifier, RecourseExplainer, ExactRecourseExplainer\n",
    "from rabit.datasets import FicoDataset, CompasDataset, AdultDataset, BailDataset\n",
    "\n",
    "pd.set_option('display.max_columns', None)\n",
    "pd.set_option('display.max_rows', None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Exp 7. Exact Recourse Algorithm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_comparison(dataset, n_iter=1, n_estimators=100, cost_budget=0.2, n_instances=30):\n",
    "    \n",
    "    results = {\n",
    "        'dataset': [],\n",
    "        'method': [],\n",
    "        'n_estimators': [],\n",
    "        'gamma': [],\n",
    "        'cost_budget': [],\n",
    "        'time': [],\n",
    "        'accuracy': [],\n",
    "        'f1': [],\n",
    "        'AUC': [],\n",
    "        'validity': [],\n",
    "        'cost': [],\n",
    "        'sparsity': [],\n",
    "        'plausibility': [],\n",
    "    }\n",
    "    \n",
    "    print('Running {} dataset'.format(dataset.name))\n",
    "    for _ in tqdm(range(n_iter)):\n",
    "\n",
    "        X_tr, X_ts, y_tr, y_ts = dataset.get_dataset(split=True)\n",
    "        action = Action(dataset.params, cost_budget=cost_budget)\n",
    "        action = action.fit(X_tr, y_tr)\n",
    "        \n",
    "        done_vanilla = False\n",
    "        for gamma in [0.0, 0.0, 0.002]:\n",
    "            results['dataset'].append(dataset.name)\n",
    "            if gamma == 0.0:\n",
    "                if done_vanilla:\n",
    "                    method = 'OAF'\n",
    "                else:\n",
    "                    method = 'Vanilla'\n",
    "                    done_vanilla = True\n",
    "            else:\n",
    "                method = 'RABIT'\n",
    "            results['method'].append(method)\n",
    "            results['n_estimators'].append(n_estimators)\n",
    "            results['gamma'].append(gamma)\n",
    "            results['cost_budget'].append(cost_budget)\n",
    "\n",
    "            start = time()\n",
    "            estimator = RecourseBoostingClassifier(action, n_estimators=n_estimators, gamma=gamma, only_actionable_features=(method == 'OAF'))        \n",
    "            estimator = estimator.fit(X_tr, y_tr)\n",
    "            results['time'].append(time() - start)\n",
    "            results['accuracy'].append(estimator.score(X_ts, y_ts))\n",
    "            results['f1'].append(f1_score(y_ts, estimator.predict(X_ts)))\n",
    "            results['AUC'].append(roc_auc_score(y_ts, estimator.predict_proba(X_ts)[:, 1]))\n",
    "\n",
    "            X_target = X_ts[estimator.predict(X_ts) != 1][:n_instances]      \n",
    "            explainer = ExactRecourseExplainer(estimator, action, max_features=4)\n",
    "            recourse = explainer.explain_recourse(X_target, time_limit=60)\n",
    "            results['validity'].append(recourse.get_validity())\n",
    "            results['cost'].append(recourse.get_cost())\n",
    "            results['sparsity'].append(recourse.get_sparsity())\n",
    "            results['plausibility'].append(recourse.get_plausibility())\n",
    "                                            \n",
    "    results = pd.DataFrame(results)    \n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running FICO dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [8:32:36<00:00, 3075.69s/it] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running COMPAS dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [30:34<00:00, 183.41s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Adult dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [1:17:32<00:00, 465.21s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Bail dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [55:53<00:00, 335.36s/it]\n"
     ]
    }
   ],
   "source": [
    "results = []\n",
    "\n",
    "for dataset in [FicoDataset(), CompasDataset(), AdultDataset(), BailDataset()]:\n",
    "    result = run_comparison(dataset, n_iter=10, n_instances=50)\n",
    "    results.append(result) \n",
    "    \n",
    "results_comparison = pd.concat(results)\n",
    "results_comparison.to_csv('./results/results_comparison_rebuttal.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_table(results_comparison):\n",
    "        \n",
    "    results = {\n",
    "        'Vanilla': [],\n",
    "        'OAF': [],\n",
    "        'RABIT': [],\n",
    "    }\n",
    "    datasets = results_comparison['dataset'].unique()\n",
    "    for i, dataset in enumerate(datasets):\n",
    "        results_dataset = results_comparison[results_comparison['dataset'] == dataset]\n",
    "        for method in ['Vanilla', 'OAF', 'RABIT']:\n",
    "            results[method].append(results_dataset[results_dataset['method'] == method]['cost'].mean())\n",
    "\n",
    "    results = pd.DataFrame(results, index=datasets)\n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "|         |   FICO |   COMPAS |   Adult |   Bail |\n",
      "|:--------|-------:|---------:|--------:|-------:|\n",
      "| Vanilla |  0.354 |    0.188 |   0.340 |  0.420 |\n",
      "| OAF     |  0.293 |    0.158 |   0.297 |  0.320 |\n",
      "| RABIT   |  0.110 |    0.096 |   0.257 |  0.213 |\n"
     ]
    }
   ],
   "source": [
    "print(make_table(results_comparison).T.to_markdown(floatfmt=\".3f\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Exp 8. Ablation Study"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_weights_rebuttal(dataset, n_iter=10, n_estimators=100, cost_budget=0.2):\n",
    "    \n",
    "    results = {\n",
    "        'dataset': [],\n",
    "        'method': [],\n",
    "        'n_estimators': [],\n",
    "        'eta': [],\n",
    "        'cost_budget': [],\n",
    "        'time': [],\n",
    "        'accuracy': [],\n",
    "        'f1': [],\n",
    "        'AUC': [],\n",
    "        'recourse': [],\n",
    "        'validity': [],\n",
    "        'cost': [],\n",
    "    }\n",
    "\n",
    "    print('Running {} dataset'.format(dataset.name))\n",
    "    for _ in tqdm(range(n_iter)):\n",
    "\n",
    "        X_tr, X_ts, y_tr, y_ts = dataset.get_dataset(split=True)\n",
    "        action = Action(dataset.params, cost_budget=cost_budget)\n",
    "        action = action.fit(X_tr, y_tr)\n",
    "\n",
    "        for method in ['Vanilla', 'OAF', 'RABIT']:\n",
    "            gamma = 0.002 if method == 'RABIT' else 0.0\n",
    "            estimator = RecourseBoostingClassifier(action, n_estimators=n_estimators, gamma=gamma, only_actionable_features=(method == 'OAF'))        \n",
    "            start_time = time()\n",
    "            estimator = estimator.fit(X_tr, y_tr)            \n",
    "            elapsed_time = time() - start_time\n",
    "            explainer = RecourseExplainer(estimator, action)\n",
    "            recourse = explainer.explain_recourse(X_ts)\n",
    "            results['dataset'].append(dataset.name)\n",
    "            results['method'].append(method)\n",
    "            results['n_estimators'].append(n_estimators)\n",
    "            results['eta'].append(0.0)\n",
    "            results['cost_budget'].append(cost_budget)\n",
    "            results['time'].append(elapsed_time)\n",
    "            results['accuracy'].append(estimator.score(X_ts, y_ts))\n",
    "            results['f1'].append(f1_score(y_ts, estimator.predict(X_ts)))\n",
    "            results['AUC'].append(roc_auc_score(y_ts, estimator.predict_proba(X_ts)[:, 1]))\n",
    "            results['recourse'].append(recourse.get_recourse())\n",
    "            results['validity'].append(recourse.get_validity())\n",
    "            results['cost'].append(recourse.get_cost())\n",
    "\n",
    "            start_time = time()\n",
    "            X_cf = explainer.generate_recourse_calibration_samples(X_tr)                            \n",
    "            additional_elapsed_time1 = time() - start_time\n",
    "            for eta in [0.01, 0.02, 0.04, 0.08, 0.16, 0.32, 0.64, 1.28]:\n",
    "                start_time = time()\n",
    "                estimator = estimator.optimize_weights(X_tr, y_tr, X_cf, eta=eta)\n",
    "                additional_elapsed_time2 = time() - start_time\n",
    "                recourse = explainer.explain_recourse(X_ts)\n",
    "                results['dataset'].append(dataset.name)\n",
    "                results['method'].append(method)\n",
    "                results['n_estimators'].append(n_estimators)\n",
    "                results['eta'].append(eta)\n",
    "                results['cost_budget'].append(cost_budget)\n",
    "                results['time'].append(elapsed_time + additional_elapsed_time1 + additional_elapsed_time2)\n",
    "                results['accuracy'].append(estimator.score(X_ts, y_ts))\n",
    "                results['f1'].append(f1_score(y_ts, estimator.predict(X_ts)))\n",
    "                results['AUC'].append(roc_auc_score(y_ts, estimator.predict_proba(X_ts)[:, 1]))\n",
    "                results['recourse'].append(recourse.get_recourse())\n",
    "                results['validity'].append(recourse.get_validity())\n",
    "                results['cost'].append(recourse.get_cost())\n",
    "            \n",
    "    results = pd.DataFrame(results)    \n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running FICO dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [16:53<00:00, 101.34s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running COMPAS dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [08:18<00:00, 49.89s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Adult dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [1:03:59<00:00, 383.99s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Bail dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [07:07<00:00, 42.76s/it]\n"
     ]
    }
   ],
   "source": [
    "results = []\n",
    "\n",
    "for dataset in [FicoDataset(), CompasDataset(), AdultDataset(), BailDataset()]:\n",
    "    result = run_weights_rebuttal(dataset)\n",
    "    results.append(result) \n",
    "    \n",
    "results_weights = pd.concat(results)\n",
    "results_weights.to_csv('./results/results_weights_rebuttal.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_table(results_weights):\n",
    "    \n",
    "    results_accuracy = {\n",
    "        'Vanilla': [],\n",
    "        'RABIT': [],\n",
    "        'Vanilla w/ Refinement': [],\n",
    "        'RABIT w/ Refinement': [],\n",
    "    }\n",
    "    results_recourse = {\n",
    "        'Vanilla': [],\n",
    "        'RABIT': [],\n",
    "        'Vanilla w/ Refinement': [],\n",
    "        'RABIT w/ Refinement': [],\n",
    "    }\n",
    "    \n",
    "    datasets = results_weights['dataset'].unique()\n",
    "    for i, dataset in enumerate(datasets):\n",
    "        results_dataset = results_weights[results_weights['dataset'] == dataset]\n",
    "\n",
    "        results_vanilla = results_dataset[results_dataset['method'] == 'Vanilla'].groupby('eta')\n",
    "        results_rabit = results_dataset[results_dataset['method'] == 'RABIT'].groupby('eta')\n",
    "        X_vanilla = results_vanilla['accuracy'].mean().values\n",
    "        Y_vanilla = results_vanilla['recourse'].mean().values\n",
    "        X_rabit = results_rabit['accuracy'].mean().values\n",
    "        Y_rabit = results_rabit['recourse'].mean().values\n",
    "        \n",
    "        x_vanilla_wo = X_vanilla[0]\n",
    "        x_rabit_wo = X_rabit[0]\n",
    "        X_vanilla_w = X_vanilla[1:]\n",
    "        X_rabit_w = X_rabit[1:]\n",
    "        eta_vanilla = np.argmin(abs(x_rabit_wo - X_vanilla_w))\n",
    "        x_vanilla_w = X_vanilla_w[eta_vanilla]\n",
    "        eta_rabit = np.argmin(abs(x_rabit_wo - X_rabit_w))\n",
    "        x_rabit_w = X_rabit_w[eta_rabit]\n",
    "        \n",
    "        results_accuracy['Vanilla'].append(x_vanilla_wo)\n",
    "        results_accuracy['RABIT'].append(x_rabit_wo)\n",
    "        results_accuracy['Vanilla w/ Refinement'].append(x_vanilla_w)\n",
    "        results_accuracy['RABIT w/ Refinement'].append(x_rabit_w)   \n",
    "        \n",
    "        y_vanilla_wo = Y_vanilla[0]\n",
    "        y_rabit_wo = Y_rabit[0]\n",
    "        y_vanilla_w = Y_vanilla[1:][eta_vanilla]\n",
    "        y_rabit_w = Y_rabit[1:][eta_rabit]\n",
    "        \n",
    "        results_recourse['Vanilla'].append(y_vanilla_wo)\n",
    "        results_recourse['RABIT'].append(y_rabit_wo)\n",
    "        results_recourse['Vanilla w/ Refinement'].append(y_vanilla_w)\n",
    "        results_recourse['RABIT w/ Refinement'].append(y_rabit_w)\n",
    "\n",
    "    results_accuracy = pd.DataFrame(results_accuracy, index=datasets)\n",
    "    results_recourse = pd.DataFrame(results_recourse, index=datasets)\n",
    "    return results_accuracy, results_recourse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "|                       |   FICO |   COMPAS |   Adult |   Bail |\n",
      "|:----------------------|-------:|---------:|--------:|-------:|\n",
      "| Vanilla               |  0.735 |    0.682 |   0.852 |  0.711 |\n",
      "| RABIT                 |  0.732 |    0.677 |   0.851 |  0.701 |\n",
      "| Vanilla w/ Refinement |  0.708 |    0.676 |   0.853 |  0.702 |\n",
      "| RABIT w/ Refinement   |  0.716 |    0.677 |   0.850 |  0.706 |\n"
     ]
    }
   ],
   "source": [
    "results_accuracy, results_recourse = make_table(results_weights)\n",
    "\n",
    "print(results_accuracy.T.to_markdown(floatfmt=\".3f\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "|                       |   FICO |   COMPAS |   Adult |   Bail |\n",
      "|:----------------------|-------:|---------:|--------:|-------:|\n",
      "| Vanilla               |  0.541 |    0.832 |   0.274 |  0.364 |\n",
      "| RABIT                 |  0.825 |    0.936 |   0.427 |  0.628 |\n",
      "| Vanilla w/ Refinement |  0.694 |    0.887 |   0.685 |  0.840 |\n",
      "| RABIT w/ Refinement   |  0.859 |    0.986 |   0.847 |  0.868 |\n"
     ]
    }
   ],
   "source": [
    "print(results_recourse.T.to_markdown(floatfmt=\".3f\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "3.10.10",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
