{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "443bfa88",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "from xgboost import XGBRegressor, XGBClassifier\n",
    "from econml.dml import NonParamDML, LinearDML\n",
    "from econml.dr import LinearDRLearner\n",
    "from econml.metalearners import XLearner\n",
    "from econml import metalearners\n",
    "\n",
    "from sklearn.experimental import enable_iterative_imputer\n",
    "from sklearn.impute import SimpleImputer, IterativeImputer\n",
    "from sklearn.linear_model import LogisticRegression, LinearRegression\n",
    "from sklearn.svm import SVC, SVR\n",
    "\n",
    "import torch\n",
    "from torch import nn\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "8462e5b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_data(n, d, z_d_dim, amount_of_missingness, treatment_balance, missing_value=-1):\n",
    "    # GENERATE DATA\n",
    "    \n",
    "\n",
    "    assert 0 < n\n",
    "    assert 0 < z_d_dim <= 10\n",
    "    assert 0 < amount_of_missingness < .5\n",
    "    assert .5 <= treatment_balance < 1.\n",
    "\n",
    "\n",
    "    # COVARIATES\n",
    "    #X = np.random.rand(n, drandom.multivariate_normal)         # Fully observed X\n",
    "    A = np.random.rand(d,d)\n",
    "    cov = np.dot(A, A.transpose())\n",
    "\n",
    "    X = np.random.multivariate_normal(np.zeros(d), cov, size=n)\n",
    "    X /= (X.max() - X.min())\n",
    "\n",
    "    # DOWN\n",
    "    alpha = 1 - amount_of_missingness\n",
    "    p = (1+np.sqrt(2 * alpha-1))/2\n",
    "\n",
    "    theta_down = np.random.rand(z_d_dim)\n",
    "    Z_down = np.logical_xor(\n",
    "        X[:,:z_d_dim] + np.random.randn(n, z_d_dim) * .01 > 0, \n",
    "        theta_down > p\n",
    "    ).astype(int)\n",
    "    Z_down = np.abs(Z_down-1)        # 0 = missing, 1 = present\n",
    "    \n",
    "\n",
    "    # X_tilde_down\n",
    "    X_ = X.copy()\n",
    "    X_[:,:z_d_dim][Z_down==0] = missing_value\n",
    "    \n",
    "    \n",
    "    \n",
    "    \n",
    "    \n",
    "    # X^down, Z^down -> W\n",
    "    # TREATMENTS\n",
    "    theta_w = np.random.rand(z_d_dim)\n",
    "    p = (1+np.sqrt(2 * .6-1))/2\n",
    "    _B = np.logical_xor(\n",
    "        X[:,:z_d_dim] + np.random.randn(n, z_d_dim) * .01 > 0, \n",
    "        theta_w > p\n",
    "    ).astype(int)\n",
    "    _B = np.abs(_B-1).mean(1)\n",
    "    W = np.random.binomial(1, _B)\n",
    "    \n",
    "    \n",
    "\n",
    "\n",
    "    # UP\n",
    "    theta_up_1 = np.random.rand(d - z_d_dim)\n",
    "    theta_up_0 = np.random.rand(d - z_d_dim)\n",
    "\n",
    "    alpha = 1 - amount_of_missingness\n",
    "    p = (1+np.sqrt(2 * alpha-1))/2\n",
    "\n",
    "    Z_up_0 = np.logical_xor(\n",
    "        X[np.where(W==0)[0],z_d_dim:] > 0, \n",
    "        theta_up_0 > p\n",
    "    ).astype(int)\n",
    "    Z_up_0 = np.abs(Z_up_0 - 1)\n",
    "\n",
    "    Z_up_1 = np.logical_xor(\n",
    "        X[np.where(W==1)[0],z_d_dim:] > 0, \n",
    "        theta_up_1 > p\n",
    "    ).astype(int)\n",
    "    Z_up_1 = np.abs(Z_up_1 - 1)\n",
    "    \n",
    "\n",
    "\n",
    "    # X_tilde_up\n",
    "    X_[np.where(W==0)[0], z_d_dim:] *= Z_up_0\n",
    "    X_[np.where(W==1)[0], z_d_dim:] *= Z_up_1\n",
    "    \n",
    "    X_[X_ == 0] = missing_value\n",
    "\n",
    "\n",
    "    # OUTCOMES\n",
    "    theta_y0 = np.random.randn(d)\n",
    "    theta_y1 = np.random.randn(d)\n",
    "    \n",
    "    \n",
    "\n",
    "    Y0 = np.sum(np.abs(X * theta_y0), 1)\n",
    "    Y1 = np.sum(np.abs(X * theta_y1), 1)\n",
    "\n",
    "    Y = np.array([Y0[i] if w == 0 else Y1[i] for i, w in enumerate(W)]) + np.random.randn(n)*.1\n",
    "\n",
    "    CATE = Y1 - Y0\n",
    "    \n",
    "    return X, X_, Y0, Y1, Y, CATE, W, Z_up_1, Z_up_0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "f7f5eadf",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_regressor(missing_value):\n",
    "    #return LinearRegression()\n",
    "    #return LassoCV()\n",
    "    #return SVR()\n",
    "    return XGBRegressor(missing=missing_value, eval_metric='logloss')\n",
    "\n",
    "def get_classifier(missing_value):\n",
    "    #return SVC(probability=True)\n",
    "    #return LogisticRegression()\n",
    "    return XGBClassifier(use_label_encoder=False, missing=missing_value, eval_metric='logloss')\n",
    "\n",
    "def get_imputer(missing_value):\n",
    "    return IterativeImputer(max_iter=1000, random_state=None, missing_values=missing_value)\n",
    "    #return SimpleImputer(missing_values=0, strategy='mean')\n",
    "    \n",
    "learners = {\n",
    "    'T': lambda missing_value: metalearners.TLearner(models= get_regressor(missing_value)),\n",
    "    'X': lambda missing_value : XLearner(\n",
    "        models=get_regressor(missing_value),\n",
    "        propensity_model=get_classifier(missing_value),\n",
    "        cate_models=get_regressor(missing_value)\n",
    "    ),\n",
    "    'S': lambda missing_value : metalearners.SLearner(\n",
    "        overall_model=get_regressor(missing_value),\n",
    "    ),\n",
    "    'R': lambda missing_value : NonParamDML(\n",
    "        model_y=get_regressor(missing_value),\n",
    "        model_t=get_classifier(missing_value),\n",
    "        model_final=get_regressor(missing_value),\n",
    "        discrete_treatment=True\n",
    "    ),\n",
    "}\n",
    "\n",
    "def evaluate(ground_truth, estimate, W):\n",
    "    PEHE = np.sqrt(((estimate - ground_truth)**2).mean())\n",
    "    PEHE_0 = np.sqrt(((estimate[W==0] - ground_truth[W==0])**2).mean())\n",
    "    PEHE_1 = np.sqrt(((estimate[W==1] - ground_truth[W==1])**2).mean())\n",
    "    return PEHE, PEHE_0, PEHE_1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "acf3730c",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# EXP. SETTINGS\n",
    "train_size = 5000\n",
    "runs = 10\n",
    "\n",
    "n = 10000\n",
    "d = 20\n",
    "z_d_dim = 10\n",
    "amount_of_missingness = .1\n",
    "treatment_balance = .5\n",
    "missing_value=-1\n",
    "\n",
    "learner = 'X'\n",
    "\n",
    "\n",
    "# DEBUG SETTINGS\n",
    "verbose=False\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "30c6a79b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|█████▉                                                     | 1/10 [01:19<11:51, 79.04s/it][IterativeImputer] Early stopping criterion not reached.\n",
      " 70%|█████████████████████████████████████████▎                 | 7/10 [07:02<02:17, 45.80s/it][IterativeImputer] Early stopping criterion not reached.\n",
      " 90%|█████████████████████████████████████████████████████      | 9/10 [08:46<00:46, 46.33s/it][IterativeImputer] Early stopping criterion not reached.\n",
      "100%|██████████████████████████████████████████████████████████| 10/10 [10:12<00:00, 61.29s/it]\n"
     ]
    }
   ],
   "source": [
    "PEHE_impute_all = []\n",
    "PEHE_impute_nothing = []\n",
    "PEHE_impute_smartly = []\n",
    "PEHE_impute_wrongly = []\n",
    "for _ in tqdm(range(runs)):\n",
    "    \n",
    "\n",
    "    X, X_, Y0, Y1, Y, CATE, W,_,_ = generate_data(n, d, z_d_dim, amount_of_missingness, treatment_balance, missing_value=missing_value)\n",
    "\n",
    "\n",
    "    assert 10 < train_size < len(X)\n",
    "\n",
    "    X_train, Y_train, W_train, CATE_train = X_[:train_size], Y[:train_size], W[:train_size], CATE[:train_size]\n",
    "    X_test, Y_test, W_test, CATE_test = X_[train_size:], Y[train_size:], W[train_size:], CATE[train_size:]\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "    # IMPUTE ALL\n",
    "    imputer = get_imputer(missing_value)\n",
    "    imputer.fit(X_train)\n",
    "    X_train_preprocessed = imputer.transform(X_train)\n",
    "    X_test_preprocessed = imputer.transform(X_test)\n",
    "    \n",
    "    est_impute_all = learners[learner](missing_value)\n",
    "    est_impute_all.fit(Y_train, W_train, X=X_train_preprocessed)\n",
    "    \n",
    "    te = est_impute_all.effect(X_test_preprocessed)\n",
    "    PEHE_impute_all.append(evaluate(CATE_test, te, W_test))\n",
    "    \n",
    "    if verbose:\n",
    "        print('all', X_train.min())\n",
    "\n",
    "\n",
    "\n",
    "    # IMPUTE NOTHING\n",
    "    treatment_effects_impute_nothing = []\n",
    "    X_train_preprocessed = X_train.copy()\n",
    "    X_test_preprocessed = X_test.copy()\n",
    "    \n",
    "    \n",
    "    est_impute_nothing = learners[learner](missing_value)\n",
    "    est_impute_nothing.fit(Y_train, W_train, X=X_train_preprocessed)\n",
    "    \n",
    "    te = est_impute_nothing.effect(X_test_preprocessed)\n",
    "    PEHE_impute_nothing.append(evaluate(CATE_test, te, W_test))\n",
    "    \n",
    "    if verbose:\n",
    "        print('nothing', X_train.min())\n",
    "    \n",
    "    \n",
    "    \n",
    "    # IMPUTE SMARTLY\n",
    "    treatment_effects_impute_smartly = []\n",
    "    imputer_smart = get_imputer(missing_value)\n",
    "    imputer_smart.fit(X_train[:,z_d_dim:])\n",
    "    \n",
    "    X_train_preprocessed = X_train.copy()\n",
    "    X_test_preprocessed = X_test.copy()\n",
    "    \n",
    "    X_train_preprocessed[:,z_d_dim:] = imputer_smart.transform(X_train[:,z_d_dim:])\n",
    "    X_test_preprocessed[:,z_d_dim:] = imputer_smart.transform(X_test[:,z_d_dim:])\n",
    "    \n",
    "    est_impute_smartly = learners[learner](missing_value)\n",
    "    est_impute_smartly.fit(Y_train, W_train, X=X_train_preprocessed)\n",
    "\n",
    "    te = est_impute_smartly.effect(X_test_preprocessed)\n",
    "    PEHE_impute_smartly.append(evaluate(CATE_test, te, W_test))\n",
    "    \n",
    "    if verbose:\n",
    "        print('smart down', X_train[:,:z_d_dim].min())\n",
    "        print('smart up', X_train[:,z_d_dim:].min())\n",
    "    \n",
    "    \n",
    "    \n",
    "    # IMPUTE WRONGLY\n",
    "    treatment_effects_impute_wrongly = []\n",
    "    imputer_wrongly = get_imputer(missing_value)\n",
    "    imputer_wrongly.fit(X_train[:,:z_d_dim])\n",
    "    \n",
    "    X_train_preprocessed = X_train.copy()\n",
    "    X_test_preprocessed = X_test.copy()\n",
    "    \n",
    "    X_train_preprocessed[:,:z_d_dim] = imputer_wrongly.transform(X_train[:,:z_d_dim])\n",
    "    X_test_preprocessed[:,:z_d_dim] = imputer_wrongly.transform(X_test[:,:z_d_dim])\n",
    "    \n",
    "    est_impute_wrongly = learners[learner](missing_value)\n",
    "    est_impute_wrongly.fit(Y_train, W_train, X=X_train_preprocessed)\n",
    "\n",
    "    te = est_impute_wrongly.effect(X_test_preprocessed)\n",
    "    PEHE_impute_wrongly.append(evaluate(CATE_test, te, W_test))\n",
    "    \n",
    "    if verbose:\n",
    "        print('wrong down', X_train[:,:z_d_dim].min())\n",
    "        print('wrong up', X_train[:,z_d_dim:].min())\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "029231e4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ALL IMPUTATION 0.2641162154929491 0.11288001339236028\n",
      "NO IMPUTATION 0.2793748577304963 0.11555372203809124\n",
      "SMART IMPUTATION 0.24833807460990198 0.12194151808293517\n",
      "WRONG IMPUTATION 0.27935659139470287 0.11426520897085507\n",
      "-----------------\n",
      "ALL IMPUTATION 0.2861844708655413 0.13755625212984332\n",
      "NO IMPUTATION 0.2896694279534494 0.13411197726048069\n",
      "SMART IMPUTATION 0.27540835342994074 0.1562677549217369\n",
      "WRONG IMPUTATION 0.2921668215014991 0.1422390730587516\n",
      "-----------------\n",
      "ALL IMPUTATION 0.23831874782808704 0.0864665280728755\n",
      "NO IMPUTATION 0.267208085932947 0.09782183979627333\n",
      "SMART IMPUTATION 0.2152925363455575 0.0815447247922131\n",
      "WRONG IMPUTATION 0.2632941604831127 0.08599820621825298\n",
      "-----------------\n"
     ]
    }
   ],
   "source": [
    "PEHE_impute_all = np.array(PEHE_impute_all)\n",
    "PEHE_impute_nothing = np.array(PEHE_impute_nothing)\n",
    "PEHE_impute_smartly = np.array(PEHE_impute_smartly)\n",
    "PEHE_impute_wrongly = np.array(PEHE_impute_wrongly)\n",
    "\n",
    "for i in range(3):\n",
    "    print('ALL IMPUTATION', np.mean(PEHE_impute_all[:,i]), np.std(PEHE_impute_all[:,i]))\n",
    "    print('NO IMPUTATION', np.mean(PEHE_impute_nothing[:,i]), np.std(PEHE_impute_nothing[:,i]))\n",
    "    print('SMART IMPUTATION', np.mean(PEHE_impute_smartly[:,i]), np.std(PEHE_impute_smartly[:,i]))\n",
    "    print('WRONG IMPUTATION', np.mean(PEHE_impute_wrongly[:,i]), np.std(PEHE_impute_wrongly[:,i]))\n",
    "    print('-----------------')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3087169c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# RESULTS (temp, see below)\n",
    "\n",
    "# T-Learner (10 runs | n=5000:10000) XGBoost\n",
    "#                    PEHE                 std\n",
    "#   ALL IMPUTATION   0.5411206118067651   0.18198541669699742\n",
    "#   NO IMPUTATION    0.5368966745997649   0.19839874904963278\n",
    "#   SMART IMPUTATION 0.48921213504279787  0.16139291971437686\n",
    "#   WRONG IMPUTATION 0.49196703165837824  0.1635566026243228\n",
    "\n",
    "# X-Learner (10 runs | n=5000:10000) XGBoost\n",
    "#                    PEHE                 std\n",
    "#   ALL IMPUTATION   0.2405095450534244   0.062313804017148804\n",
    "#   NO IMPUTATION    0.25197845126111107  0.07592054733347772\n",
    "#   SMART IMPUTATION 0.25028434735400346  0.05551856896160914\n",
    "#   WRONG IMPUTATION 0.24208038749537458  0.06925279102004783\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# S-Learner (10 runs | n=5000:10000) XGBoost\n",
    "#                    PEHE                 std\n",
    "#   ALL IMPUTATION   0.2641162154929491   0.11288001339236028\n",
    "#   NO IMPUTATION    0.2793748577304963   0.11555372203809124\n",
    "#   SMART IMPUTATION 0.24833807460990198  0.12194151808293517\n",
    "#   WRONG IMPUTATION 0.27935659139470287  0.11426520897085507\n",
    "#   -----------------\n",
    "#   ALL IMPUTATION   0.2861844708655413   0.13755625212984332\n",
    "#   NO IMPUTATION    0.2896694279534494   0.13411197726048069\n",
    "#   SMART IMPUTATION 0.27540835342994074  0.1562677549217369\n",
    "#   WRONG IMPUTATION 0.2921668215014991   0.1422390730587516\n",
    "#   -----------------\n",
    "#   ALL IMPUTATION   0.23831874782808704  0.0864665280728755\n",
    "#   NO IMPUTATION    0.267208085932947    0.09782183979627333\n",
    "#   SMART IMPUTATION 0.2152925363455575   0.0815447247922131\n",
    "#   WRONG IMPUTATION 0.2632941604831127   0.08599820621825298\n",
    "\n",
    "\n",
    "# TODO\n",
    "# - DONE split PEHE to W=1/0\n",
    "# - rerun with 100 runs to make sure, we want 110% \n",
    "#     reprodocible as we are providing explanations\n",
    "# - ATE; likely the effect from bias will be greater here\n",
    "# (- combine T/X-Learner)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "3d228fc2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjwAAAAVCAYAAABCMzRNAAAACXBIWXMAAA7EAAAOxAGVKw4bAAAQV0lEQVR4Ae2d7bEctRKGF5cDMCYCIANsZwAZ2BCBIQMofpl/FGQARADcDIAI+MgAbgSYk8G97zOrHms0GunV7NrYrlGVVl+tt1utltSjnT3njSdPnnx6Op0+USQ8+uKLL/44Z5efqn9H8a9l7VE6NHBo4NDAoYFDA4cGDg28eA30/BK1L/ybWxLxLcVP1PCu4pazQ6f3XvxwDo6viwZkWyv7Ud0dxXdelzG+yHG87PpkXhU/VrxzLb08D8xryXbgvBwakI289PvMte342ngvx0zaUrDP4J9Ug9q+VnxXjf9RvHu7SpVVivihim/RMauesqr7KtX9rRTQr1Q3dAskegz0R8V7yt8oXQTVsWF+nlVS/lH1P2d1q2zCxZGL26sVDRXXpEtYm2NJ/IZ0JsygD/m/V93smCqPwxBjRDeUmYeFflw6mIh2SOeib8oIpsIvCTdkhwfh3jlZfvYwR2Vcou8rZTIN27v6Mi+fJc73lT6lrPrQR2o6J6pvrgtR2foU1jcZ+F3lH6vuJqubsgN0zliQH77fCHfCLz5uVP9m1Cnv2PEQJtjC7ekxRLBphenYe4nb3It6cqrd0fnM8xqZbJzD9h78jXHZe40pj70uQkbSJGd1jtTm2GYON+UbmLYdm7xtvDTWIVuSDF17z2gsW8noQ2+Lcy0qRddcv2r/GRpFHq6+jX5badPhEcBkjEpXh5Lqfhfol0rxnE5Kof1d6QeKTacn0X4nejZ9Nn8mYCtweMeBPtGojMPD7cDEe6Mjjkf1MCnoL6JDDuFZYxGtrTPRohNk41CcnBellIk4lyeV4U37rB/lcVB/UsrXk/ncdOnATMHSufC7MgagUuaagAFjH8iG/dwoncMApiXjDHxhRnLZc1eySmPi4P8g2pRnE4n1EvNr25L6dvUpHuAhN7ynBxal6P+/SnnAmNapUotO/U6iZc67YxHNA0XGVdsL3ld9OH9gwt+xTwsz4blr0tZ5GntzTWoctUCf1V7kypn4Ojqv8d5VJ56X2LutUwlnreMBebrrYkMhrTlybLMGW8UU4YgdO7wtPAQcsaVE27X3gbkJ/k1M4Y3YD5jc4rCX/qB4U5uIqGs6PCJiU86fDqd+Av1YmYXDASNFDjHo54196lB8JKEeUa385tdlaoMPC68Mj1Xxi2LV4UmYZZ9V+Rp0wkDB7lhGdIZRfC/8/KYGQ8gPEPSDZ4uDE7oI+s/VFnUu3Uk4Izp3ZJQYU/hD2JOeomIj7WIOyrjBxq9O/EbmrgRnHc1OKY3CZCND14x3uulQ2bIl+is4+uTQvyvc+XZWefr9pvp8nbp08LXGAqH4rPYB1U0PN0rDNiEdsc8uprBtPY7QSs6ubTKYPAi/dd3uymnrPOe9Ny+ZmY/d9u7qNPHp7u+D8jjrYqEa4W/OkQht28xBO5gntXfteIS3iYeII7bUtXfxHbUVB9NdF7nK2c9WY8sJyN8qK4ryhxpQ7ZqIg2v1xKK6XxXfV587Bc7eIjcZK8NogYk3T7AojLgZrk23yehZg60zycYtDeNY6F713J7l+mAOFmNVe23cLh3SWjofkBFMKwxgWjJaTD0ie+424LjR+FPju1O045xysExOQNF2jSJ2lDvIgYk95OvUpaO/Oxb2glpYPdGLyLXPEcwa7911A7Y581Afay+aO2xnXJ1vI4y1XGrvLjd3HT83eYw5cm1zHrOB6dqxy9vFQ0bLlgbs3Z6bAcxZl25G2JyV+CvlHruA2HR4knC1zRIAlPZ0gXQuBD3t1whM5EPJwg1GPhA8OTy6WvhItAtHoUakumvTbbCZq0d0xm3AjcZRc15mQLXz/eWbpFGpPIcXYdaPS3fuNjmtjs4tGROmm7iYe+zClaFGNzJ3tf7Mz1+ah635zG271n+4TrwCs7ZO/06A9126TABrLMLNb3Cm7qpj3X6ZYUW9a8c2ZsnjCmXXNnNW7h6T96nlLZ3XOu6su9TeXbbuOn6e8jTnSDZr2WYx4B6mZccu75G1JjldW3LtfWRuXMxCnXYR/+PDFvXtRiM3CShnEaTc2EgX9UXhblHeVWQi02RyiP+jPN/781SQf4UzY6udq8n5oJ8bisy16Qr4VVH8RnV2XyAckDwhfqTIAcW4my9rix7jm671lN90+lp0anN1PiyjsKfrT8nILwO51eAdHp5iIliYAzIG7u5UvEbnbsVLGDwF1QLze1J7roMaXbVO/Tb1qTYcZvrV1iL6J/ALBzZ08k06CAii3TUW9WO+eblwfndnAqx8iMa1Yxuzwma0yrLNANUYrL0o6FupsHbpvIW51SZeF9v7FnZZL17dvWaPPOqzuS5yGUQ3PEfq07TNnZiWHfd4x9hEt4mnNteWuvYurFFb6WLGGHam+Cv4LZtn3+0GMMLVnIfYGG8afR1FNLo/a2KCFJEDI+Ywx4sjXQTRMMls8nHLtGiPwrXpAreTjuos9McT+HxAKI/Txy9sFk8IKr8n/izEB4ocnr8prsIAnaPzIRklDPTzS2WShfniax6+pgvH2sZUH0fGlQ52VIzOncVC8jNn6GCeX6vjMyJHn9gJdlEGeBNC3y7duVfxaY6FNbtatzlUwunacdani5nRXpoNXXXXpMZh7UWXCGTqfA+L52LvW4JoHL11PCqPsy5Oo3OU9N20zVHMTCdNO3Z4Z1hkm3gF7Snhl3uRY+975gb23TVUymiW/xRdbb+bu9+ac+sMA65dh68p1zXxBLluGazRZDxUlxtFbjg4GJkYDkrq88BPCjc9u4zw2nQZ9EXZSWcaQxgaT8PleH4Qh+8ymomhyryox5vqeO/fK/LGeqmfk+pcuqbOhbNHRhwb5nEKyuOYMp+TUz2KKfqmjGcuL+xzj73/KOl4wp1fKB6RVv2a+kxYj0lFO28CyuPsxDzEw4FLl2BXSXMs4sma5Z2hcGxXAFSo3bLPRGthVhkNVkquUXt395hBSRbkTZ0vKK9f2GPvVSmk22us41ke4TnrAlmG5si0zSFMhBBu145N3sBZeBPh8mNhS+I3au9LtGVp97m2hLFK+CvoczPc2mw5X3HHxpiTtZyg8PjiHYG833BeiudWBwP+TJGveLiuiiu5+eBPdLXbqAXPa9MtwNuFPTqLwyhH5hcNGOP9vDLPa4w8rTNv00/387Y8v0WXdNTVecLaJWMmB/35WiU30i7moIwZu13ZPXPXZCT5sVXsOWy5ST/QuNCn8LGDtxV5iv5UkSc/dP2rImHStUt37rL8VF9nLHx3X5vXJVhWEm7PjocxM/i92doYFmtScrNndfeivQLQz9T5XhZXt/ctQZKuenvNNeQp18VFcyS5V7aZxrJn3ofsuMa70O8oXmv99ux9z9z0MIvhDBXB5nzcDC2Hp9pJCmcTJdSAo642qHOvsU826MWVf5pwbnvg9b7KbOD80qXJ89p04mkH8bZ1ltFGnxofxnwSLbdAPLGX4bdUMT3Zu3Spj6PzkC3Skj/lSUYy4s87VxwMW4H5C6xIa7SB2ZWx1nlPXSbXnUr/qGvaXt5PeGy2/FQc531XUN+uPgMY+RV58uQGkAcHNut4Ip7lVr1FF7ik6uOOhaf4zc1ROJYd57yVb2IWtBcV0U0CiLSGF457dy+qdXbrBnTuQi7osrGGbeftUTfbTd64I99dxyPyiLa7LkRjnRcxFtF3bXMUM7BTumnHDu8Ci+ImXkkr/Or6VX3YeaRlV8rYe7SHXeR0UTfZSkYbfXLayMf+HuXRlAuXFv7pdgORDSqELsniq6WyPm54mlfXZadaWQqCd34QzmRq4+mYjRt+KOmBylzL5QFHgEmhHqX/pHg1OuEuHDFh98KIzrZog0dsOJMTIVn4pVZroi06Ybg6Rw5XRmjvK9YOvMlexPcPiBS6mIMynlEv/9ySa8jeJTubEf/CZb7ZUX5a5EpjTh1pXX1uYbE2eFm5ZTP03aRTX2ssosOmGGPMMbhlsOwzOpmYQX6tdMsGAp/5s/YiyT+6d0w8XJ2HQBekW2MdsvcW/zSHzv4OjCtPd12ILw+B3XMgmyPHNnfNe9JBa204vGc1G3g5bW/9buk8MGK/2qKr2coWbYkZ5dGUvaZ2zsw4LYcnFvBMnGVwIvDOy3BPFXwP39tIy36rMhgp4rSEcnM6BsemTRuKXATV/6MK2ufDhfKCSIUL6Uq4VnlEZ1wzlg4c2OgXvcQ40HPt4GLhE4bohGvp/Aw9Xds7MkL+rbBrmzybT8gInTXuJGfTLgCLIPrqxhrtRjoyd1U4yYDzwEZb6oGNp3xXq4qRVVr6FC+w+aOCbys/rUmlrBv0ji1NwaXL6EfGErbY2ohcOw4RHMygvVZq2aaY5fY88ZZ+a3vRkFzCsHXOHCtO8z3E5Bnxxfb+DKqeQ74UnXXsytNdF+LJ/IzMUdc2k65HMEMpPTvu8g6glPbwJjLJ69iSa+/u3MDbxUzDGU7uqkfNV5iBbs25dYYnsgfr6slJYIN+KsU9jHbl7yjPb+CnFyCpp07xf4qTp0pdJcT1OsKWAWdl9S6K8LiK4/qyNTjkIfbCNek2xyJZLZ0hrGi5vcKRmZ1K5Vf6FQ2HJzdXcxAdcwItX2PcpAaXDnJL58J2ZQSTP4uPsc9BZX4SSpgd0gFMS0bAhYku+HVbywYh3Qzqa80dvBRX9q46nuLYGGifdBGp6vJ5ymXYtCURWfoUHXxLRwM54Mn6juDSndRvdCzonxC2eC4tP0fsk54OZnBo6TFoIt2k1bhH7D3wIkXekDnqyrTF29a55ITPv2rvxcA2xyU6ax1rTNb6E567LgoRp+LWHI3aZo69hRk0YRM3UVGko7x7eCfp0rIl0Vn2PjA38LYwCx207KcgnR7i8n2tbD+98eTJEw5VnIeFh6oyXiDOBu/LrILqUS59may/FR8ofqn6BUOV+akYhzcvU81BZTZeAk+bYNEPBwZZMPApKI8cnys+PddMn5t/j0b0HKxMKrgElPyr6he/hrkmnbDcsVg6m6TWh3DRL30IOIQ1/TLO2WlQnrHz12zL+bTo1Be+ts4dGRMmcrGACYyF+eSdEuxnERzMQRmxQUL1H9Sem9qf4mfNnehW9p7qGH8tcCN6LxqUd23J0qfwwmkOO6qunQE6xmeNhTEJF1qczcfKsxarQW0j9tnFFJ6lR4QZpO2uyRigcLt7kcNbNKM6/1ftfUSnGpu114jOXX/WuhicI9s209i7857oHDu2eUtHDt6oLXXt3Z2bTOcOpr1+M1zGxsPc4uxLuobnT5sOTyIC4JEAFk4MbUc4NPAqaUA2zM1X7eu/V2kYh6yHBiwNHPZuqekgek00IHvH2eOyZOuCZnJ4bnXGC9HiZqZDfzQfGnhZNcD7M6vbpJdV2EOuQwMXauCw9wsVeHR/pTTAtwf4K83QdHh0QPDVEi+V4T0d4dDAK6kB2S/X4XzteoRDA6+9Bg57f+2n+BhgpoHkn+CnzK/CZM2LbNPhSZS8H8J3kkc4NPCqauBjLYbFO1yv6kAOuQ8NGBo47N1Q0kHy2mgA/8T6Jire4YlfW1Xf10ke1MPj0HhtDOQYyKGBQwOHBg4NHBp4pTUgn4Rf+/Iveqq/2E7tOEP8UObR/wFig3Lg1A4LQQAAAABJRU5ErkJggg==\n",
      "text/latex": [
       "$\\displaystyle \\left[ \\left( 0.184441063386568, \\  0.215992737116412, \\  0.180543473726061\\right)\\right]$"
      ],
      "text/plain": [
       "[(0.18444106338656757, 0.21599273711641193, 0.1805434737260608)]"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "PEHE_impute_wrongly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "a5b6424e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAKoAAAAPCAYAAAB0p1TfAAAACXBIWXMAAA7EAAAOxAGVKw4bAAAF60lEQVRoBe2a/3EVNxDHnxkX4JgKgA7AVIDTAYQKCB2Qf/0fQzoAKuBHB5AKCHQAqSDBHTifj96topN1dzq/B5OZZGf0VlqtvtpdraS7sw8uLi42/9O/JwJnZ2e3KZ9Ki2gf0T6Gfynl/6X6YeksgXg2tP+E36I86w0OejfR/2UYfwL/yzbyUdCH/g3ymCtEr6Z0VaDvNuwx/HEMKPnQ/wbZHernZV9d79VFb9ZG+lf5rB0dc/+GjokZcbMu3dmy6d8Buxkj+p4XI4+pP0I2ihPtVf70YBZzpipjJtdxDi8nKkofQXoKfysi3AB9hP9Imd3J9Ovgc3UdK1F3kWP8+yTcytU1oUziJIfbtrg5psj+WLykwzhtfElxU7g5xG7SSt1FG8Fb43O3nRivL5ILatxdD9dllFTIWjQVI9fW9fnVQXCx/4C7qdPawtf6s4jpXA26ko3XBMLIn2FH8JSkg8zA2C53ol0tMilHJx1Ynq5iaFhJtj09c/LSdiEnNwO6T0qAqCM/pzygOPerkLf4Gl3G99jY7fPKuT+h/wPlgHKL4oZeTFJ0mjHCFzeyjw0pSY0NdTf875Rybbv9YVwvJqr/0C42pkQF6gFldFoN8B/gp0xgIs3RKZ2fG3omoxvA3bqB34e5m1/YDkLuqZ1P45DLkavvQi0ulvq7EvP12tjl86729IxfiJH+tA4B17tc2zX+9GJm83e1MRJVI+PKyeBUwkH758iE/IIxU8kUie7J5yk4pdea4yH6o8RuKe1R1mtjr897NG0Sqhkj4hZxb62t7yHSyZZtuvxZiTlAJ7aTjYfFxCVoXT+uBWUbDE/kFnkabuiP09qgmNDKH1IMls+lb5AZqBEh8zorr6dR/zdqdNmIbb0+rzYT7PQoxsDrFG8jn1EjhiO8uRjR56Ggfmv9xJbituvyZw3mFj6t/+Q69uIdAhZOzJ1ysTNj7kWOASajQYgvAY4JnBP6s5z6V4pvoeUzsmMNdJzqjv8e1G1jbQy2tnyu1Zbazv8arHMV4cbBxyofj0abeehbipExbd2I2iqFv9tW8Qv+lD/dmPuy8Vph11w1dt+cTt3nC8lbDI03zQiI3wnrq/w1ui+Rh45Yfmap9ZR/Myrm77WxtmXkc93Z08YGEzIlqfrU3agmaOtm6YnRowEnJyuYJmDMMXcQTPmzBnMvNpqorecXfZOOtyxd0UN1mREIg+oV37pOWoHxU4dJmp6XGOfV11oYxN+FFm2srVjwuVZf29aem8zh6ZqoN0bomZA3KH4deULx7V4cX5Sllq8b9CbXkL4uTPS61rEH71AligabKDWFrOlMrWwbLI3zc8joLZ52zKOTUxSLcYR+95xTYGvlvTbWuIxr+lzrLbXBeYeOsZv6uJ/Wg34TrTtG6Bvz+vOhCStdijP6i/4sYe7bRp9RJa+WvFuTZPsTJ+ro2ajoH1Ux7j4Cv/3lk3QweAM3IFPzBI462nEXfa+dktLz0iD3tM7PuKXSHuo9NuZpsGPJ56zbUfFGad1waR2YK16o9hEj4/kezNHBsaM/GRMcfdl1HTNeJKpJETuMaiZ3th+gR87k3qKCjqAaVieQCxnPml4ndQKK4jyeuLEhgtuXiL6vVAxs3gRD175Zr40bbOnxeY19L8Cs4+d4ny9zTNCxntsqSMgvxQiZ8fcD/Q3qaR3hR7TFHJ3cyLv86cFEZ682pkQFNAVIAyjpzRuuMz9R7lESDTKDYfJmJ6m7w01AE6l+tjxFFi9Uvlyp4/8QpAWBxzzpAT1N1P5RzzJF14eOY/j5lNIgn9TFni4b0evyuWHH5Nzo+mdOS76mqcdfnHo2aCtG2lmf0q6VLzlxQm+or/GnC7Phu6Ir2XgQ/z2FoQJ4qrrIft+8S3laOkNbhz7DTLYymMo0vkWjpFaBsc7jfJKJdWme1MMPuia+2J4AkhvpA/JI/jih7RfT4PsI8Q6dOMlpJqw1urM2gr3W5665wdXXOFWNjUk2+2dUxizFKG7LiPml79ZX8GcRE7sz7Wrj32/M3foVpfJ+AAAAAElFTkSuQmCC\n",
      "text/latex": [
       "$\\displaystyle 0.264116215492949$"
      ],
      "text/plain": [
       "0.2641162154929491"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.array(PEHE_impute_all)[:,0].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a61f7eb",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mcm (3.7.9)",
   "language": "python",
   "name": "mcm"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
