{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tqdm\n",
    "import pickle\n",
    "import warnings\n",
    "import numpy as np\n",
    "import scipy as sp\n",
    "import pandas as pd\n",
    "import plotly.express as px\n",
    "import plotly.graph_objects as go\n",
    "from copy import deepcopy\n",
    "from scipy.spatial import distance\n",
    "\n",
    "from data import *\n",
    "from model import *\n",
    "from utils import *\n",
    "from recourse_model import LAROAR, RecourseCost\n",
    "\n",
    "warnings.filterwarnings('ignore')\n",
    "# pd.set_option('display.max_colwidth', 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def append_hist(d, i, x_0, theta_0, x_r, theta_r):\n",
    "    d['i'].append(i)\n",
    "    d['x_0'].append(x_0)\n",
    "    d['theta_0'].append(theta_0)\n",
    "    d['x_r'].append(x_r)\n",
    "    d['theta_r'].append(theta_r)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def recourse_model_runner(dataset: np.ndarray, model: LAROAR):    \n",
    "    hist = {'i': [], 'x_0': [], 'theta_0': [], 'x_r': [], 'theta_r': []}\n",
    "    n = len(dataset)\n",
    "        \n",
    "    for i in tqdm.trange(n, desc=f'Eval alpha={model.alpha}; lambda={model.lamb}', colour='#0091ff'):\n",
    "        x_0 = dataset[i]\n",
    "        theta_0 = np.hstack((model.weights, model.bias))\n",
    "\n",
    "        x_r = model.get_recourse(x_0, beta=1.)\n",
    "        weights_r, bias_r = model.calc_theta_adv(x_r)\n",
    "        theta_r = np.hstack((weights_r, bias_r))\n",
    "        \n",
    "        append_hist(hist, i, x_0, theta_0, x_r, theta_r)\n",
    "    \n",
    "    df_hist = pd.DataFrame(hist)  \n",
    "    return df_hist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def recourse_tradeoff(params: dict):\n",
    "    if params['data'] in ['correction', 'german']:\n",
    "        data_model = CorrectionShift(\"../datasets/german.csv\", \"../datasets/corrected_german.csv\", seed=0)\n",
    "    elif params['data'] in ['temporal', 'business']:\n",
    "        data_model = TemporalShift(\"../datasets/SBAcase.11.13.17.csv\", seed=0)\n",
    "    elif params['data'] in ['geospatial', 'student']:\n",
    "        data_model = GeospatialShift(\"../datasets/student-por.csv\", seed=0)\n",
    "    elif params['data'] in ['synthetic', 'simulated']:\n",
    "        data_model = SyntheticData(alpha=1.5, beta=0, n=200, seed=0)\n",
    "    \n",
    "    data1, _ = data_model.get_data(0)\n",
    "    X1_train, y1_train, X1_test, y1_test = data1\n",
    "\n",
    "    base_model = LR()\n",
    "        \n",
    "    base_model.train(X1_train.values, y1_train.values)\n",
    "    \n",
    "    recourse_needed_X1_train = recourse_needed(base_model.predict, X1_train.values)\n",
    "    recourse_needed_X1_test = recourse_needed(base_model.predict, X1_test.values)\n",
    "    \n",
    "    alpha = params['alpha']\n",
    "    weights, bias = base_model.model.coef_[0], base_model.model.intercept_\n",
    "        \n",
    "    recourse_model = LAROAR(\n",
    "        weights = weights,\n",
    "        bias = bias,\n",
    "        alpha = alpha,\n",
    "    )    \n",
    "    \n",
    "    lamb = recourse_model.choose_lambda(recourse_needed_X1_train, base_model.predict, X1_train.values, base_model.predict_proba)\n",
    "    recourse_model.lamb = lamb\n",
    "        \n",
    "    df_hist = recourse_model_runner(recourse_needed_X1_test, recourse_model)\n",
    "\n",
    "    return df_hist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Choosing lambda\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "8it [00:00, 241.79it/s]\n",
      "Eval alpha=0.1; lambda=0.8: 100%|\u001b[38;2;0;145;255m██████████\u001b[0m| 21/21 [00:00<00:00, 15578.42it/s]\n"
     ]
    }
   ],
   "source": [
    "params = {}\n",
    "# 'synthetic/simulated', 'correction/german', 'temporal/business', 'geospatial/student'\n",
    "params['data'] = 'synthetic'\n",
    "params['alpha'] = 0.5\n",
    "\n",
    "df_hist = recourse_tradeoff(params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[1.61, 1.47, 0.07]])"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "theta_rs = np.stack(df_hist.theta_r.values)\n",
    "unique_theta_rs = np.unique(theta_rs, axis=0)\n",
    "unique_theta_rs.round(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[1.51, 1.57, 0.17]])"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "theta_0s = np.stack(df_hist.theta_0.values)\n",
    "unique_theta_0s = np.unique(theta_0s, axis=0)\n",
    "unique_theta_0s.round(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[1.61, 1.47, 0.07]])"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "n = len(unique_theta_rs)\n",
    "rng = np.random.default_rng(0)\n",
    "idx = rng.choice(range(n), 2 if n>=2 else 1, replace=False)\n",
    "two_unique_theta_rs = unique_theta_rs[idx]\n",
    "two_unique_theta_rs.round(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1.51374515, 1.47217731, 0.07049114])"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "if params['data'] == 'synthetic':\n",
    "    theta_r = deepcopy(two_unique_theta_rs[0])\n",
    "    theta_r[0] = unique_theta_0s[0][0]\n",
    "\n",
    "    two_unique_theta_rs = np.vstack((two_unique_theta_rs, theta_r))\n",
    "    two_unique_theta_rs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(f'../theta_preds/theta_preds_{params[\"data\"]}.npy', two_unique_theta_rs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[1.61, 1.47, 0.07],\n",
       "       [1.51, 1.47, 0.07]])"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.load(f'../theta_preds/theta_preds_{params[\"data\"]}.npy').round(2)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
