{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "498e3ab9",
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "\n",
    "import sys, os\n",
    "module_path = os.path.abspath(os.path.join('..'))\n",
    "if module_path not in sys.path:\n",
    "    sys.path.append(module_path)\n",
    "    \n",
    "import numpy as np\n",
    "\n",
    "from tqdm.notebook import tqdm, tnrange\n",
    "\n",
    "\n",
    "from xgboost import XGBRegressor, XGBClassifier\n",
    "\n",
    "import causalml\n",
    "from causalml.inference.meta import LRSRegressor\n",
    "from causalml.inference.meta import XGBTRegressor, MLPTRegressor\n",
    "from causalml.inference.meta import BaseXRegressor\n",
    "from causalml.inference.meta import BaseRRegressor\n",
    "from causalml.inference.meta import BaseSRegressor, BaseTLearner\n",
    "from causalml.inference.meta import BaseDRLearner, XGBDRRegressor, BaseDRRegressor\n",
    "from causalml.propensity import GradientBoostedPropensityModel\n",
    "\n",
    "import sklearn\n",
    "from sklearn.experimental import enable_iterative_imputer\n",
    "from sklearn.impute import SimpleImputer, IterativeImputer\n",
    "from sklearn.linear_model import LogisticRegression, LinearRegression\n",
    "from sklearn.svm import SVC, SVR\n",
    "    \n",
    "from src.data.data_module import generate_data_exp, _generate_covariates\n",
    "from src.data.utils import split_eval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2080beb4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_regressor(missing_value):\n",
    "    return XGBRegressor(missing=missing_value, eval_metric='logloss')\n",
    "\n",
    "learners = {\n",
    "    'T': lambda missing_value: BaseTLearner(learner=get_regressor(missing_value)),\n",
    "    'X': lambda missing_value : BaseXRegressor(learner=get_regressor(missing_value)),\n",
    "    'S': lambda missing_value : metalearners.SLearner(\n",
    "        overall_model=get_regressor(missing_value),\n",
    "    ),\n",
    "    'R': lambda missing_value : NonParamDML(\n",
    "        model_y=get_regressor(missing_value),\n",
    "        model_t=get_classifier(missing_value),\n",
    "        model_final=get_regressor(missing_value),\n",
    "        discrete_treatment=True\n",
    "    ),\n",
    "    'GPS': lambda missing_value: GPS(\n",
    "        missing=missing_value,\n",
    "        eval_metric='logloss',\n",
    "        use_label_encoder=False\n",
    "    ),\n",
    "    'DR': lambda missing_value: BaseDRLearner(\n",
    "        learner=get_regressor(missing_value),\n",
    "    )\n",
    "}\n",
    "\n",
    "\n",
    "def xgb_prop(missing_value, X, W):\n",
    "    pm = GradientBoostedPropensityModel(\n",
    "        eval_metric='logloss', \n",
    "        use_label_encoder=False, \n",
    "        missing=missing_value)\n",
    "    return pm.fit_predict(X, W)\n",
    "\n",
    "prop_learners = {\n",
    "    'none': lambda missing_value, X, W: None,\n",
    "    'xgb': lambda missing_value, X, W : xgb_prop(missing_value, X, W),    \n",
    "}\n",
    "\n",
    "def get_imputer(missing_value):\n",
    "    return IterativeImputer(max_iter=1500, tol=15e-4, random_state=None, missing_values=missing_value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "adcff454",
   "metadata": {},
   "outputs": [],
   "source": [
    "# EXP. SETTINGS\n",
    "runs = 100\n",
    "sims = 10\n",
    "\n",
    "assert runs % sims == 0\n",
    "\n",
    "n = 5000\n",
    "d = 20\n",
    "z_d_dim = 10\n",
    "amount_of_missingness = .3\n",
    "missing_value=-1\n",
    "\n",
    "learner = 'T'\n",
    "prop_learner = 'xgb'\n",
    "\n",
    "# DEBUG SETTINGS\n",
    "verbose=False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc72d549",
   "metadata": {},
   "outputs": [],
   "source": [
    "ground_truth = []\n",
    "\n",
    "ATE_impute_all = []\n",
    "ATE_impute_nothing = []\n",
    "ATE_impute_smartly = []\n",
    "ATE_impute_wrongly = []\n",
    "\n",
    "for _ in tqdm(range(sims)):\n",
    "    X, X_, Y0, Y1, Y, CATE, W, Z_up, Z_down = generate_data_exp(\n",
    "        n*2, d, \n",
    "        z_d_dim, \n",
    "        amount_of_missingness, \n",
    "        missing_value=missing_value)\n",
    "\n",
    "    if verbose:\n",
    "        report_data_metrics(X, Y, W, CATE, Z_up, Z_down)\n",
    "\n",
    "    for _ in tqdm(range(int(runs/sims)), leave=False):\n",
    "        \n",
    "        idxs = np.random.choice(range(n*2), size=n, replace=False)\n",
    "        \n",
    "        X__i, Y_i, W_i = X_[idxs], Y[idxs], W[idxs]    # 50% fold -> n=n\n",
    "\n",
    "        ground_truth.append(Y1.mean() - Y0.mean())\n",
    "\n",
    "\n",
    "        # IMPUTE ALL\n",
    "        X_in_use = X__i.copy()\n",
    "        imputer = get_imputer(missing_value)\n",
    "        imputer.fit(X_in_use)\n",
    "        X_in_use = imputer.transform(X_in_use)\n",
    "\n",
    "        cm_impute_all = learners[learner](missing_value)\n",
    "        ps = prop_learners[prop_learner](missing_value, X_in_use, W_i)\n",
    "\n",
    "        ATE_impute_all.append(cm_impute_all.estimate_ate(X_in_use, W_i, Y_i, p=ps)[0])\n",
    "\n",
    "        if verbose:\n",
    "            print('all', X_in_use.min())\n",
    "\n",
    "\n",
    "\n",
    "        # IMPUTE NOTHING\n",
    "        X_in_use = X__i.copy()\n",
    "\n",
    "        cm_impute_nothing = learners[learner](missing_value)\n",
    "        ps = prop_learners[prop_learner](missing_value, X_in_use, W_i)\n",
    "\n",
    "        ATE_impute_nothing.append(cm_impute_nothing.estimate_ate(X_in_use, W_i, Y_i, p=ps)[0])\n",
    "\n",
    "        if verbose:\n",
    "            print('nothing', X_in_use.min())\n",
    "\n",
    "\n",
    "\n",
    "        # IMPUTE SMARTLY\n",
    "        X_in_use = X__i.copy()\n",
    "        imputer_smart = get_imputer(missing_value)\n",
    "        imputer_smart.fit(X_in_use[:,z_d_dim:])\n",
    "\n",
    "        X_in_use[:,z_d_dim:] = imputer_smart.transform(X_in_use[:,z_d_dim:])\n",
    "\n",
    "        est_impute_smartly = learners[learner](missing_value)\n",
    "        ps = prop_learners[prop_learner](missing_value, X_in_use, W_i)\n",
    "\n",
    "        ATE_impute_smartly.append(est_impute_smartly.estimate_ate(X_in_use, W_i, Y_i, p=ps)[0])\n",
    "\n",
    "        if verbose:\n",
    "            print('smart down', X_in_use[:,:z_d_dim].min())\n",
    "            print('smart up', X_in_use[:,z_d_dim:].min())\n",
    "\n",
    "\n",
    "\n",
    "        # IMPUTE WRONGLY\n",
    "        X_in_use = X__i.copy()\n",
    "        imputer_wrongly = get_imputer(missing_value)\n",
    "        imputer_wrongly.fit(X_in_use[:,:z_d_dim])\n",
    "\n",
    "        X_in_use[:,:z_d_dim] = imputer_wrongly.transform(X_in_use[:,:z_d_dim])\n",
    "\n",
    "        est_impute_wrongly = learners[learner](missing_value)\n",
    "        ps = prop_learners[prop_learner](missing_value, X_in_use, W_i)\n",
    "\n",
    "        ATE_impute_wrongly.append(est_impute_wrongly.estimate_ate(X_in_use, W_i, Y_i, p=ps)[0])\n",
    "\n",
    "        if verbose:\n",
    "            print('wrong down', X_in_use[:,:z_d_dim].min())\n",
    "            print('wrong up', X_in_use[:,z_d_dim:].min())\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8c36ab2",
   "metadata": {},
   "outputs": [],
   "source": [
    "print('# SETUP')\n",
    "print(f'# amount_of_missingness = {amount_of_missingness}')\n",
    "\n",
    "all_means, all_stds = split_eval(ATE_impute_all, ground_truth, sims)\n",
    "no_means, no_stds = split_eval(ATE_impute_nothing, ground_truth, sims)\n",
    "smart_means, smart_stds = split_eval(ATE_impute_smartly, ground_truth, sims)\n",
    "wrong_means, wrong_stds = split_eval(ATE_impute_wrongly, ground_truth, sims)\n",
    "\n",
    "print(f'#   ALL IMPUTATION   :\\t {all_means.mean()}\\t{all_stds.mean()}')\n",
    "print(f'#   NO IMPUTATION    :\\t {no_means.mean()}\\t{no_stds.mean()}')\n",
    "print(f'#   SMART IMPUTATION :\\t {smart_means.mean()}\\t{smart_stds.mean()}')\n",
    "print(f'#   WRONG IMPUTATION :\\t {wrong_means.mean()}\\t{wrong_stds.mean()}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9286455",
   "metadata": {},
   "outputs": [],
   "source": [
    "# RESULTS\n",
    "\n",
    "# X\n",
    "# T-Learner\n",
    "# SETUP\n",
    "# amount_of_missingness = 0.3\n",
    "#   ALL IMPUTATION   :\t 0.09515771581695376\t0.010809105881710744\n",
    "#   NO IMPUTATION    :\t 0.06420910047705318\t0.027903473513318455\n",
    "#   SMART IMPUTATION :\t 0.040317712250933016\t0.01410509008640046\n",
    "#   WRONG IMPUTATION :\t 0.09311239665206224\t0.026522649552125133\n",
    "\n",
    "\n",
    "# V\n",
    "# X-Learner\n",
    "# SETUP\n",
    "# amount_of_missingness = 0.3\n",
    "#   ALL IMPUTATION   :\t 0.04725865475246275\t0.00920325287310282\n",
    "#   NO IMPUTATION    :\t 0.07268430878500698\t0.02402610881856112\n",
    "#   SMART IMPUTATION :\t 0.03096103028420011\t0.014224934591948447\n",
    "#   WRONG IMPUTATION :\t 0.09842917616331746\t0.04098717021403818\n",
    "\n",
    "# V\n",
    "# DR-Learner\n",
    "# SETUP\n",
    "# amount_of_missingness = 0.3\n",
    "#   ALL IMPUTATION   :\t 0.0651557198852841\t    0.008415489835708705\n",
    "#   NO IMPUTATION    :\t 0.09022430074549292\t0.018791844604201892\n",
    "#   SMART IMPUTATION :\t 0.0381744174377323\t    0.009683497712183653\n",
    "#   WRONG IMPUTATION :\t 0.09028986507376442\t0.01902110820473588\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f35fc23",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mcm (3.7.9)",
   "language": "python",
   "name": "mcm"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
