{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tqdm\n",
    "import pickle\n",
    "import warnings\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from copy import deepcopy\n",
    "import plotly.express as px\n",
    "import plotly.graph_objects as go\n",
    "\n",
    "from data import *\n",
    "from model import *\n",
    "from utils import *\n",
    "from recourse_model import ROAR, LAROAR, RecourseCost, RobustRecourse\n",
    "\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def append_res(d, rob, loss, cost, m1_validity, wc_validity, m1_expectation, wc_expectation):\n",
    "    d['Cost'].append(cost)\n",
    "    d['M1 Validity'].append(m1_validity)\n",
    "    d['WC Validity'].append(wc_validity)\n",
    "    d['M1 Expectation'].append(m1_expectation)\n",
    "    d['WC Expectation'].append(wc_expectation)\n",
    "    d['J'].append(rob) \n",
    "    d['Loss'].append(loss)\n",
    "    \n",
    "def get_res(d, alg, seed, alpha, lamb):\n",
    "    result = {\n",
    "        'alg': alg, \n",
    "        'seed': seed,\n",
    "        'alpha': alpha,\n",
    "        'lambda': lamb}\n",
    "    \n",
    "    for key in d.keys():\n",
    "        result[key] = np.mean(d[key])\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def recourse_model_runner(X: np.ndarray, recourse_model1: LAROAR, recourse_model2: ROAR, base_model: Model, m1, X_train, seed):    \n",
    "    alpha = recourse_model1.alpha\n",
    "    lamb = recourse_model1.lamb\n",
    "    \n",
    "    model_adv_opt = deepcopy(base_model)\n",
    "    model_adv_roar = deepcopy(base_model)\n",
    "    \n",
    "    results_opt = {'Cost': [], 'M1 Validity': [], 'WC Validity': [], 'M1 Expectation': [], 'WC Expectation': [], 'J': [], 'Loss': []}\n",
    "    results_roar = deepcopy(results_opt)\n",
    "    \n",
    "    n = len(X)\n",
    "\n",
    "    for i in tqdm.trange(n, desc=f'Eval alpha={alpha}; lambda={lamb}', colour='#0091ff'):\n",
    "    # for i in range(n):\n",
    "        x_0 = X[i]\n",
    "        J = RecourseCost(x_0, lamb)\n",
    "        \n",
    "        if isinstance(base_model, NN):\n",
    "            #set seed for lime\n",
    "            np.random.seed(i)\n",
    "            weights_0, bias_0 = lime_explanation(base_model.predict, X_train, x_0)\n",
    "            weights_0, bias_0 = np.round(weights_0, 4), np.round(bias_0, 4)\n",
    "            \n",
    "            recourse_model1.weights = weights_0\n",
    "            recourse_model1.bias = bias_0\n",
    "            \n",
    "            recourse_model2.set_weights(weights_0)\n",
    "            recourse_model2.set_bias(bias_0)\n",
    "        \n",
    "        # OPT\n",
    "        x_r = recourse_model1.get_recourse(x_0, beta=1.)\n",
    "        weights_r, bias_r = recourse_model1.calc_theta_adv(x_r)\n",
    "        bce_loss_opt, cost_opt, rob_opt = J(x_r, weights_r, bias_r, True)\n",
    "        m1_validity_opt = base_model.predict(x_r.reshape(1,-1))[0]\n",
    "        m1_expectation_opt = base_model.predict_proba(x_r.reshape(1,-1))[0,1]\n",
    "        \n",
    "        model_adv_opt.model.coef_ = weights_r.reshape(1,-1)\n",
    "        model_adv_opt.model.intercept_ = bias_r\n",
    "        wc_validity_opt = model_adv_opt.predict(x_r.reshape(1,-1))[0]\n",
    "        wc_expectation_opt = model_adv_opt.predict_proba(x_r.reshape(1,-1))[0,1]\n",
    "        \n",
    "        # ROAR\n",
    "        x_r2, _ = recourse_model2.get_recourse(x_0)\n",
    "        weights_r2, bias_r2 = recourse_model1.calc_theta_adv(x_r2)\n",
    "        bce_loss_roar, cost_roar, rob_roar = J(x_r2, weights_r2, bias_r2, True)\n",
    "        m1_validity_roar = base_model.predict(x_r2.reshape(1,-1))[0]\n",
    "        m1_expectation_roar = base_model.predict_proba(x_r2.reshape(1,-1))[0,1]\n",
    "        \n",
    "        model_adv_roar.model.coef_ = weights_r2.reshape(1,-1)\n",
    "        model_adv_roar.model.intercept_ = bias_r2\n",
    "        wc_validity_roar = model_adv_roar.predict(x_r2.reshape(1,-1))[0]\n",
    "        wc_expectation_roar = model_adv_roar.predict_proba(x_r2.reshape(1,-1))[0,1]\n",
    "        \n",
    "        append_res(results_opt, rob_opt, bce_loss_opt, cost_opt, m1_validity_opt, wc_validity_opt, m1_expectation_opt, wc_expectation_opt)\n",
    "        append_res(results_roar, rob_roar, bce_loss_roar, cost_roar, m1_validity_roar, wc_validity_roar, m1_expectation_roar, wc_expectation_roar)\n",
    "\n",
    "    return get_res(results_opt, 'OPT', seed, alpha, lamb), get_res(results_roar, 'ROAR', seed, alpha, lamb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def recourse_tradeoff(params: dict, seeds: list, results: list):\n",
    "    if params['data'] in ['correction', 'german']:\n",
    "        data_model = CorrectionShift(\"../datasets/german.csv\", \"../datasets/corrected_german.csv\", seed=0)\n",
    "    elif params['data'] in ['temporal', 'business']:\n",
    "        data_model = TemporalShift(\"../datasets/SBAcase.11.13.17.csv\", seed=0)\n",
    "    elif params['data'] in ['geospatial', 'student']:\n",
    "        data_model = GeospatialShift(\"../datasets/student-por.csv\", seed=0)\n",
    "    elif params['data'] in ['synthetic', 'simulated']:\n",
    "        data_model = SyntheticData(alpha=1.5, beta=0, n=1000, seed=0)\n",
    "    \n",
    "    alpha = params['alpha']\n",
    "    \n",
    "    for seed in seeds:\n",
    "        data1, _ = data_model.get_data(seed)\n",
    "        X1_train, y1_train, X1_test, y1_test = data1\n",
    "\n",
    "        if params['base_model'] == 'lr':\n",
    "            base_model = LR()\n",
    "            m1 = deepcopy(base_model)\n",
    "        elif params['base_model'] == 'nn':\n",
    "            base_model = NN(X1_train.shape[1])\n",
    "            m1 = LR()\n",
    "\n",
    "        base_model.train(X1_train.values, y1_train.values)\n",
    "        m1.train(X1_train.values, y1_train.values)\n",
    "        \n",
    "        recourse_needed_X1_train = recourse_needed(base_model.predict, X1_train.values)\n",
    "        recourse_needed_X1_test = recourse_needed(base_model.predict, X1_test.values)\n",
    "        \n",
    "        print(len(recourse_needed_X1_train), len(X1_train), len(recourse_needed_X1_test), len(X1_test))\n",
    "        \n",
    "        weights, bias = None, None\n",
    "        if params['base_model'] == 'lr':\n",
    "            weights = base_model.model.coef_[0]\n",
    "            bias = base_model.model.intercept_\n",
    "        \n",
    "        recourse_model1 = LAROAR(\n",
    "            weights = weights,\n",
    "            bias = bias,\n",
    "            alpha = alpha,\n",
    "        )    \n",
    "        \n",
    "        recourse_model2 = ROAR(\n",
    "            weights = weights,\n",
    "            bias = bias,\n",
    "            alpha = alpha,\n",
    "        )\n",
    "        \n",
    "        lamb = recourse_model1.choose_lambda(recourse_needed_X1_train, base_model.predict, X1_train, base_model.predict_proba)\n",
    "        recourse_model1.lamb = lamb\n",
    "        recourse_model2.lamb = lamb\n",
    "        \n",
    "        result_opt, result_roar = recourse_model_runner(recourse_needed_X1_test, recourse_model1, recourse_model2, base_model, m1, X1_train, seed)\n",
    "        \n",
    "        results.append(result_opt)\n",
    "        results.append(result_roar)\n",
    "        \n",
    "        print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(0)\n",
    "params = {}\n",
    "# 'synthetic/simulated', 'correction/german', 'temporal/business', 'geospatial/student'\n",
    "params['data'] = 'synthetic'\n",
    "# 'lr', 'nn\n",
    "params['base_model'] = 'lr'\n",
    " \n",
    "params['alpha'] = 0.5\n",
    "\n",
    "\n",
    "seeds = list(range(5))\n",
    "\n",
    "results = []\n",
    "recourse_tradeoff(params, seeds, results)\n",
    "df_results = pd.DataFrame(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_results = pd.DataFrame(results)\n",
    "df_results.sort_values('alg')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_results_avg = df_results.groupby('alg').mean()\n",
    "df_results_avg[['Cost', 'M1 Validity', 'WC Validity', 'M1 Expectation', 'WC Expectation',  'J', 'Loss']] = df_results_avg[['Cost', 'M1 Validity', 'WC Validity', 'M1 Expectation', 'WC Expectation', 'J', 'Loss']].round(2).astype(str) + '±' + df_results.groupby('alg').std()[['Cost', 'M1 Validity', 'WC Validity', 'M1 Expectation', 'WC Expectation', 'J', 'Loss']].round(2).astype(str)\n",
    "\n",
    "print(params['base_model'], params['data'])\n",
    "df_results_avg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(f'../results_comparison/comparison_{params[\"base_model\"]}_{params[\"data\"]}.pkl', 'wb') as f:\n",
    "    pickle.dump(df_results, f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
