{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a610c13b",
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "\n",
    "import sys, os\n",
    "module_path = os.path.abspath(os.path.join('..'))\n",
    "if module_path not in sys.path:\n",
    "    sys.path.append(module_path)\n",
    "    \n",
    "import numpy as np\n",
    "\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "from xgboost import XGBRegressor, XGBClassifier\n",
    "from econml.dml import NonParamDML, LinearDML\n",
    "from econml.dr import LinearDRLearner, ForestDRLearner\n",
    "from econml.metalearners import XLearner\n",
    "from econml import metalearners\n",
    "\n",
    "import sklearn\n",
    "from sklearn.experimental import enable_iterative_imputer\n",
    "from sklearn.impute import SimpleImputer, IterativeImputer\n",
    "from sklearn.linear_model import LogisticRegression, LinearRegression\n",
    "from sklearn.svm import SVC, SVR\n",
    "    \n",
    "from src.data.data_module import generate_data_exp, _generate_covariates, report_data_metrics\n",
    "from src.data.utils import split_eval_cate\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3db56c93",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_regressor(missing_value):\n",
    "    #return LinearRegression()\n",
    "    #return LassoCV()\n",
    "    #return SVR()\n",
    "    return XGBRegressor(missing=missing_value, eval_metric='logloss')\n",
    "\n",
    "def get_classifier(missing_value):\n",
    "    #return SVC(probability=True)\n",
    "    #return LogisticRegression()\n",
    "    return XGBClassifier(use_label_encoder=False, missing=missing_value, eval_metric='logloss')\n",
    "\n",
    "def get_imputer(missing_value):\n",
    "    return IterativeImputer(max_iter=1500, tol=15e-4, random_state=None, missing_values=missing_value)\n",
    "    #return SimpleImputer(missing_values=0, strategy='mean')\n",
    "    \n",
    "learners = {\n",
    "    'T': lambda missing_value: metalearners.TLearner(\n",
    "        models=get_regressor(missing_value)\n",
    "    ),\n",
    "    'X': lambda missing_value : XLearner(\n",
    "        models=get_regressor(missing_value),\n",
    "        propensity_model=get_classifier(missing_value),\n",
    "        cate_models=get_regressor(missing_value)\n",
    "    ),\n",
    "    'S': lambda missing_value : metalearners.SLearner(\n",
    "        overall_model=get_regressor(missing_value),\n",
    "    ),\n",
    "    'R': lambda missing_value : NonParamDML(\n",
    "        model_y=get_regressor(missing_value),\n",
    "        model_t=get_classifier(missing_value),\n",
    "        model_final=get_regressor(missing_value),\n",
    "        discrete_treatment=True\n",
    "    ),\n",
    "    'DR': lambda missing_value : ForestDRLearner(\n",
    "        model_propensity=get_classifier(missing_value),\n",
    "        model_regression=get_regressor(missing_value)\n",
    "    ),\n",
    "    'DML': lambda missing_value: NonParamDML(\n",
    "        model_y=get_regressor(missing_value),\n",
    "        model_t=get_classifier(missing_value),\n",
    "        model_final=get_regressor(missing_value),\n",
    "        discrete_treatment=True)\n",
    "}\n",
    "\n",
    "def evaluate(ground_truth, estimate, W):\n",
    "    PEHE = np.sqrt(((estimate - ground_truth)**2).mean())\n",
    "    PEHE_0 = np.sqrt(((estimate[W==0] - ground_truth[W==0])**2).mean())\n",
    "    PEHE_1 = np.sqrt(((estimate[W==1] - ground_truth[W==1])**2).mean())\n",
    "    return PEHE, PEHE_0, PEHE_1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8431aad6",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# EXP. SETTINGS\n",
    "runs = 100\n",
    "sims = 10\n",
    "\n",
    "assert runs % sims == 0\n",
    "\n",
    "\n",
    "train_size = 5000\n",
    "\n",
    "d = 20\n",
    "z_d_dim = 10\n",
    "amount_of_missingness = .3\n",
    "missing_value=-1\n",
    "\n",
    "learner = 'DR'\n",
    "\n",
    "\n",
    "# DEBUG SETTINGS\n",
    "verbose=False\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e61592b",
   "metadata": {},
   "outputs": [],
   "source": [
    "ground_truth_cate = []\n",
    "ground_truth_cate_0 = []\n",
    "ground_truth_cate_1 = []\n",
    "\n",
    "\n",
    "PEHE_impute_all = []\n",
    "PEHE_impute_nothing = []\n",
    "PEHE_impute_smartly = []\n",
    "PEHE_impute_wrongly = []\n",
    "\n",
    "effect_impute_all = []\n",
    "effect_impute_nothing = []\n",
    "effect_impute_smartly = []\n",
    "effect_impute_wrongly = []\n",
    "\n",
    "effect_impute_all_0 = []\n",
    "effect_impute_nothing_0 = []\n",
    "effect_impute_smartly_0 = []\n",
    "effect_impute_wrongly_0 = []\n",
    "\n",
    "effect_impute_all_1 = []\n",
    "effect_impute_nothing_1 = []\n",
    "effect_impute_smartly_1 = []\n",
    "effect_impute_wrongly_1 = []\n",
    "for _ in tqdm(range(sims)):\n",
    "    \n",
    "\n",
    "    X, X_, Y0, Y1, Y, CATE, W, Z_up, Z_down = generate_data_exp(\n",
    "        train_size*2, d, \n",
    "        z_d_dim, \n",
    "        amount_of_missingness, \n",
    "        missing_value=missing_value)\n",
    "    \n",
    "    if verbose:\n",
    "        report_data_metrics(X, Y, W, CATE, Z_up, Z_down)\n",
    "\n",
    "\n",
    "    assert 10 < train_size < len(X)\n",
    "    \n",
    "    \n",
    "    for _ in tqdm(range(int(runs/sims)), leave=False):\n",
    "\n",
    "\n",
    "        idxs = np.random.choice(range(len(X)), size=train_size, replace=False)\n",
    "        include_idx = set(idxs)\n",
    "        mask = np.array([(i in include_idx) for i in range(len(X))])\n",
    "\n",
    "\n",
    "        X_train, Y_train, W_train, CATE_train = X_[mask], Y[mask], W[mask], CATE[mask]\n",
    "        X_test, Y_test, W_test, CATE_test = X_[~mask], Y[~mask], W[~mask], CATE[~mask]\n",
    "        \n",
    "        ground_truth_cate.append(CATE_test)\n",
    "        ground_truth_cate_0.append(CATE_test[W_test == 0])\n",
    "        ground_truth_cate_1.append(CATE_test[W_test == 1])\n",
    "\n",
    "\n",
    "\n",
    "        # IMPUTE ALL\n",
    "        imputer = get_imputer(missing_value)\n",
    "        imputer.fit(X_train)\n",
    "        X_train_preprocessed = imputer.transform(X_train)\n",
    "        X_test_preprocessed = imputer.transform(X_test)\n",
    "\n",
    "        est_impute_all = learners[learner](missing_value)\n",
    "        est_impute_all.fit(Y_train, W_train, X=X_train_preprocessed)\n",
    "\n",
    "\n",
    "        #PEHE_impute_all.append(evaluate(CATE_test, te, W_test))\n",
    "        effect_impute_all.append(est_impute_all.effect(X_test_preprocessed))\n",
    "        effect_impute_all_0.append(est_impute_all.effect(X_test_preprocessed[W_test == 0]))\n",
    "        effect_impute_all_1.append(est_impute_all.effect(X_test_preprocessed[W_test == 1]))\n",
    "\n",
    "        if verbose:\n",
    "            print('all', X_train.min())\n",
    "\n",
    "\n",
    "\n",
    "        # IMPUTE NOTHING\n",
    "        treatment_effects_impute_nothing = []\n",
    "        X_train_preprocessed = X_train.copy()\n",
    "        X_test_preprocessed = X_test.copy()\n",
    "\n",
    "\n",
    "        est_impute_nothing = learners[learner](missing_value)\n",
    "        est_impute_nothing.fit(Y_train, W_train, X=X_train_preprocessed)\n",
    "\n",
    "        #PEHE_impute_nothing.append(evaluate(CATE_test, te, W_test))\n",
    "        effect_impute_nothing.append(est_impute_nothing.effect(X_test_preprocessed))\n",
    "        effect_impute_nothing_0.append(est_impute_nothing.effect(X_test_preprocessed[W_test == 0]))\n",
    "        effect_impute_nothing_1.append(est_impute_nothing.effect(X_test_preprocessed[W_test == 1]))\n",
    "\n",
    "        if verbose:\n",
    "            print('nothing', X_train.min())\n",
    "\n",
    "\n",
    "\n",
    "        # IMPUTE SMARTLY\n",
    "        treatment_effects_impute_smartly = []\n",
    "        imputer_smart = get_imputer(missing_value)\n",
    "        imputer_smart.fit(X_train[:,z_d_dim:])\n",
    "\n",
    "        X_train_preprocessed = X_train.copy()\n",
    "        X_test_preprocessed = X_test.copy()\n",
    "\n",
    "        X_train_preprocessed[:,z_d_dim:] = imputer_smart.transform(X_train[:,z_d_dim:])\n",
    "        X_test_preprocessed[:,z_d_dim:] = imputer_smart.transform(X_test[:,z_d_dim:])\n",
    "\n",
    "        est_impute_smartly = learners[learner](missing_value)\n",
    "        est_impute_smartly.fit(Y_train, W_train, X=X_train_preprocessed)\n",
    "\n",
    "        #PEHE_impute_smartly.append(evaluate(CATE_test, te, W_test))\n",
    "        effect_impute_smartly.append(est_impute_smartly.effect(X_test_preprocessed))\n",
    "        effect_impute_smartly_0.append(est_impute_smartly.effect(X_test_preprocessed[W_test == 0]))\n",
    "        effect_impute_smartly_1.append(est_impute_smartly.effect(X_test_preprocessed[W_test == 1]))\n",
    "\n",
    "        if verbose:\n",
    "            print('smart down', X_train[:,:z_d_dim].min())\n",
    "            print('smart up', X_train[:,z_d_dim:].min())\n",
    "\n",
    "\n",
    "\n",
    "        # IMPUTE WRONGLY\n",
    "        treatment_effects_impute_wrongly = []\n",
    "        imputer_wrongly = get_imputer(missing_value)\n",
    "        imputer_wrongly.fit(X_train[:,:z_d_dim])\n",
    "\n",
    "        X_train_preprocessed = X_train.copy()\n",
    "        X_test_preprocessed = X_test.copy()\n",
    "\n",
    "        X_train_preprocessed[:,:z_d_dim] = imputer_wrongly.transform(X_train[:,:z_d_dim])\n",
    "        X_test_preprocessed[:,:z_d_dim] = imputer_wrongly.transform(X_test[:,:z_d_dim])\n",
    "\n",
    "        est_impute_wrongly = learners[learner](missing_value)\n",
    "        est_impute_wrongly.fit(Y_train, W_train, X=X_train_preprocessed)\n",
    "\n",
    "\n",
    "        #PEHE_impute_wrongly.append(evaluate(CATE_test, te, W_test))\n",
    "        effect_impute_wrongly.append(est_impute_wrongly.effect(X_test_preprocessed))\n",
    "        effect_impute_wrongly_0.append(est_impute_wrongly.effect(X_test_preprocessed[W_test == 0]))\n",
    "        effect_impute_wrongly_1.append(est_impute_wrongly.effect(X_test_preprocessed[W_test == 1]))\n",
    "\n",
    "        if verbose:\n",
    "            print('wrong down', X_train[:,:z_d_dim].min())\n",
    "            print('wrong up', X_train[:,z_d_dim:].min())\n",
    "            \n",
    "        \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a38ef57",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_means, all_stds = split_eval_cate(effect_impute_all, ground_truth_cate, sims)\n",
    "no_means, no_stds = split_eval_cate(effect_impute_nothing, ground_truth_cate, sims)\n",
    "smart_means, smart_stds = split_eval_cate(effect_impute_smartly, ground_truth_cate, sims)\n",
    "wrong_means, wrong_stds = split_eval_cate(effect_impute_wrongly, ground_truth_cate, sims)\n",
    "\n",
    "all_means_0, all_stds_0 = split_eval_cate(effect_impute_all_0, ground_truth_cate_0, sims)\n",
    "no_means_0, no_stds_0 = split_eval_cate(effect_impute_nothing_0, ground_truth_cate_0, sims)\n",
    "smart_means_0, smart_stds_0 = split_eval_cate(effect_impute_smartly_0, ground_truth_cate_0, sims)\n",
    "wrong_means_0, wrong_stds_0 = split_eval_cate(effect_impute_wrongly_0, ground_truth_cate_0, sims)\n",
    "\n",
    "all_means_1, all_stds_1 = split_eval_cate(effect_impute_all_1, ground_truth_cate_1, sims)\n",
    "no_means_1, no_stds_1 = split_eval_cate(effect_impute_nothing_1, ground_truth_cate_1, sims)\n",
    "smart_means_1, smart_stds_1 = split_eval_cate(effect_impute_smartly_1, ground_truth_cate_1, sims)\n",
    "wrong_means_1, wrong_stds_1 = split_eval_cate(effect_impute_wrongly_1, ground_truth_cate_1, sims)\n",
    "\n",
    "print('# SETUP')\n",
    "print(f'# amount_of_missingness = {amount_of_missingness}')\n",
    "\n",
    "print(f'#   ALL IMPUTATION  :\\t{np.mean(all_means)}\\t{np.mean(all_stds)}')\n",
    "print(f'#   NO IMPUTATION   :\\t{np.mean(no_means)}\\t{np.mean(no_stds)}')\n",
    "print(f'#   SMART IMPUTATION:\\t{np.mean(smart_means)}\\t{np.mean(smart_stds)}')\n",
    "print(f'#   WRONG IMPUTATION:\\t{np.mean(wrong_means)}\\t{np.mean(wrong_stds)}')\n",
    "print(f'#   ---------------- PEHE_0')\n",
    "print(f'#   ALL IMPUTATION  :\\t{np.mean(all_means_0)}\\t{np.mean(all_stds_0)}')\n",
    "print(f'#   NO IMPUTATION   :\\t{np.mean(no_means_0)}\\t{np.mean(no_stds_0)}')\n",
    "print(f'#   SMART IMPUTATION:\\t{np.mean(smart_means_0)}\\t{np.mean(smart_stds_0)}')\n",
    "print(f'#   WRONG IMPUTATION:\\t{np.mean(wrong_means_0)}\\t{np.mean(wrong_stds_0)}')\n",
    "print(f'#   ---------------- PEHE_1')\n",
    "print(f'#   ALL IMPUTATION  :\\t{np.mean(all_means_1)}\\t{np.mean(all_stds_1)}')\n",
    "print(f'#   NO IMPUTATION   :\\t{np.mean(no_means_1)}\\t{np.mean(no_stds_1)}')\n",
    "print(f'#   SMART IMPUTATION:\\t{np.mean(smart_means_1)}\\t{np.mean(smart_stds_1)}')\n",
    "print(f'#   WRONG IMPUTATION:\\t{np.mean(wrong_means_1)}\\t{np.mean(wrong_stds_1)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34a61e67",
   "metadata": {},
   "outputs": [],
   "source": [
    "# RESULTS (temp, see below)\n",
    "\n",
    "# V\n",
    "# T-Learner (100 runs | n=5000:5000 | XGBoost)\n",
    "# SETUP\n",
    "# amount_of_missingness = 0.3\n",
    "#   ALL IMPUTATION  :\t0.7603202228807558\t0.05107655476272003\n",
    "#   NO IMPUTATION   :\t0.6906018297012612\t0.07293375356277434\n",
    "#   SMART IMPUTATION:\t0.4605443542126303\t0.045559141435647814\n",
    "#   WRONG IMPUTATION:\t0.9158624744688488\t0.06479239894963464\n",
    "#   ---------------- PEHE_0\n",
    "#   ALL IMPUTATION  :\t0.7371179935236377\t0.0813447454038048\n",
    "#   NO IMPUTATION   :\t0.7015184811570604\t0.10091128810622667\n",
    "#   SMART IMPUTATION:\t0.572054086285261\t0.07933149623910837\n",
    "#   WRONG IMPUTATION:\t0.9351743987244323\t0.12057528790964936\n",
    "#   ---------------- PEHE_1\n",
    "#   ALL IMPUTATION  :\t0.7726199464001329\t0.05520143480388414\n",
    "#   NO IMPUTATION   :\t0.6881436442651043\t0.0901607628080812\n",
    "#   SMART IMPUTATION:\t0.40971083141461895\t0.045350667352497306\n",
    "#   WRONG IMPUTATION:\t0.9183522971586949\t0.07513038656333801\n",
    "\n",
    "\n",
    "\n",
    "# V\n",
    "# DR-Learner (30 runs | n=5000:5000 | XGBoost)\n",
    "# SETUP\n",
    "# amount_of_missingness = 0.3\n",
    "#   ALL IMPUTATION  :\t1.3674804785864336\t1.7318280165796296\n",
    "#   NO IMPUTATION   :\t0.9409536956704727\t1.9431432104562432\n",
    "#   SMART IMPUTATION:\t0.20421566866874663\t0.22411214110370406\n",
    "#   WRONG IMPUTATION:\t4.365789991043024\t8.82361783733213\n",
    "#   ---------------- PEHE_0\n",
    "#   ALL IMPUTATION  :\t1.2083191574779995\t1.6106075455876951\n",
    "#   NO IMPUTATION   :\t0.8130941858636677\t1.2876995138351313\n",
    "#   SMART IMPUTATION:\t0.1787563949407937\t0.20287619043864713\n",
    "#   WRONG IMPUTATION:\t4.2306427399228905\t9.056379962065844\n",
    "#   ---------------- PEHE_1\n",
    "#   ALL IMPUTATION  :\t1.4419080960316477\t1.8027164353381573\n",
    "#   NO IMPUTATION   :\t0.997339557310883\t2.3179472556207465\n",
    "#   SMART IMPUTATION:\t0.216936220213687\t0.23653557233533617\n",
    "#   WRONG IMPUTATION:\t4.432220414868022\t8.749143751110289\n",
    "\n",
    "\n",
    "# V\n",
    "# X-Learner (100 runs | n=5000:5000 | XGBoost)\n",
    "# SETUP\n",
    "# amount_of_missingness = 0.3\n",
    "#   ALL IMPUTATION  :\t0.6149946402626558\t0.06390551930603257\n",
    "#   NO IMPUTATION   :\t0.3027909087842893\t0.08597057400122958\n",
    "#   SMART IMPUTATION:\t0.2116881614969457\t0.03239196614270469\n",
    "#   WRONG IMPUTATION:\t0.4912969947275013\t0.10914276845610779\n",
    "#   ---------------- PEHE_0\n",
    "#   ALL IMPUTATION  :\t0.6272957845922045\t0.07083868937874815\n",
    "#   NO IMPUTATION   :\t0.2907498976394652\t0.10774536989843311\n",
    "#   SMART IMPUTATION:\t0.2556986684662631\t0.06254388476455\n",
    "#   WRONG IMPUTATION:\t0.5198024390023097\t0.15692764578598367\n",
    "#   ---------------- PEHE_1\n",
    "#   ALL IMPUTATION  :\t0.6097365193131001\t0.06839144126845112\n",
    "#   NO IMPUTATION   :\t0.3091613781100213\t0.09853250057067761\n",
    "#   SMART IMPUTATION:\t0.19157833135487273\t0.03664993710259606\n",
    "#   WRONG IMPUTATION:\t0.4803861176358356\t0.13089151115646364"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a94ea7ca",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mcm (3.7.9)",
   "language": "python",
   "name": "mcm"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
