{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "764053f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "\n",
    "import sys, os\n",
    "module_path = os.path.abspath(os.path.join('..'))\n",
    "if module_path not in sys.path:\n",
    "    sys.path.append(module_path)\n",
    "    \n",
    "import numpy as np\n",
    "\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "from xgboost import XGBRegressor, XGBClassifier\n",
    "from econml.dml import NonParamDML, LinearDML\n",
    "from econml.dr import LinearDRLearner, ForestDRLearner\n",
    "from econml.metalearners import XLearner\n",
    "from econml import metalearners\n",
    "\n",
    "import sklearn\n",
    "from sklearn.experimental import enable_iterative_imputer\n",
    "from sklearn.impute import SimpleImputer, IterativeImputer\n",
    "from sklearn.linear_model import LogisticRegression, LinearRegression\n",
    "from sklearn.svm import SVC, SVR\n",
    "    \n",
    "from src.data import data_module\n",
    "from src.data.utils import split_eval_cate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2569a18",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_regressor(missing_value):\n",
    "    #return LinearRegression()\n",
    "    #return LassoCV()\n",
    "    #return SVR()\n",
    "    return XGBRegressor(missing=missing_value, eval_metric='logloss')\n",
    "\n",
    "def get_classifier(missing_value):\n",
    "    #return SVC(probability=True)\n",
    "    #return LogisticRegression()\n",
    "    return XGBClassifier(use_label_encoder=False, missing=missing_value, eval_metric='logloss')\n",
    "\n",
    "def get_imputer(missing_value):\n",
    "    return IterativeImputer(max_iter=1500, tol=15e-4, random_state=None, missing_values=missing_value)\n",
    "    #return SimpleImputer(missing_values=0, strategy='mean')\n",
    "    \n",
    "learners = {\n",
    "    'T': lambda missing_value: metalearners.TLearner(\n",
    "        models=get_regressor(missing_value)\n",
    "    ),\n",
    "    'X': lambda missing_value : XLearner(\n",
    "        models=get_regressor(missing_value),\n",
    "        propensity_model=get_classifier(missing_value),\n",
    "        cate_models=get_regressor(missing_value)\n",
    "    )}\n",
    "\n",
    "def evaluate(ground_truth, estimate, W):\n",
    "    PEHE = np.sqrt(((estimate - ground_truth)**2).mean())\n",
    "    PEHE_0 = np.sqrt(((estimate[W==0] - ground_truth[W==0])**2).mean())\n",
    "    PEHE_1 = np.sqrt(((estimate[W==1] - ground_truth[W==1])**2).mean())\n",
    "    return PEHE, PEHE_0, PEHE_1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f262a322",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# EXP. SETTINGS\n",
    "runs = 10\n",
    "sims = 1\n",
    "\n",
    "assert runs % sims == 0\n",
    "\n",
    "\n",
    "train_size = 5000\n",
    "\n",
    "d = 20\n",
    "z_d_dim = 10\n",
    "amount_of_missingness = [.1, .2, .3, .4, .5]\n",
    "missing_value=-1\n",
    "\n",
    "# DEBUG SETTINGS\n",
    "verbose=False\n",
    "\n",
    "\n",
    "\n",
    "# PREP\n",
    "amount_of_missingness = amount_of_missingness if isinstance(amount_of_missingness, list) else [amount_of_missingness]\n",
    "z_d_dim = z_d_dim if isinstance(z_d_dim, list) else [z_d_dim]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2183b446",
   "metadata": {},
   "outputs": [],
   "source": [
    "# WE FIX X and Y\n",
    "X = data_module._generate_covariates(d, train_size*2)\n",
    "Y0, Y1, CATE = data_module._generate_outcomes(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7016ce53",
   "metadata": {},
   "outputs": [],
   "source": [
    "learner = 'T'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aee9e6b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "res_dict = {\n",
    "    'ground_truth': np.zeros((sims, len(amount_of_missingness), len(z_d_dim), int(runs/sims), train_size)),\n",
    "    'all': np.zeros((sims, len(amount_of_missingness), len(z_d_dim),  int(runs/sims), train_size)),\n",
    "    'nothing': np.zeros((sims,len(amount_of_missingness), len(z_d_dim), int(runs/sims), train_size)),\n",
    "    'smartly': np.zeros((sims, len(amount_of_missingness), len(z_d_dim), int(runs/sims), train_size)),\n",
    "    'wrongly': np.zeros((sims, len(amount_of_missingness), len(z_d_dim), int(runs/sims), train_size)),\n",
    "}\n",
    "\n",
    "past_w =  np.zeros((sims, len(amount_of_missingness), len(z_d_dim), int(runs/sims), train_size))\n",
    "\n",
    "\n",
    "# i -> sim index\n",
    "# j -> a_o_m index\n",
    "# k -> z_d_d index\n",
    "# l -> run index\n",
    "\n",
    "# dicts: [name] -> [i, j, k, l, train_size]\n",
    "\n",
    "for i, _ in enumerate(tqdm(range(sims), desc=\"Simulation\")):\n",
    "    \n",
    "    \n",
    "    \n",
    "    \n",
    "    \n",
    "    for j, a_o_m in enumerate(tqdm(amount_of_missingness, desc=\"Missingness\", leave=False)):\n",
    "        for k, z_d_d in enumerate(tqdm(z_d_dim, desc=\"Z dim\", leave=False)):\n",
    "            \n",
    "            \n",
    "\n",
    "\n",
    "            assert 10 < train_size < len(X)\n",
    "\n",
    "\n",
    "            for l, _ in enumerate(tqdm(range(int(runs/sims)), leave=False, desc=\"Runs per sim\")):\n",
    "                \n",
    "                # GENERATE MISSINGNESS DEPENDENT VARIABLES\n",
    "                Z_down = data_module._Z_down(a_o_m, X, z_d_d)\n",
    "                W = data_module._treatments(Z_down, X, z_d_d)\n",
    "                Y = data_module._generate_observed_outcomes(Y0, Y1, W)\n",
    "                Z_up = data_module._Z_up(a_o_m, X, z_d_d, W)\n",
    "                X_ = data_module._complete_covariates(X, z_d_d, Z_up, Z_down, missing_value)\n",
    "\n",
    "                if verbose:\n",
    "                    report_data_metrics(X, Y, W, CATE, Z_up, Z_down)\n",
    "\n",
    "\n",
    "                idxs = np.random.choice(range(len(X)), size=train_size, replace=False)\n",
    "                include_idx = set(idxs)\n",
    "                mask = np.array([(i in include_idx) for i in range(len(X))])\n",
    "\n",
    "\n",
    "                X_train, Y_train, W_train, CATE_train = X_[mask], Y[mask], W[mask], CATE[mask]\n",
    "                X_test, Y_test, W_test, CATE_test = X_[~mask], Y[~mask], W[~mask], CATE[~mask]\n",
    "\n",
    "                res_dict['ground_truth'][i,j,k,l] = CATE_test\n",
    "                past_w[i,j,k,l] = W_test\n",
    "                \n",
    "                \n",
    "\n",
    "\n",
    "\n",
    "                # IMPUTE ALL\n",
    "                imputer = get_imputer(missing_value)\n",
    "                imputer.fit(X_train)\n",
    "                X_train_preprocessed = imputer.transform(X_train)\n",
    "                X_test_preprocessed = imputer.transform(X_test)\n",
    "\n",
    "                est_impute_all = learners[learner](missing_value)\n",
    "                est_impute_all.fit(Y_train, W_train, X=X_train_preprocessed)\n",
    "\n",
    "\n",
    "                #PEHE_impute_all.append(evaluate(CATE_test, te, W_test))\n",
    "                res_dict['all'][i,j,k,l] = est_impute_all.effect(X_test_preprocessed)\n",
    "\n",
    "                if verbose:\n",
    "                    print('all', X_train.min())\n",
    "\n",
    "\n",
    "\n",
    "                # IMPUTE NOTHING\n",
    "                treatment_effects_impute_nothing = []\n",
    "                X_train_preprocessed = X_train.copy()\n",
    "                X_test_preprocessed = X_test.copy()\n",
    "\n",
    "\n",
    "                est_impute_nothing = learners[learner](missing_value)\n",
    "                est_impute_nothing.fit(Y_train, W_train, X=X_train_preprocessed)\n",
    "\n",
    "                #PEHE_impute_nothing.append(evaluate(CATE_test, te, W_test))\n",
    "                res_dict['nothing'][i,j,k,l] = est_impute_nothing.effect(X_test_preprocessed)\n",
    "\n",
    "                if verbose:\n",
    "                    print('nothing', X_train.min())\n",
    "\n",
    "\n",
    "\n",
    "                # IMPUTE SMARTLY\n",
    "                treatment_effects_impute_smartly = []\n",
    "                imputer_smart = get_imputer(missing_value)\n",
    "                imputer_smart.fit(X_train[:,z_d_d:])\n",
    "\n",
    "                X_train_preprocessed = X_train.copy()\n",
    "                X_test_preprocessed = X_test.copy()\n",
    "\n",
    "                X_train_preprocessed[:,z_d_d:] = imputer_smart.transform(X_train[:,z_d_d:])\n",
    "                X_test_preprocessed[:,z_d_d:] = imputer_smart.transform(X_test[:,z_d_d:])\n",
    "\n",
    "                est_impute_smartly = learners[learner](missing_value)\n",
    "                est_impute_smartly.fit(Y_train, W_train, X=X_train_preprocessed)\n",
    "\n",
    "                #PEHE_impute_smartly.append(evaluate(CATE_test, te, W_test))\n",
    "                res_dict['smartly'][i,j,k,l] = est_impute_smartly.effect(X_test_preprocessed)\n",
    "\n",
    "                if verbose:\n",
    "                    print('smart down', X_train[:,:z_d_d].min())\n",
    "                    print('smart up', X_train[:,z_d_d:].min())\n",
    "\n",
    "\n",
    "\n",
    "                # IMPUTE WRONGLY\n",
    "                treatment_effects_impute_wrongly = []\n",
    "                imputer_wrongly = get_imputer(missing_value)\n",
    "                imputer_wrongly.fit(X_train[:,:z_d_d])\n",
    "\n",
    "                X_train_preprocessed = X_train.copy()\n",
    "                X_test_preprocessed = X_test.copy()\n",
    "\n",
    "                X_train_preprocessed[:,:z_d_d] = imputer_wrongly.transform(X_train[:,:z_d_d])\n",
    "                X_test_preprocessed[:,:z_d_d] = imputer_wrongly.transform(X_test[:,:z_d_d])\n",
    "\n",
    "                est_impute_wrongly = learners[learner](missing_value)\n",
    "                est_impute_wrongly.fit(Y_train, W_train, X=X_train_preprocessed)\n",
    "\n",
    "\n",
    "                #PEHE_impute_wrongly.append(evaluate(CATE_test, te, W_test))\n",
    "                res_dict['wrongly'][i,j,k,l] = est_impute_wrongly.effect(X_test_preprocessed)\n",
    "\n",
    "                if verbose:\n",
    "                    print('wrong down', X_train[:,:z_d_d].min())\n",
    "                    print('wrong up', X_train[:,z_d_d:].min())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09e3b466",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sens_eval(setting_key, result_dict, aggr_tuple):\n",
    "    temp = (result_dict[setting_key] - result_dict['ground_truth']) ** 2\n",
    "    swapped = np.swapaxes(temp, 0, 1)\n",
    "    \n",
    "    means = swapped.mean(axis=4).mean(axis=3).mean(axis=aggr_tuple)\n",
    "    stds = swapped.mean(axis=4).std(axis=3).mean(axis=aggr_tuple)\n",
    "    \n",
    "    return np.round(means, decimals=5), np.round(stds, decimals=4)\n",
    "\n",
    "for k in res_dict.keys():\n",
    "    if k != 'ground_truth':\n",
    "        print(f'# set: {k} \\t {\" \".join(str(a) for a in list(zip(*sens_eval(k, res_dict, (2,1)))))}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0649aba",
   "metadata": {},
   "outputs": [],
   "source": [
    "for k in res_dict.keys():\n",
    "    if k != 'ground_truth':\n",
    "        means, stds = sens_eval(k, res_dict, (2,1))\n",
    "        print(f'# set: {k} \\t means: {\" \".join(str(a) for a in list(zip(amount_of_missingness, means)))}\\n#\\t\\t\\t stds:  {\" \".join(str(a) for a in list(zip(amount_of_missingness, stds)))}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee4e58b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# T-Learner\n",
    "# set: all \t means: (0.1, 0.6387) (0.2, 0.97497) (0.3, 1.0933) (0.4, 0.77673) (0.5, 0.6489)\n",
    "#\t\t stds:  (0.1, 0.0267) (0.2, 0.0711) (0.3, 0.0584) (0.4, 0.0589) (0.5, 0.0424)\n",
    "# set: nothing \t means: (0.1, 1.10829) (0.2, 1.00463) (0.3, 0.94244) (0.4, 0.88775) (0.5, 0.82824)\n",
    "#\t\t stds:  (0.1, 0.0732) (0.2, 0.0638) (0.3, 0.051) (0.4, 0.0555) (0.5, 0.0632)\n",
    "# set: smartly \t means: (0.1, 0.60229) (0.2, 0.56311) (0.3, 0.54543) (0.4, 0.40746) (0.5, 0.35365)\n",
    "#\t\t stds:  (0.1, 0.0527) (0.2, 0.0555) (0.3, 0.0811) (0.4, 0.0371) (0.5, 0.0273)\n",
    "# set: wrongly \t means: (0.1, 1.18941) (0.2, 1.10463) (0.3, 1.05624) (0.4, 0.95405) (0.5, 0.87472)\n",
    "#\t\t stds:  (0.1, 0.0787) (0.2, 0.0531) (0.3, 0.061) (0.4, 0.0537) (0.5, 0.0524)\n",
    "\n",
    "\n",
    "# X-Learner\n",
    "# set: all \t means: (0.1, 0.46102) (0.2, 0.79628) (0.3, 0.83915) (0.4, 0.56256) (0.5, 0.4036)\n",
    "#\t\t stds:  (0.1, 0.0389) (0.2, 0.0624) (0.3, 0.0722) (0.4, 0.0522) (0.5, 0.0304)\n",
    "# set: nothing \t means: (0.1, 0.03348) (0.2, 0.04115) (0.3, 0.05752) (0.4, 0.05961) (0.5, 0.08526)\n",
    "#\t\t stds:  (0.1, 0.0047) (0.2, 0.0148) (0.3, 0.0272) (0.4, 0.0204) (0.5, 0.0254)\n",
    "# set: smartly \t means: (0.1, 0.21161) (0.2, 0.20825) (0.3, 0.22869) (0.4, 0.17132) (0.5, 0.16462)\n",
    "#\t\t stds:  (0.1, 0.0373) (0.2, 0.0283) (0.3, 0.0293) (0.4, 0.0125) (0.5, 0.0135)\n",
    "# set: wrongly \t means: (0.1, 0.04181) (0.2, 0.04966) (0.3, 0.11158) (0.4, 0.08371) (0.5, 0.04647)\n",
    "#\t\t stds:  (0.1, 0.0099) (0.2, 0.0138) (0.3, 0.097) (0.4, 0.1261) (0.5, 0.0473)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9450d56a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mcm (3.7.9)",
   "language": "python",
   "name": "mcm"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
