{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tqdm\n",
    "import pickle\n",
    "import warnings\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from copy import deepcopy\n",
    "import plotly.express as px\n",
    "import plotly.graph_objects as go\n",
    "\n",
    "from data import *\n",
    "from model import *\n",
    "from utils import *\n",
    "from recourse_model import LAROAR, ROAR, RecourseCost\n",
    "\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "def append_res(d, alg, seed, alpha, lamb, i, x_0, theta_0, beta, x_r, theta_r, p, theta_p, J_r, J_c, robustness, consistency):\n",
    "    d['alg'].append(alg)\n",
    "    d['seed'].append(seed)\n",
    "    d['alpha'].append(alpha)\n",
    "    d['lambda'].append(lamb)\n",
    "    d['i'].append(i)\n",
    "    d['x_0'].append(x_0.round(4))\n",
    "    d['theta_0'].append(theta_0.round(4))\n",
    "    d['beta'].append(beta)\n",
    "    d['x_r'].append(x_r.round(4))\n",
    "    d['theta_r'].append(theta_r.round(4))\n",
    "    d['p'].append(p)\n",
    "    d['theta_p'].append(theta_p.round(4))\n",
    "    d['J_r'].append(J_r)\n",
    "    d['J_c'].append(J_c)\n",
    "    d['robustness'].append(robustness)\n",
    "    d['consistency'].append(consistency)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "def recourse_model_runner(seed, X: np.ndarray, recourse_model1: LAROAR, recourse_model2: ROAR, predictions: List, X_train: np.ndarray, base_model: LR | NN):    \n",
    "    alpha = recourse_model1.alpha\n",
    "    lamb = recourse_model1.lamb\n",
    "    \n",
    "    results_opt = {'alg': [], 'seed': [], 'alpha': [], 'lambda': [], 'i': [], 'x_0': [], 'theta_0': [], 'beta': [], 'x_r': [], 'theta_r': [], 'p': [], 'theta_p': [], 'J_r': [], 'J_c': [], 'robustness': [], 'consistency': []}\n",
    "    results_roar = deepcopy(results_opt)\n",
    "    \n",
    "    n = len(X)\n",
    "    \n",
    "    weights_0, bias_0 = recourse_model1.weights, recourse_model1.bias\n",
    "    theta_0 = np.hstack((weights_0, bias_0))\n",
    "    for i in tqdm.trange(n, desc=f'Eval alpha={alpha}; lambda={lamb}', colour='#0091ff'):\n",
    "        x_0 = X[i]\n",
    "        J = RecourseCost(x_0, lamb)\n",
    "        \n",
    "        for p, prediction in enumerate(predictions):\n",
    "            weights_p, bias_p = prediction[:-1], prediction[[-1]]\n",
    "            theta_p = (weights_p, bias_p)\n",
    "        \n",
    "            # robustness\n",
    "            x_r = recourse_model1.get_recourse(x_0, beta=1.)\n",
    "            weights_r, bias_r = recourse_model1.calc_theta_adv(x_r)\n",
    "            opt_rob = J(x_r, weights_r, bias_r)\n",
    "            \n",
    "            # consistency\n",
    "            x_c = recourse_model1.get_recourse(x_0, beta=0., theta_p=theta_p)\n",
    "            opt_con = J(x_c, weights_p, bias_p)\n",
    "            \n",
    "            # OPT\n",
    "            betas = np.arange(0., 1.01, 0.01).round(4)\n",
    "            alg_num = 0\n",
    "            for beta in betas:\n",
    "                x = recourse_model1.get_recourse(x_0, beta=beta, theta_p=theta_p)\n",
    "                weights_r, bias_r = recourse_model1.calc_theta_adv(x)\n",
    "                theta_r = np.hstack((weights_r, bias_r))\n",
    "                \n",
    "                J_r = J(x, weights_r, bias_r)\n",
    "                J_c = J(x, weights_p, bias_p)\n",
    "                rob = J_r - opt_rob\n",
    "                con = J_c - opt_con\n",
    "                \n",
    "                append_res(results_opt, 'OPT', seed, alpha, lamb, i, x_0, theta_0, beta, x, theta_r, p, np.hstack(theta_p), J_r[0], J_c[0], rob[0], con[0])\n",
    "                \n",
    "                # convex combination\n",
    "                for c in np.arange(0., 1.01, 0.01).round(2):\n",
    "                    x_b = c*x + (1-c)*x_c\n",
    "                    \n",
    "                    weights_r, bias_r = recourse_model1.calc_theta_adv(x_b)\n",
    "                    theta_r = np.hstack((weights_r, bias_r))\n",
    "                    \n",
    "                    J_r = J(x_b, weights_r, bias_r)\n",
    "                    J_c = J(x_b, weights_p, bias_p)\n",
    "                    rob = J_r - opt_rob\n",
    "                    con = J_c - opt_con\n",
    "                    \n",
    "                    append_res(results_opt, f'OPT{alg_num}', seed, alpha, lamb, i, x_0, theta_0, c, x_b, theta_r, p, np.hstack(theta_p), J_r[0], J_c[0], rob[0], con[0])\n",
    "                    \n",
    "                    x_b = c*x_r + (1-c)*x\n",
    "                    weights_r, bias_r = recourse_model1.calc_theta_adv(x_b)\n",
    "                    theta_r = np.hstack((weights_r, bias_r))\n",
    "                    \n",
    "                    J_r = J(x_b, weights_r, bias_r)\n",
    "                    J_c = J(x_b, weights_p, bias_p)\n",
    "                    rob = J_r - opt_rob\n",
    "                    con = J_c - opt_con\n",
    "                    \n",
    "                    append_res(results_opt, f'OPT{alg_num+1}', seed, alpha, lamb, i, x_0, theta_0, c, x_b, theta_r, p, np.hstack(theta_p), J_r[0], J_c[0], rob[0], con[0])\n",
    "                \n",
    "                alg_num += 2\n",
    "            \n",
    "            # ROAR\n",
    "            x_r2, _ = recourse_model2.get_recourse(x_0)\n",
    "            weights_r2, bias_r2 = recourse_model1.calc_theta_adv(x_r2)\n",
    "            theta_r2 = np.hstack((weights_r2, bias_r2))\n",
    "            J_r2 = J(x_r2, weights_r2, bias_r2)\n",
    "            J_c2 = J(x_r2, weights_p, bias_p)\n",
    "            rob2 = J_r2 - opt_rob\n",
    "            con2 = J_c2 - opt_con\n",
    "            \n",
    "            append_res(results_roar, 'ROAR', seed, alpha, lamb, i, x_0, theta_0, 1., x_r2, theta_r2, p, np.hstack(theta_p), J_r2[0], J_c2[0], rob2[0], con2[0])\n",
    "        \n",
    "    df_hist = pd.concat((pd.DataFrame(results_opt), pd.DataFrame(results_roar)))\n",
    "    df_results = df_hist.groupby(['alg', 'beta', 'p'], as_index=False).mean(True)\n",
    "        \n",
    "    return df_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "def recourse_tradeoff(params: dict, seeds):\n",
    "    if params['data'] in ['correction', 'german']:\n",
    "        data_model = CorrectionShift(\"../datasets/german.csv\", \"../datasets/corrected_german.csv\", seed=0)\n",
    "    elif params['data'] in ['temporal', 'business']:\n",
    "        data_model = TemporalShift(\"../datasets/SBAcase.11.13.17.csv\", seed=0)\n",
    "    elif params['data'] in ['geospatial', 'student']:\n",
    "        data_model = GeospatialShift(\"../datasets/student-por.csv\", seed=0)\n",
    "    elif params['data'] in ['synthetic', 'simulated']:\n",
    "        data_model = SyntheticData(alpha=1.5, beta=0, n=1000, seed=0)\n",
    "    alpha = params['alpha']\n",
    "    \n",
    "    results = []\n",
    "        \n",
    "    for seed in seeds:\n",
    "        data1, data2 = data_model.get_data(seed)\n",
    "        X1_train, y1_train, X1_test, y1_test = data1\n",
    "\n",
    "        base_model = LR()\n",
    "        base_model.train(X1_train.values, y1_train.values)\n",
    "        \n",
    "        weights_0 = base_model.model.coef_[0]\n",
    "        bias_0 = base_model.model.intercept_\n",
    "        theta_0 = np.hstack((weights_0, bias_0))\n",
    "        \n",
    "        if seed == 0:\n",
    "            theta_preds_near = np.load(f'../theta_preds/theta_preds_{params[\"base_model\"]}_{params[\"data\"]}.npy')\n",
    "            predictions = [theta_0]\n",
    "            for theta_p in theta_preds_near:\n",
    "                alphas = theta_p - theta_0\n",
    "                theta = theta_0 - alphas\n",
    "                predictions.append(theta_p)\n",
    "                predictions.append(theta)\n",
    "            \n",
    "        recourse_needed_X1_train = recourse_needed(base_model.predict, X1_train.values)\n",
    "        recourse_needed_X1_test = recourse_needed(base_model.predict, X1_test.values)\n",
    "        \n",
    "        weights, bias = weights_0, bias_0\n",
    "        \n",
    "        recourse_model1 = LAROAR(\n",
    "            weights = weights,\n",
    "            bias = bias,\n",
    "            alpha = alpha,\n",
    "        )\n",
    "        \n",
    "        recourse_model2 = ROAR(\n",
    "            weights = weights,\n",
    "            bias = bias,\n",
    "            alpha = alpha,\n",
    "        )\n",
    "        \n",
    "        lamb = recourse_model1.choose_lambda(recourse_needed_X1_train, base_model.predict, X1_train.values)\n",
    "        recourse_model1.lamb = lamb\n",
    "        recourse_model2.lamb = lamb\n",
    "        \n",
    "        df_results = recourse_model_runner(seed, recourse_needed_X1_test, recourse_model1, recourse_model2, predictions, X1_train.values, base_model)\n",
    "        results.append(df_results)\n",
    "\n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Eval alpha=0.5; lambda=1.0: 100%|\u001b[38;2;0;145;255m██████████\u001b[0m| 39/39 [14:55<00:00, 22.95s/it]\n",
      "Eval alpha=0.5; lambda=1.0: 100%|\u001b[38;2;0;145;255m██████████\u001b[0m| 36/36 [14:13<00:00, 23.71s/it]\n",
      "Eval alpha=0.5; lambda=1.0: 100%|\u001b[38;2;0;145;255m██████████\u001b[0m| 40/40 [15:32<00:00, 23.31s/it]\n",
      "Eval alpha=0.5; lambda=1.0: 100%|\u001b[38;2;0;145;255m██████████\u001b[0m| 36/36 [13:36<00:00, 22.67s/it]\n",
      "Eval alpha=0.5; lambda=1.0: 100%|\u001b[38;2;0;145;255m██████████\u001b[0m| 38/38 [15:29<00:00, 24.46s/it]\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(0)\n",
    "params = {}\n",
    "# 'synthetic/simulated', 'correction/german', 'temporal/business', 'geospatial/student'\n",
    "params['data'] = 'temporal'\n",
    "# 'lr', 'nn\n",
    "params['base_model'] = 'lr'\n",
    "params['alpha'] = 0.5\n",
    "\n",
    "seeds = range(5)\n",
    "\n",
    "results = recourse_tradeoff(params, seeds)\n",
    "df_results = pd.concat(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_results.to_pickle(f'../results/tradeoff_{params[\"base_model\"]}_{params[\"data\"]}.pkl')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
