{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tqdm\n",
    "import pickle\n",
    "import warnings\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from copy import deepcopy\n",
    "import plotly.express as px\n",
    "import plotly.graph_objects as go\n",
    "\n",
    "from data import *\n",
    "from model import *\n",
    "from utils import *\n",
    "from recourse_model import LAROAR, ROAR, RecourseCost\n",
    "\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def append_res(d, alg, seed, alpha, lamb, i, x_0, theta_0, beta, x_r, theta_r, p, theta_p, J_r, J_c, robustness, consistency):\n",
    "    d['alg'].append(alg)\n",
    "    d['seed'].append(seed)\n",
    "    d['alpha'].append(alpha)\n",
    "    d['lambda'].append(lamb)\n",
    "    d['i'].append(i)\n",
    "    d['x_0'].append(x_0.round(4))\n",
    "    d['theta_0'].append(theta_0.round(4))\n",
    "    d['beta'].append(beta)\n",
    "    d['x_r'].append(x_r.round(4))\n",
    "    d['theta_r'].append(theta_r.round(4))\n",
    "    d['p'].append(p)\n",
    "    d['theta_p'].append(theta_p.round(4))\n",
    "    d['J_r'].append(J_r)\n",
    "    d['J_c'].append(J_c)\n",
    "    d['robustness'].append(robustness)\n",
    "    d['consistency'].append(consistency)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def recourse_model_runner(seed, X: np.ndarray, recourse_model1: LAROAR, recourse_model2: ROAR, predictions: List, X_train: np.ndarray, base_model: LR | NN):    \n",
    "    alpha = recourse_model1.alpha\n",
    "    lamb = recourse_model1.lamb\n",
    "    \n",
    "    results_opt = {'alg': [], 'seed': [], 'alpha': [], 'lambda': [], 'i': [], 'x_0': [], 'theta_0': [], 'beta': [], 'x_r': [], 'theta_r': [], 'p': [], 'theta_p': [], 'J_r': [], 'J_c': [], 'robustness': [], 'consistency': []}\n",
    "    results_roar = deepcopy(results_opt)\n",
    "    \n",
    "    n = len(X)\n",
    "\n",
    "    for i in tqdm.trange(n, desc=f'Eval alpha={alpha}; lambda={lamb}', colour='#0091ff'):\n",
    "        x_0 = X[i]\n",
    "        J = RecourseCost(x_0, lamb)\n",
    "        \n",
    "        # lime approximation of original NN\n",
    "        np.random.seed(i)\n",
    "        weights_0, bias_0 = lime_explanation(base_model.predict, X_train, x_0)\n",
    "        weights_0, bias_0 = np.round(weights_0, 4), np.round(bias_0, 4)\n",
    "        theta_0 = np.hstack((weights_0, bias_0))\n",
    "        \n",
    "        recourse_model1.weights = weights_0\n",
    "        recourse_model1.bias = bias_0\n",
    "        \n",
    "        recourse_model2.set_weights(weights_0)\n",
    "        recourse_model2.set_bias(bias_0)\n",
    "        \n",
    "        for p, prediction in enumerate(predictions):\n",
    "            # lime approximation of prediction NN\n",
    "            np.random.seed(p)\n",
    "            weights_p, bias_p = lime_explanation(prediction.predict, X_train, x_0)\n",
    "            weights_p = weights_p.clip(weights_0-alpha, weights_0+alpha).round(4)\n",
    "            bias_p = bias_p.clip(bias_0-alpha, bias_0+alpha).round(4)\n",
    "            theta_p = (weights_p, bias_p)\n",
    "        \n",
    "            # robust recourse\n",
    "            x_r = recourse_model1.get_recourse(x_0, beta=1., theta_p=theta_p)\n",
    "            weights_r, bias_r = recourse_model1.calc_theta_adv(x_r)\n",
    "            opt_rob = J(x_r, weights_r, bias_r)\n",
    "            \n",
    "            # consistent recourse\n",
    "            x_c = recourse_model1.get_recourse(x_0, beta=0., theta_p=theta_p)\n",
    "            opt_con = J(x_c, weights_p, bias_p)\n",
    "            \n",
    "            # OPT\n",
    "            betas = np.arange(0., 1.01, 0.01).round(2)\n",
    "            alg_num = 0\n",
    "            for beta in betas:\n",
    "                x = recourse_model1.get_recourse(x_0, beta=beta, theta_p=theta_p)\n",
    "                weights_r, bias_r = recourse_model1.calc_theta_adv(x)\n",
    "                theta_r = np.hstack((weights_r, bias_r))\n",
    "                \n",
    "                J_r = J(x, weights_r, bias_r)\n",
    "                J_c = J(x, weights_p, bias_p)\n",
    "                rob = J_r - opt_rob\n",
    "                con = J_c - opt_con\n",
    "                \n",
    "                append_res(results_opt, 'OPT', seed, alpha, lamb, i, x_0, theta_0, beta, x, theta_r, p, np.hstack(theta_p), J_r[0], J_c[0], rob[0], con[0])\n",
    "                \n",
    "                # convex combination\n",
    "                for c in np.arange(0., 1.01, 0.01).round(2):\n",
    "                    x_b = c*x + (1-c)*x_c\n",
    "                    \n",
    "                    weights_r, bias_r = recourse_model1.calc_theta_adv(x_b)\n",
    "                    theta_r = np.hstack((weights_r, bias_r))\n",
    "                    \n",
    "                    J_r = J(x_b, weights_r, bias_r)\n",
    "                    J_c = J(x_b, weights_p, bias_p)\n",
    "                    rob = J_r - opt_rob\n",
    "                    con = J_c - opt_con\n",
    "                    \n",
    "                    append_res(results_opt, f'OPT{alg_num}', seed, alpha, lamb, i, x_0, theta_0, c, x_b, theta_r, p, np.hstack(theta_p), J_r[0], J_c[0], rob[0], con[0])\n",
    "                    \n",
    "                    x_b = c*x_r + (1-c)*x\n",
    "                    weights_r, bias_r = recourse_model1.calc_theta_adv(x_b)\n",
    "                    theta_r = np.hstack((weights_r, bias_r))\n",
    "                    \n",
    "                    J_r = J(x_b, weights_r, bias_r)\n",
    "                    J_c = J(x_b, weights_p, bias_p)\n",
    "                    rob = J_r - opt_rob\n",
    "                    con = J_c - opt_con\n",
    "                    \n",
    "                    append_res(results_opt, f'OPT{alg_num+1}', seed, alpha, lamb, i, x_0, theta_0, c, x_b, theta_r, p, np.hstack(theta_p), J_r[0], J_c[0], rob[0], con[0])\n",
    "                \n",
    "                alg_num += 2\n",
    "                \n",
    "            # ROAR\n",
    "            x_r2, _ = recourse_model2.get_recourse(x_0)\n",
    "            weights_r2, bias_r2 = recourse_model1.calc_theta_adv(x_r2)\n",
    "            theta_r2 = np.hstack((weights_r2, bias_r2))\n",
    "            J_r2 = J(x_r2, weights_r2, bias_r2)\n",
    "            J_c2 = J(x_r2, weights_p, bias_p)\n",
    "            rob2 = J_r2 - opt_rob\n",
    "            con2 = J_c2 - opt_con\n",
    "            \n",
    "            append_res(results_roar, 'ROAR', seed, alpha, lamb, i, x_0, theta_0, 1., x_r2, theta_r2, p, np.hstack(theta_p), J_r2[0], J_c2[0], rob2[0], con2[0])\n",
    "        \n",
    "    df_hist = pd.concat((pd.DataFrame(results_opt), pd.DataFrame(results_roar)))\n",
    "    df_results = df_hist.groupby(['alg', 'beta', 'p'], as_index=False).mean(True)\n",
    "        \n",
    "    return df_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def recourse_tradeoff(params: dict, seeds, results):\n",
    "    if params['data'] in ['correction', 'german']:\n",
    "        data_model = CorrectionShift(\"../datasets/german.csv\", \"../datasets/corrected_german.csv\", seed=0)\n",
    "    elif params['data'] in ['temporal', 'business']:\n",
    "        data_model = TemporalShift(\"../datasets/SBAcase.11.13.17.csv\", seed=0)\n",
    "    elif params['data'] in ['geospatial', 'student']:\n",
    "        data_model = GeospatialShift(\"../datasets/student-por.csv\", seed=0)\n",
    "    elif params['data'] in ['synthetic', 'simulated']:\n",
    "        data_model = SyntheticData(alpha=1.5, beta=0, n=1000, seed=0)\n",
    "    \n",
    "    alpha = params['alpha']\n",
    "    \n",
    "    data1, data2 = data_model.get_data(0)\n",
    "    X1_train, y1_train, X1_test, y1_test = data1\n",
    "\n",
    "    base_model = NN(X1_train.shape[1])\n",
    "    base_model.train(X1_train.values, y1_train.values)\n",
    "    \n",
    "    generator = torch.Generator().manual_seed(0)\n",
    "    predictions = []\n",
    "    for seed in seeds:\n",
    "        data1, data2 = data_model.get_data(seed)\n",
    "        X1_train, y1_train, X1_test, y1_test = data1\n",
    "\n",
    "        base_model = NN(X1_train.shape[1])\n",
    "        base_model.train(X1_train.values, y1_train.values)\n",
    "        \n",
    "        if seed == 0:\n",
    "            for _ in range(5):\n",
    "                prediction = deepcopy(base_model)\n",
    "                for p, module in enumerate(prediction.model):\n",
    "                    if hasattr(module, 'weight'):\n",
    "                        module.weight.data += 0.07*torch.randint(-1, 2, module.weight.data.shape, generator=generator)\n",
    "                    if hasattr(module, 'bias'):\n",
    "                        module.bias.data += 0.07*torch.randint(-1, 2, module.bias.data.shape, generator=generator)\n",
    "                predictions.append(prediction)\n",
    "        \n",
    "        recourse_needed_X1_train = recourse_needed(base_model.predict, X1_train.values)\n",
    "        recourse_needed_X1_test = recourse_needed(base_model.predict, X1_test.values)\n",
    "            \n",
    "        weights, bias = None, None\n",
    "        \n",
    "        recourse_model1 = LAROAR(\n",
    "            weights = weights,\n",
    "            bias = bias,\n",
    "            alpha = alpha,\n",
    "        )\n",
    "    \n",
    "        recourse_model2 = ROAR(\n",
    "            weights = weights,\n",
    "            bias = bias,\n",
    "            alpha = alpha,\n",
    "        )\n",
    "        \n",
    "        lamb = recourse_model1.choose_lambda(recourse_needed_X1_train, base_model.predict, X1_train.values)\n",
    "        recourse_model1.lamb = lamb\n",
    "        recourse_model2.lamb = lamb\n",
    "        \n",
    "        df_results = recourse_model_runner(seed, recourse_needed_X1_test, recourse_model1, recourse_model2, predictions, X1_train.values, base_model)\n",
    "        results.append(df_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Choosing lambda\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 60%|██████    | 6/10 [00:31<00:20,  5.20s/it]\n",
      "Eval alpha=0.1; lambda=0.6: 100%|\u001b[38;2;0;145;255m██████████\u001b[0m| 39/39 [20:45<00:00, 31.94s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Choosing lambda\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|█         | 1/10 [02:23<21:31, 143.52s/it]\n",
      "Eval alpha=0.1; lambda=0.1: 100%|\u001b[38;2;0;145;255m██████████\u001b[0m| 36/36 [23:47<00:00, 39.65s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Choosing lambda\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|█         | 1/10 [01:42<15:22, 102.46s/it]\n",
      "Eval alpha=0.1; lambda=0.1: 100%|\u001b[38;2;0;145;255m██████████\u001b[0m| 40/40 [26:43<00:00, 40.10s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Choosing lambda\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|█         | 1/10 [01:01<09:10, 61.14s/it]\n",
      "Eval alpha=0.1; lambda=0.1: 100%|\u001b[38;2;0;145;255m██████████\u001b[0m| 36/36 [14:02<00:00, 23.41s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Choosing lambda\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|█         | 1/10 [00:11<01:45, 11.74s/it]\n",
      "Eval alpha=0.1; lambda=0.1: 100%|\u001b[38;2;0;145;255m██████████\u001b[0m| 38/38 [22:10<00:00, 35.01s/it]\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(0)\n",
    "params = {}\n",
    "# 'synthetic/simulated', 'correction/german', 'temporal/business', 'geospatial/student'\n",
    "params['data'] = 'temporal'\n",
    "# 'lr', 'nn\n",
    "params['base_model'] = 'nn'\n",
    "params['alpha'] = 0.1\n",
    "\n",
    "seeds = range(5)\n",
    "\n",
    "results = []\n",
    "recourse_tradeoff(params, seeds, results)\n",
    "df_results = pd.concat(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_results.to_pickle(f'../results/tradeoff_{params[\"base_model\"]}_{params[\"data\"]}.pkl')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
