{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "from time import time\n",
    "from sklearn.metrics import f1_score, roc_auc_score\n",
    "np.random.seed(0)\n",
    "\n",
    "from rabit import Action, RecourseBoostingClassifier, RecourseExplainer\n",
    "from rabit.datasets import FicoDataset, CompasDataset, AdultDataset, BailDataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Exp 1. Baseline Comparison and Trade-off Analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_comparison(dataset, n_iter=10, n_estimators=100, cost_budget=0.2):\n",
    "    \n",
    "    results = {\n",
    "        'dataset': [],\n",
    "        'method': [],\n",
    "        'n_estimators': [],\n",
    "        'gamma': [],\n",
    "        'cost_budget': [],\n",
    "        'time': [],\n",
    "        'accuracy': [],\n",
    "        'f1': [],\n",
    "        'AUC': [],\n",
    "        'recourse': [],\n",
    "        'validity': [],\n",
    "        'cost': [],\n",
    "        'sparsity': [],\n",
    "        'plausibility': [],\n",
    "        'unfairness': [],\n",
    "    }\n",
    "    \n",
    "    print('Running {} dataset'.format(dataset.name))\n",
    "    for _ in tqdm(range(n_iter)):\n",
    "\n",
    "        X_tr, X_ts, y_tr, y_ts = dataset.get_dataset(split=True)\n",
    "        action = Action(dataset.params, cost_budget=cost_budget)\n",
    "        action = action.fit(X_tr, y_tr)\n",
    "        \n",
    "        done_vanilla = False\n",
    "        for gamma in [0.0, 0.0, 0.0005, 0.001, 0.0015, 0.002, 0.0025, 0.003, 0.0035, 0.004, 0.0045, 0.005]:\n",
    "            results['dataset'].append(dataset.name)\n",
    "            if gamma == 0.0:\n",
    "                if done_vanilla:\n",
    "                    method = 'OAF'\n",
    "                else:\n",
    "                    method = 'Vanilla'\n",
    "                    done_vanilla = True\n",
    "            else:\n",
    "                method = 'RABIT'\n",
    "            results['method'].append(method)\n",
    "            results['n_estimators'].append(n_estimators)\n",
    "            results['gamma'].append(gamma)\n",
    "            results['cost_budget'].append(cost_budget)\n",
    "\n",
    "            start = time()\n",
    "            estimator = RecourseBoostingClassifier(action, n_estimators=n_estimators, gamma=gamma, only_actionable_features=(method == 'OAF'))        \n",
    "            estimator = estimator.fit(X_tr, y_tr)\n",
    "            results['time'].append(time() - start)\n",
    "            results['accuracy'].append(estimator.score(X_ts, y_ts))\n",
    "            results['f1'].append(f1_score(y_ts, estimator.predict(X_ts)))\n",
    "            results['AUC'].append(roc_auc_score(y_ts, estimator.predict_proba(X_ts)[:, 1]))\n",
    "            \n",
    "            explainer = RecourseExplainer(estimator, action)\n",
    "            recourse = explainer.explain_recourse(X_ts)\n",
    "            results['recourse'].append(recourse.get_recourse())\n",
    "            results['validity'].append(recourse.get_validity())\n",
    "            results['cost'].append(recourse.get_cost())\n",
    "            results['sparsity'].append(recourse.get_sparsity())\n",
    "            results['plausibility'].append(recourse.get_plausibility())\n",
    "            \n",
    "            if dataset.name in ['COMPAS', 'Adult']:\n",
    "                sensitive_indices = dataset.get_sensitive_indices()\n",
    "                results['unfairness'].append(recourse.get_unfairness(sensitive_indices))\n",
    "            else:\n",
    "                results['unfairness'].append(0.0)\n",
    "                                \n",
    "    results = pd.DataFrame(results)    \n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running FICO dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [27:40<00:00, 166.06s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running COMPAS dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [06:36<00:00, 39.70s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Adult dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [1:17:20<00:00, 464.02s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Bail dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [08:52<00:00, 53.24s/it]\n"
     ]
    }
   ],
   "source": [
    "results = []\n",
    "\n",
    "for dataset in [FicoDataset(), CompasDataset(), AdultDataset(), BailDataset()]:\n",
    "    result = run_comparison(dataset)\n",
    "    results.append(result) \n",
    "    \n",
    "results_comparison = pd.concat(results)\n",
    "results_comparison.to_csv('./results/results_comparison.csv', index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Exp 2. Leaf Refinement"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_weights(dataset, n_iter=10, n_estimators=100, cost_budget=0.2):\n",
    "    \n",
    "    results = {\n",
    "        'dataset': [],\n",
    "        'method': [],\n",
    "        'n_estimators': [],\n",
    "        'eta': [],\n",
    "        'cost_budget': [],\n",
    "        'accuracy': [],\n",
    "        'f1': [],\n",
    "        'AUC': [],\n",
    "        'recourse': [],\n",
    "        'validity': [],\n",
    "        'cost': [],\n",
    "    }\n",
    "\n",
    "    print('Running {} dataset'.format(dataset.name))\n",
    "    for _ in tqdm(range(n_iter)):\n",
    "\n",
    "        X_tr, X_vl, X_ts, y_tr, y_vl, y_ts = dataset.get_dataset(split=True, test_size=0.25, validation_size=0.25)\n",
    "        sensitive_indices = dataset.get_sensitive_indices()\n",
    "        action = Action(dataset.params, cost_budget=cost_budget)\n",
    "        action = action.fit(X_tr, y_tr)\n",
    "\n",
    "        for method in ['Vanilla', 'OAF', 'RABIT']:\n",
    "            gamma = 0.002 if method == 'RABIT' else 0.0\n",
    "            estimator = RecourseBoostingClassifier(action, n_estimators=n_estimators, gamma=gamma, only_actionable_features=(method == 'OAF'))        \n",
    "            estimator = estimator.fit(X_tr, y_tr)            \n",
    "            explainer = RecourseExplainer(estimator, action)\n",
    "            X_cf = explainer.generate_recourse_calibration_samples(X_vl)\n",
    "                            \n",
    "            for eta in [0.01, 0.02, 0.04, 0.08, 0.16, 0.32, 0.64, 1.28]:\n",
    "                estimator = estimator.optimize_weights(X_vl, y_vl, X_cf, eta=eta)\n",
    "                recourse = explainer.explain_recourse(X_ts)\n",
    "                results['dataset'].append(dataset.name)\n",
    "                results['method'].append(method)\n",
    "                results['n_estimators'].append(n_estimators)\n",
    "                results['eta'].append(eta)\n",
    "                results['cost_budget'].append(cost_budget)\n",
    "                results['accuracy'].append(estimator.score(X_ts, y_ts))\n",
    "                results['f1'].append(f1_score(y_ts, estimator.predict(X_ts)))\n",
    "                results['AUC'].append(roc_auc_score(y_ts, estimator.predict_proba(X_ts)[:, 1]))\n",
    "                results['recourse'].append(recourse.get_recourse())\n",
    "                results['validity'].append(recourse.get_validity())\n",
    "                results['cost'].append(recourse.get_cost())\n",
    "            \n",
    "    results = pd.DataFrame(results)    \n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running FICO dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [12:22<00:00, 74.30s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running COMPAS dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [06:57<00:00, 41.77s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Adult dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [42:25<00:00, 254.51s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Bail dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [05:52<00:00, 35.28s/it]\n"
     ]
    }
   ],
   "source": [
    "results = []\n",
    "\n",
    "for dataset in [FicoDataset(), CompasDataset(), AdultDataset(), BailDataset()]:\n",
    "    result = run_weights(dataset)\n",
    "    results.append(result) \n",
    "    \n",
    "results_weights = pd.concat(results)\n",
    "results_weights.to_csv('./results/results_weights.csv', index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Exp 3. Sensitivity Analyses (Appendix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running FICO dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [16:20<00:00, 98.08s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running FICO dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [41:54<00:00, 251.41s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running FICO dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [54:34<00:00, 327.49s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running FICO dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [1:09:52<00:00, 419.22s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running COMPAS dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [04:48<00:00, 28.90s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running COMPAS dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [10:08<00:00, 60.89s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running COMPAS dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [12:29<00:00, 74.99s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running COMPAS dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [14:40<00:00, 88.07s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Adult dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [44:31<00:00, 267.13s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Adult dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [1:50:06<00:00, 660.69s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Adult dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [2:25:08<00:00, 870.89s/it] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Adult dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [2:59:20<00:00, 1076.09s/it] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Bail dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [05:05<00:00, 30.58s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Bail dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [11:40<00:00, 70.09s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Bail dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [15:15<00:00, 91.51s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Bail dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [19:45<00:00, 118.55s/it]\n"
     ]
    }
   ],
   "source": [
    "results = []\n",
    "\n",
    "for dataset in [FicoDataset(), CompasDataset(), AdultDataset(), BailDataset()]:\n",
    "    for n_estimators in [50, 150, 200, 250]:\n",
    "        result = run_comparison(dataset, n_estimators=n_estimators)\n",
    "        results.append(result) \n",
    "    \n",
    "results_trees = pd.concat(results)\n",
    "results_trees = pd.concat([results_trees, results_comparison])\n",
    "results_trees.to_csv('./results/results_trees.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running FICO dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [29:25<00:00, 176.58s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running FICO dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [28:25<00:00, 170.59s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running FICO dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [28:00<00:00, 168.08s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running FICO dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [28:39<00:00, 171.98s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running COMPAS dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [07:13<00:00, 43.38s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running COMPAS dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [07:03<00:00, 42.39s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running COMPAS dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [06:58<00:00, 41.90s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running COMPAS dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [06:54<00:00, 41.46s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Adult dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [1:17:04<00:00, 462.49s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Adult dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [1:15:09<00:00, 450.93s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Adult dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [1:15:57<00:00, 455.72s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Adult dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [1:17:28<00:00, 464.83s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Bail dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [08:37<00:00, 51.77s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Bail dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [08:33<00:00, 51.36s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Bail dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [08:33<00:00, 51.39s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Bail dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [08:30<00:00, 51.03s/it]\n"
     ]
    }
   ],
   "source": [
    "results = []\n",
    "\n",
    "for dataset in [FicoDataset(), CompasDataset(), AdultDataset(), BailDataset()]:\n",
    "    for cost_budget in [0.1, 0.3, 0.4, 0.5]:\n",
    "        result = run_comparison(dataset, cost_budget=cost_budget)\n",
    "        results.append(result) \n",
    "    \n",
    "results_budget = pd.concat(results)\n",
    "results_budget = pd.concat([results_budget, results_comparison])\n",
    "results_budget.to_csv('./results/results_budget.csv', index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Exp 4. Intercept Adjustment (Appendix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_intercept(dataset, n_iter=10, n_estimators=100, cost_budget=0.2):\n",
    "    \n",
    "    results = {\n",
    "        'dataset': [],\n",
    "        'method': [],\n",
    "        'n_estimators': [],\n",
    "        'epsilon': [],\n",
    "        'cost_budget': [],\n",
    "        'accuracy': [],\n",
    "        'f1': [],\n",
    "        'AUC': [],\n",
    "        'recourse': [],\n",
    "        'validity': [],\n",
    "        'cost': [],\n",
    "    }\n",
    "\n",
    "    print('Running {} dataset'.format(dataset.name))\n",
    "    for _ in tqdm(range(n_iter)):\n",
    "\n",
    "        X_tr, X_vl, X_ts, y_tr, y_vl, y_ts = dataset.get_dataset(split=True, test_size=0.25, validation_size=0.25)\n",
    "        action = Action(dataset.params, cost_budget=cost_budget)\n",
    "        action = action.fit(X_tr, y_tr)\n",
    "\n",
    "        for method in ['Vanilla', 'OAF', 'RABIT']:\n",
    "            gamma = 0.002 if method == 'RABIT' else 0.0\n",
    "            estimator = RecourseBoostingClassifier(action, n_estimators=n_estimators, gamma=gamma, only_actionable_features=(method == 'OAF'))        \n",
    "            estimator = estimator.fit(X_tr, y_tr)            \n",
    "            explainer = RecourseExplainer(estimator, action)\n",
    "            X_cf = explainer.generate_recourse_calibration_samples(X_vl)\n",
    "\n",
    "            for epsilon in [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5]:\n",
    "                estimator = estimator.optimize_intercept(X_cf, epsilon=epsilon)\n",
    "                recourse = explainer.explain_recourse(X_ts)\n",
    "                results['dataset'].append(dataset.name)\n",
    "                results['method'].append(method)\n",
    "                results['n_estimators'].append(n_estimators)\n",
    "                results['epsilon'].append(epsilon)\n",
    "                results['cost_budget'].append(cost_budget)\n",
    "                results['accuracy'].append(estimator.score(X_ts, y_ts))\n",
    "                results['f1'].append(f1_score(y_ts, estimator.predict(X_ts)))\n",
    "                results['AUC'].append(roc_auc_score(y_ts, estimator.predict_proba(X_ts)[:, 1]))\n",
    "                results['recourse'].append(recourse.get_recourse())\n",
    "                results['validity'].append(recourse.get_validity())\n",
    "                results['cost'].append(recourse.get_cost())\n",
    "                estimator.intercept_ = 0.0\n",
    "            \n",
    "    results = pd.DataFrame(results)    \n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running FICO dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [14:06<00:00, 84.68s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running COMPAS dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [07:46<00:00, 46.61s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Adult dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [43:41<00:00, 262.12s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Bail dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [05:22<00:00, 32.30s/it]\n"
     ]
    }
   ],
   "source": [
    "results = []\n",
    "\n",
    "for dataset in [FicoDataset(), CompasDataset(), AdultDataset(), BailDataset()]:\n",
    "    result = run_intercept(dataset)\n",
    "    results.append(result) \n",
    "    \n",
    "results_intercept = pd.concat(results)\n",
    "results_intercept.to_csv('./results/results_intercept.csv', index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Exp 5. Brittleness Analysis (Appendix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_brittleness(estimator, X, var, params, n_repeat=100):\n",
    "    X_rep = np.repeat(X, n_repeat, axis=0)\n",
    "    y_rep = np.repeat(estimator.predict(X), n_repeat)\n",
    "    pert = np.random.multivariate_normal(np.zeros(X.shape[1]), 0.1 * np.diag(var), size=(X_rep.shape[0],))\n",
    "    pert[:, params['is_immutable']] = 0.0\n",
    "    pert[:, params['is_unincreasable']] = np.clip(pert[:, params['is_unincreasable']], None, 0.0)\n",
    "    pert[:, params['is_irreducible']] = np.clip(pert[:, params['is_irreducible']], 0.0, None)\n",
    "    return (estimator.predict(X_rep + pert) != y_rep).mean()\n",
    "\n",
    "\n",
    "def run_brittleness(dataset, n_iter=10, n_estimators=100, cost_budget=0.2):\n",
    "    \n",
    "    results = {\n",
    "        'dataset': [],\n",
    "        'method': [],\n",
    "        'gamma': [],\n",
    "        'brittleness': []\n",
    "    }\n",
    "\n",
    "    print('Running {} dataset'.format(dataset.name))\n",
    "    for _ in tqdm(range(n_iter)):\n",
    "\n",
    "        X_tr, X_ts, y_tr, y_ts = dataset.get_dataset(split=True, test_size=0.25)\n",
    "        action = Action(dataset.params, cost_budget=cost_budget)\n",
    "        action = action.fit(X_tr, y_tr)\n",
    "\n",
    "        for method in ['Vanilla', 'OAF', 'RABIT']:\n",
    "            gamma = 0.002 if method == 'RABIT' else 0.0\n",
    "            results['dataset'].append(dataset.name)\n",
    "            results['method'].append(method)\n",
    "            results['gamma'].append(gamma)\n",
    "\n",
    "            estimator = RecourseBoostingClassifier(action, n_estimators=n_estimators, gamma=gamma, only_actionable_features=(method == 'OAF'))        \n",
    "            estimator = estimator.fit(X_tr, y_tr)\n",
    "            brittleness = get_brittleness(estimator, X_ts, X_tr.var(axis=0), dataset.params)\n",
    "            results['brittleness'].append(brittleness)\n",
    "            \n",
    "    results = pd.DataFrame(results)    \n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running FICO dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [07:45<00:00, 46.59s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running COMPAS dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [01:32<00:00,  9.25s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Adult dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [15:53<00:00, 95.39s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Bail dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [02:19<00:00, 13.97s/it]\n"
     ]
    }
   ],
   "source": [
    "results = []\n",
    "\n",
    "for dataset in [FicoDataset(), CompasDataset(), AdultDataset(), BailDataset()]:\n",
    "    result = run_brittleness(dataset)\n",
    "    results.append(result) \n",
    "    \n",
    "results_brittleness = pd.concat(results)\n",
    "results_brittleness.to_csv('./results/results_brittleness.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "3.10.10",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
