{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "3ae3dcaf-1ace-4127-a0b7-0eef36ccc671",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import seaborn\n",
    "from scipy import stats\n",
    "\n",
    "import numpy as np\n",
    "import river\n",
    "from river import neural_net\n",
    "from river import preprocessing as pp\n",
    "from river import optim, metrics\n",
    "from river.neural_net import activations as act\n",
    "\n",
    "\n",
    "full_dat = pd.read_csv(\"behaghel.csv\")\n",
    "relevant_columns = [\n",
    "                    'sw',\n",
    "                    'A_public',\n",
    "                    'A_private',\n",
    "                    'A_standard',\n",
    "                    'Y',\n",
    "                    'College_education',\n",
    "                    'nivetude2',\n",
    "                    'Vocational',\n",
    "                    'High_school_dropout',\n",
    "                    'Manager',\n",
    "                    'Technician',\n",
    "                    'Skilled_clerical_worker',\n",
    "                    'Unskilled_clerical_worker',\n",
    "                    'Skilled_blue_colar',\n",
    "                    'Unskilled_blue_colar',\n",
    "                    'Woman',\n",
    "                    'Married',\n",
    "                    'French',\n",
    "                    'African',\n",
    "                    'Other_Nationality',\n",
    "                    'Paris_region',\n",
    "                    'North',\n",
    "                    'Other_regions',\n",
    "                    'Employment_component_level_1',\n",
    "                    'Employment_component_level_2',\n",
    "                    'Employment_component_missing',\n",
    "                    'Economic_Layoff',\n",
    "                    'Personnal_Layoff',\n",
    "                    'End_of_Fixed_Term_Contract',\n",
    "                    'End_of_Temporary_Work',\n",
    "                    'Other_reasons_of_unemployment',\n",
    "                    'Statistical_risk_level_2',\n",
    "                    'Statistical_risk_level_3',\n",
    "                    'Other_Statistical_risk',\n",
    "                    'Search_for_a_full_time_position',\n",
    "                    'Sensitive_suburban_area',\n",
    "                    'Insertion',\n",
    "                    'Interim',\n",
    "                    'Conseil',\n",
    "                    'age',\n",
    "                    'Number_of_children',\n",
    "                    'exper',\n",
    "                    'salaire.num',\n",
    "                    'mois_saisie_occ',\n",
    "                    'ndem'\n",
    "                    ]\n",
    "full_dat = full_dat[relevant_columns]\n",
    "#print((full_dat.head()))\n",
    "\n",
    "# label columns as features, outcome, treatment\n",
    "\n",
    "# numerical features\n",
    "Xnum = [\n",
    "  'age',\n",
    "  'Number_of_children',\n",
    "  'exper', # years experience on the job\n",
    "  'salaire.num', # salary target\n",
    "  'mois_saisie_occ', # when assigned\n",
    "  'ndem' # Num. unemployment spell\n",
    "]\n",
    "\n",
    "\n",
    "# categorical features\n",
    "Xbin = [\n",
    "  'College_education',\n",
    "  'nivetude2',\n",
    "  'Vocational',\n",
    "  'High_school_dropout',\n",
    "  'Manager',\n",
    "  'Technician',\n",
    "  'Skilled_clerical_worker',\n",
    "  'Unskilled_clerical_worker',\n",
    "  'Skilled_blue_colar',\n",
    "  'Unskilled_blue_colar',\n",
    "  'Woman',\n",
    "  'Married',\n",
    "  'French',\n",
    "  'African',\n",
    "  'Other_Nationality',\n",
    "  'Paris_region',\n",
    "  'North',\n",
    "  'Other_regions',\n",
    "  'Employment_component_level_1',\n",
    "  'Employment_component_level_2',\n",
    "  'Employment_component_missing',\n",
    "  'Economic_Layoff',\n",
    "  'Personnal_Layoff',\n",
    "  'End_of_Fixed_Term_Contract',\n",
    "  'End_of_Temporary_Work',\n",
    "  'Other_reasons_of_unemployment',\n",
    "  'Statistical_risk_level_2',\n",
    "  'Statistical_risk_level_3',\n",
    "  'Other_Statistical_risk',\n",
    "  'Search_for_a_full_time_position',\n",
    "  'Sensitive_suburban_area',\n",
    "  'Insertion',\n",
    "  'Interim',\n",
    "  'Conseil'\n",
    "]\n",
    "\n",
    "\n",
    "for col in Xnum:\n",
    "    full_dat[col] = full_dat[col].astype(float)\n",
    "\n",
    "for col in Xbin:\n",
    "    full_dat[col] = full_dat[col].astype(\"category\")\n",
    "\n",
    "\n",
    "other_variables = [\"sw\", \"A_public\", \"A_private\", \"A_standard\", \"Y\"]\n",
    "\n",
    "for col in other_variables:\n",
    "    full_dat[col] = full_dat[col].astype(float)\n",
    "\n",
    "#print(full_dat.dtypes)\n",
    "\n",
    "categorical_indices = []\n",
    "\n",
    "for i in range(full_dat.shape[1]):\n",
    "    if (full_dat.columns[i] in Xbin):\n",
    "        categorical_indices.append(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "id": "8753b698-c274-42ee-8d85-c42a186e9a9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# construct bootstrap sampling function\n",
    "### here, we return pseudooutcomes + covariate matrix \n",
    "\n",
    "\n",
    "def row_to_dict(x_row: np.ndarray):\n",
    "    return {f\"x{i}\": float(x_row[i]) for i in range(x_row.shape[0])}\n",
    "\n",
    "def clip02(v: float, eps: float = 0.01) -> float:\n",
    "    return float(min(1.0 - eps, max(eps, v)))\n",
    "\n",
    "\n",
    "def bootstrap_sample(n, df, random_seed = 42, k = 4, noise_std=1/4, bound = 100, Xbin = [], Xnum = [], categorical_indices = []):\n",
    "\n",
    "   # resample based on sample weights\n",
    "    bs_sample =  df.sample(n = n, replace = True, weights = df[\"sw\"], random_state = random_seed)\n",
    "\n",
    "   # fit pseudooutcomes based on training data\n",
    "    \n",
    "    ite_est = np.array(bs_sample[\"age\"])\n",
    "    A = np.array(bs_sample[\"A_public\"])\n",
    "\n",
    "    ## add noise to Y to preserve difficulty\n",
    "    Y = np.array(bs_sample[\"Y\"])\n",
    "    np.random.seed(random_seed)\n",
    "    noise = np.maximum(-bound, np.minimum(np.random.normal(size=len(Y), loc=0, scale = noise_std), bound) )\n",
    "    Y = np.array(bs_sample[\"Y\"]) + noise\n",
    "\n",
    "\n",
    "    ### randomly sample 15 columns of X\n",
    "    X = np.array(bs_sample[np.concatenate((Xbin, Xnum))])\n",
    "    rng = np.random.default_rng(seed=random_seed)\n",
    "    idx = np.arange(15)\n",
    "    cols = rng.choice(idx, size=15, replace=False)\n",
    "    X = np.array(X[:, cols])\n",
    "    \n",
    "    X_stream = np.column_stack((A, X)) \n",
    "\n",
    "\n",
    "   ## fit neural network sequentially to get predictions for g_0, g_1\n",
    "    mlp = pp.StandardScaler() | neural_net.MLPRegressor(\n",
    "        hidden_dims=(64, 64, 64, 64),\n",
    "        activations=(act.ReLU(), act.ReLU(), act.ReLU(), act.ReLU(), act.Identity()),\n",
    "        optimizer=optim.Adam(1e-3),\n",
    "        seed=0)\n",
    "    mu_1_est = np.full(n, np.nan)\n",
    "    mu_0_est = np.full(n, np.nan)\n",
    "\n",
    "    for t in range(n):\n",
    "        ## predict based on one unified neural network (S-Learner)\n",
    "\n",
    "        true_x = row_to_dict(X_stream[t])\n",
    "\n",
    "        ## get fake samples with 1 and 0 for the first entry\n",
    "        x_0 = X_stream[t]\n",
    "        x_0[0] = 0\n",
    "        x_0 = row_to_dict(x_0)\n",
    "\n",
    "        x_1 = X_stream[t]\n",
    "        x_1[0] = 1\n",
    "        x_1 = row_to_dict(x_1)\n",
    "                \n",
    "        y = float(Y[t])\n",
    "\n",
    "        yhat_0 = mlp.predict_one(x_0)\n",
    "        yhat_1 = mlp.predict_one(x_1)\n",
    "        y_hat = mlp.predict_one(true_x)\n",
    "\n",
    "        if yhat_0 is not None:\n",
    "            #yhat_0 = clip02(yhat_0, 0)\n",
    "            mu_0_est[t] = yhat_0\n",
    "\n",
    "        if yhat_1 is not None:\n",
    "            #yhat_1 = clip02(yhat_1, 0)\n",
    "            mu_1_est[t] = yhat_1\n",
    "\n",
    "        mlp.learn_one(true_x, y)\n",
    "\n",
    "    #print(mu_1_est)\n",
    "    #print(mu_0_est)\n",
    "\n",
    "    \n",
    "    d = {\"ite_est\": ite_est,\n",
    "        \"mu_1_est\": mu_1_est,\n",
    "        \"mu_0_est\": mu_0_est,\n",
    "        \"outcome\": Y,\n",
    "        \"treatment\": A,\n",
    "        \"propensity\": np.mean(A) # assuming complete randomization with fixed probability of treatment identical for everyone\n",
    "       }\n",
    "    d = pd.DataFrame(data=d)\n",
    "\n",
    "    ###  construct EIF for each datapoint\n",
    "    pi_1 = d[\"propensity\"] \n",
    "    pi_0 = 1-pi_1\n",
    "    g_1 = d[\"mu_1_est\"]\n",
    "    g_0 = d[\"mu_0_est\"]\n",
    "    Y = d[\"outcome\"]\n",
    "    A = d[\"treatment\"]\n",
    "\n",
    "    np.random.seed(random_seed)\n",
    "    ## generate random noise to eif outcomes\n",
    "    noise = np.maximum(-bound, np.minimum(np.random.normal(size=len(Y), loc=0, scale = noise_std), bound) )\n",
    "    Y_eif = g_1 + (A==1)*(Y-g_1)/pi_1 - (g_0 + (A==0)*(Y-g_0)/pi_0) + noise ## eif\n",
    "    \n",
    "    return X, Y, Y_eif"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "id": "83067b74-ff8b-450a-b11e-aab09b3c510f",
   "metadata": {},
   "outputs": [],
   "source": [
    "## assume complete randomization\n",
    "#X,Y, Y_eif = bootstrap_sample(n=10000, df=full_dat, Xbin=Xbin, Xnum = Xnum, random_seed = 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "id": "5aec8426-81b4-4493-92f3-6cced62d6e51",
   "metadata": {},
   "outputs": [],
   "source": [
    "#print(X.shape)\n",
    "#plt.plot(Y_eif)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
