{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import scipy\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.mlab   as mlab\n",
    "\n",
    "from kte import kernel_two_sample_test_nonuniform\n",
    "from xkte import kernel_two_sample_test_agnostic\n",
    "from xkte import kernel_dr_two_sample_test_agnostic\n",
    "from dr_kte import kernel_dr_nonuniform\n",
    "\n",
    "from baselines import vanilla_dr_baseline_test\n",
    "from baselines import BART_baseline_test\n",
    "from baselines import CausalForest_baseline_test\n",
    "\n",
    "from sklearn.metrics import pairwise_distances\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "\n",
    "from scipy.spatial.distance import cdist\n",
    "from scipy.special import expit\n",
    "from scipy.stats import bernoulli\n",
    "\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "import time\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preprocess data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv(\"idhp.csv\", index_col=0)\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Index(['iqsb.36', 'dose400', 'treat', 'bw', 'momage', 'nnhealth', 'birth.o',\n",
       "        'parity', 'moreprem', 'cigs', 'alcohol', 'ppvt.imp', 'bwg', 'female',\n",
       "        'mlt.birt', 'b.marry', 'livwho', 'language', 'whenpren', 'drugs',\n",
       "        'othstudy', 'mom.lths', 'mom.hs', 'mom.coll', 'mom.scoll', 'site1',\n",
       "        'site2', 'site3', 'site4', 'site5', 'site6', 'site7', 'site8',\n",
       "        'momblack', 'momhisp', 'momwhite', 'workdur.imp', 'bwg.1', 'female.1',\n",
       "        'mlt.birtF', 'b.marryF', 'livwhoF', 'languageF', 'whenprenF', 'drugs.1',\n",
       "        'othstudy.1', 'momed4F', 'siteF', 'momraceF', 'workdur.imp.1'],\n",
       "       dtype='object'),\n",
       " (985, 50))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.columns, df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "covs_cont = [\"bw\",\"momage\",\"nnhealth\",\"birth.o\",\"parity\",\"moreprem\",\"cigs\",\"alcohol\",\"ppvt.imp\"]\n",
    "covs_cat = [\"bwg\",\"female\",\"mlt.birt\",\"b.marry\",\"livwho\",\"language\",\"whenpren\",\"drugs\",\"othstudy\",\"mom.lths\",\"mom.hs\",\"mom.coll\",\"mom.scoll\",\n",
    "            \"site1\",\"site2\",\"site3\",\"site4\",\"site5\",\"site6\",\"site7\",\"site8\",\"momblack\",\"momhisp\",\"momwhite\",\"workdur.imp\"]\n",
    "ty = ['iqsb.36', 'treat']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "covs_cont = [\"bw\",\"momage\",\"nnhealth\",\"birth.o\",\"parity\",\"moreprem\",\"cigs\",\"alcohol\",\"ppvt.imp\"]\n",
    "covs_cat = [\"bwg\",\"female\",\"mlt.birt\",\"b.marry\",\"livwho\",\"language\",\"whenpren\",\"drugs\",\"othstudy\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(9, 9)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(covs_cont), len(covs_cat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "df1 = df[covs_cont + covs_cat + ty].dropna()\n",
    "dX = df1[covs_cont + covs_cat].copy()\n",
    "dX = dX.dropna()\n",
    "scaler = StandardScaler()\n",
    "dX[covs_cont] = scaler.fit_transform(dX[covs_cont])\n",
    "X_original = np.array(dX[covs_cont+covs_cat])\n",
    "\n",
    "T_original = np.array(df1['treat'])\n",
    "Y_original = np.array(df1['iqsb.36']) / np.linalg.norm(np.array(df1['iqsb.36']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(908, 347)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(T_original), T_original.sum()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "noise_var = 0.5\n",
    "beta_vec = np.ones(dX.shape[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:03<00:00, 30.26it/s]\n",
      "100%|██████████| 100/100 [00:02<00:00, 45.31it/s]\n",
      "100%|██████████| 100/100 [00:03<00:00, 29.78it/s]\n",
      "100%|██████████| 100/100 [00:02<00:00, 39.07it/s]\n",
      "100%|██████████| 100/100 [00:03<00:00, 26.27it/s]\n",
      "100%|██████████| 100/100 [00:02<00:00, 36.95it/s]\n",
      "100%|██████████| 100/100 [00:03<00:00, 26.03it/s]\n",
      "100%|██████████| 100/100 [00:02<00:00, 37.21it/s]\n",
      "100%|██████████| 100/100 [00:03<00:00, 26.74it/s]\n",
      "100%|██████████| 100/100 [00:02<00:00, 38.91it/s]\n"
     ]
    }
   ],
   "source": [
    "name_folder = 'data_ihdp/'\n",
    "size_subset = 500\n",
    "num_experiments = 100\n",
    "iterations=100\n",
    "\n",
    "b_list = ['I', 'II', 'III', 'IV', 'V']\n",
    "# method_list = ['Vanilla_DR', 'BART', 'CausalForest', 'KTE']\n",
    "method_list = ['DR-xKTENEW', 'IPW-xKTENEW']\n",
    "experiment = False\n",
    "\n",
    "np.random.seed(0)\n",
    "\n",
    "for b in b_list:\n",
    "    for method in method_list:\n",
    "        p_values = np.zeros(num_experiments)\n",
    "        values = np.zeros(num_experiments)\n",
    "        times = np.zeros(num_experiments)\n",
    "        for n in tqdm(range(num_experiments)):\n",
    "            \n",
    "             \n",
    "            idx = np.random.choice(np.arange(dX.shape[0]), size=size_subset, replace=False)\n",
    "            X = X_original[idx,:].copy()\n",
    "            T = T_original[idx].copy()\n",
    "\n",
    "            if experiment:\n",
    "                w = np.zeros(X.shape[0]) + 0.5\n",
    "            else:\n",
    "                w = LogisticRegression(C=1e6, max_iter=1000).fit(X, T).predict_proba(X)[:, 1]\n",
    "                #w = Prob_vec.copy()\n",
    "\n",
    "            Y = np.dot(beta_vec,X.T) + noise_var*np.random.randn(X.shape[0])\n",
    "\n",
    "            if b == 'I':\n",
    "                b1 = 0\n",
    "                Y[T==1] += b1\n",
    "            elif b == 'II':\n",
    "                b1 = 1\n",
    "                Y[T==1] += b1\n",
    "            elif b == 'III':\n",
    "                Z  = bernoulli.rvs(0.5,size=len(T[T==1]))\n",
    "                beta = 2.\n",
    "                b1 = (2*Z - 1)*beta\n",
    "                Y[T==1] += b1\n",
    "            elif b == 'IV':\n",
    "                beta = 4\n",
    "                b1 = np.random.uniform(-beta, beta, len(T[T==1]))\n",
    "                Y[T==1] += b1\n",
    "            elif b == 'V':\n",
    "                Y = Y_original[idx]\n",
    "            else:\n",
    "                print('b not recognized! Setting b1 = 0.')\n",
    "                b1 = 0\n",
    "    \n",
    "            \n",
    "            \n",
    "            YY0 = Y[T==0]\n",
    "            YY1 = Y[T==1]\n",
    "            \n",
    "            Y = Y[:,np.newaxis]\n",
    "            YY0 = YY0[:,np.newaxis]\n",
    "            YY1 = YY1[:,np.newaxis]\n",
    "        \n",
    "            \n",
    "            # Gaussian RBF kernel\n",
    "            sigma2 = np.median(pairwise_distances(YY0, YY1, metric='euclidean'))**2\n",
    "\n",
    "            if method == 'DR-xKTE':\n",
    "                t0 = time.time()\n",
    "                value, p_value = kernel_dr_two_sample_test_agnostic(Y, X, T, w,\n",
    "                                                                    kernel_function='rbf',\n",
    "                                                                    gamma=1.0/sigma2,\n",
    "                                                                    verbose=False)\n",
    "                times[n] = time.time() - t0\n",
    "                p_values[n] = p_value\n",
    "                values[n] = value\n",
    "\n",
    "            elif method == 'IPW-xKTE':\n",
    "                t0 = time.time()\n",
    "                value, p_value = kernel_two_sample_test_agnostic(Y, T, w,\n",
    "                                                                    kernel_function='rbf',\n",
    "                                                                    gamma=1.0/sigma2,\n",
    "                                                                    verbose=False)\n",
    "                times[n] = time.time() - t0\n",
    "                p_values[n] = p_value\n",
    "                values[n] = value\n",
    "\n",
    "            elif method == 'KTE':\n",
    "                t0 = time.time()\n",
    "                mmd2u_rbf, mmd2u_null_rbf, p_value = kernel_two_sample_test_nonuniform(YY0, YY1, T, w,\n",
    "                                                                    kernel_function='rbf',\n",
    "                                                                    gamma=1.0/sigma2,\n",
    "                                                                    verbose=False,\n",
    "                                                                    iterations=iterations)\n",
    "                times[n] = time.time() - t0\n",
    "                p_values[n] = p_value\n",
    "            \n",
    "            elif method == 'Vanilla_DR':\n",
    "                T = T[:,np.newaxis]\n",
    "                t0 = time.time()\n",
    "                vanilla_dr = vanilla_dr_baseline_test(X, T, Y, iterations)\n",
    "                p_value, value = vanilla_dr.permutation_test()\n",
    "                times[n] = time.time() - t0\n",
    "                p_values[n] = p_value\n",
    "                values[n] = value\n",
    "\n",
    "            elif method == 'BART':\n",
    "                T = T[:,np.newaxis]\n",
    "                t0 = time.time()\n",
    "                bart = BART_baseline_test(X, T, Y, iterations)\n",
    "                p_value, value = bart.permutation_test()\n",
    "                times[n] = time.time() - t0\n",
    "                p_values[n] = p_value\n",
    "                values[n] = value\n",
    "            elif method == 'CausalForest':\n",
    "                T = T[:,np.newaxis]\n",
    "                t0 = time.time()\n",
    "                causal_forest = CausalForest_baseline_test(X, T, Y, iterations)\n",
    "                p_value, value = causal_forest.permutation_test()\n",
    "                times[n] = time.time() - t0\n",
    "                p_values[n] = p_value\n",
    "                values[n] = value\n",
    "            else:\n",
    "                print('Method not recognized.')\n",
    "\n",
    "\n",
    "\n",
    "            res = pd.DataFrame()\n",
    "            res['times'] = times\n",
    "            res['p_values'] = p_values\n",
    "            res['stat_values'] = values\n",
    "            res.to_csv(name_folder + 'b' + b + method + '.csv')\n",
    "\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define Scenario VI, where the ATE is substracted."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.003471902998471982"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.random.seed(0)\n",
    "\n",
    "causal_forest = CausalForest_baseline_test(X_original, T_original, Y_original[:,np.newaxis], 100)\n",
    "ate = causal_forest.ref_stat\n",
    "ate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:02<00:00, 34.67it/s]\n",
      "100%|██████████| 100/100 [00:01<00:00, 52.12it/s]\n"
     ]
    }
   ],
   "source": [
    "name_folder = 'data_ihdp/'\n",
    "size_subset = 500\n",
    "num_experiments = 100\n",
    "iterations=100\n",
    "\n",
    "method_list = ['DR-xKTE', 'IPW-xKTE', 'Vanilla_DR', 'BART', 'CausalForest']\n",
    "experiment = False\n",
    "\n",
    "np.random.seed(0)\n",
    "\n",
    "causal_forest = CausalForest_baseline_test(X_original, T_original, Y_original[:,np.newaxis], 100)\n",
    "ate = causal_forest.ref_stat\n",
    "\n",
    "for method in method_list:\n",
    "    p_values = np.zeros(num_experiments)\n",
    "    values = np.zeros(num_experiments)\n",
    "    times = np.zeros(num_experiments)\n",
    "    for n in tqdm(range(num_experiments)):\n",
    "\n",
    "\n",
    "        idx = np.random.choice(np.arange(dX.shape[0]), size=size_subset, replace=False)\n",
    "        X = X_original[idx,:].copy()\n",
    "        T = T_original[idx].copy()\n",
    "\n",
    "        if experiment:\n",
    "            w = np.zeros(X.shape[0]) + 0.5\n",
    "        else:\n",
    "            w = LogisticRegression(C=1e6, max_iter=1000).fit(X, T).predict_proba(X)[:, 1]\n",
    "\n",
    "        Y = Y_original[idx]\n",
    "        Y[T==1] -= ate\n",
    "        YY0 = Y[T==0]\n",
    "        YY1 = Y[T==1]\n",
    "\n",
    "        Y = Y[:,np.newaxis]\n",
    "        YY0 = YY0[:,np.newaxis]\n",
    "        YY1 = YY1[:,np.newaxis]\n",
    "\n",
    "\n",
    "        # Gaussian RBF kernel\n",
    "        sigma2 = np.median(pairwise_distances(YY0, YY1, metric='euclidean'))**2\n",
    "\n",
    "        if method == 'DR-xKTE':\n",
    "            t0 = time.time()\n",
    "            value, p_value = kernel_dr_two_sample_test_agnostic(Y, X, T, w,\n",
    "                                                                kernel_function='rbf',\n",
    "                                                                gamma=1.0/sigma2,\n",
    "                                                                verbose=False)\n",
    "            times[n] = time.time() - t0\n",
    "            p_values[n] = p_value\n",
    "            values[n] = value\n",
    "\n",
    "        elif method == 'IPW-xKTE':\n",
    "            t0 = time.time()\n",
    "            value, p_value = kernel_two_sample_test_agnostic(Y, T, w,\n",
    "                                                                kernel_function='rbf',\n",
    "                                                                gamma=1.0/sigma2,\n",
    "                                                                verbose=False)\n",
    "            times[n] = time.time() - t0\n",
    "            p_values[n] = p_value\n",
    "            values[n] = value\n",
    "            \n",
    "        elif method == 'KTE':\n",
    "            t0 = time.time()\n",
    "            mmd2u_rbf, mmd2u_null_rbf, p_value = kernel_two_sample_test_nonuniform(YY0, YY1, T, w,\n",
    "                                                                kernel_function='rbf',\n",
    "                                                                gamma=1.0/sigma2,\n",
    "                                                                verbose=False,\n",
    "                                                                iterations=iterations)\n",
    "            times[n] = time.time() - t0\n",
    "            p_values[n] = p_value\n",
    "\n",
    "        elif method == 'Vanilla_DR':\n",
    "            T = T[:,np.newaxis]\n",
    "            t0 = time.time()\n",
    "            vanilla_dr = vanilla_dr_baseline_test(X, T, Y, iterations)\n",
    "            p_value, value = vanilla_dr.permutation_test()\n",
    "            times[n] = time.time() - t0\n",
    "            p_values[n] = p_value\n",
    "            values[n] = value\n",
    "\n",
    "        elif method == 'BART':\n",
    "            T = T[:,np.newaxis]\n",
    "            t0 = time.time()\n",
    "            bart = BART_baseline_test(X, T, Y, iterations)\n",
    "            p_value, value = bart.permutation_test()\n",
    "            times[n] = time.time() - t0\n",
    "            p_values[n] = p_value\n",
    "            values[n] = value\n",
    "        elif method == 'CausalForest':\n",
    "            T = T[:,np.newaxis]\n",
    "            t0 = time.time()\n",
    "            causal_forest = CausalForest_baseline_test(X, T, Y, iterations)\n",
    "            p_value, value = causal_forest.permutation_test()\n",
    "            times[n] = time.time() - t0\n",
    "            p_values[n] = p_value\n",
    "            values[n] = value\n",
    "        else:\n",
    "            print('Method not recognized.')\n",
    "\n",
    "\n",
    "\n",
    "        res = pd.DataFrame()\n",
    "        res['times'] = times\n",
    "        res['p_values'] = p_values\n",
    "        res['stat_values'] = values\n",
    "        res.to_csv(name_folder + 'b' + 'VI' + method + '.csv')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
