{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import scipy\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.mlab   as mlab\n",
    "\n",
    "from kte import kernel_two_sample_test_nonuniform\n",
    "from xkte import kernel_two_sample_test_agnostic\n",
    "from xkte import kernel_dr_two_sample_test_agnostic\n",
    "from dr_kte import kernel_dr_nonuniform\n",
    "\n",
    "\n",
    "from baselines import vanilla_dr_baseline_test\n",
    "from baselines import BART_baseline_test\n",
    "from baselines import CausalForest_baseline_test\n",
    "\n",
    "\n",
    "from sklearn.metrics import pairwise_distances\n",
    "\n",
    "from scipy.spatial.distance import cdist\n",
    "from scipy.special import expit\n",
    "from scipy.stats import bernoulli\n",
    "from numpy.polynomial.polynomial import polyval\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "import time\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Main function, which runs a list of tests based on its arguments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_tests(b_list, method_list, case_list, ns_list, experiment, name_folder, num_experiments, iterations):\n",
    "    noise_var = .5\n",
    "    # generate data from the marginal distributions P(X_0) and P(X_1)\n",
    "    d  = 5\n",
    "    # generate Y_0 and Y_1 from the conditional models\n",
    "    beta_vec  = np.array([0.1,0.2,0.3,0.4,0.5])\n",
    "    alpha_vec = np.array([0.05,0.04,0.03,0.02,0.01])\n",
    "    alpha_0   = 0.05\n",
    "\n",
    "    np.random.seed(0)\n",
    "    \n",
    "    for b in b_list:\n",
    "        print('b = ', b)\n",
    "        for method in method_list:\n",
    "            for case in case_list:\n",
    "                for ns in ns_list:\n",
    "                    p_values = np.zeros(num_experiments)\n",
    "                    values = np.zeros(num_experiments)\n",
    "                    times = np.zeros(num_experiments)\n",
    "\n",
    "                    for n in range(num_experiments):\n",
    "\n",
    "\n",
    "                        ### generate data \n",
    "                        X  = np.random.randn(ns,d)\n",
    "\n",
    "                        if experiment:\n",
    "                            Prob_vec = np.zeros(ns) + 0.5\n",
    "                            a = np.concatenate((np.repeat(0, ns//2), np.repeat(1, ns-ns//2)))\n",
    "                            T = np.random.choice(a,replace=False, size=len(a))\n",
    "                            w = Prob_vec.copy()\n",
    "                        else:\n",
    "                            Prob_vec = expit(np.dot(alpha_vec,X.T) + alpha_0)\n",
    "                            T  = bernoulli.rvs(Prob_vec)\n",
    "                            N2 = len(T) // 2\n",
    "                            w = LogisticRegression(C=1e6, max_iter=1000).fit(X, T).predict_proba(X)[:, 1]\n",
    "\n",
    "                        if case == 1:\n",
    "                            Y = np.dot(beta_vec,X.T) + noise_var*np.random.randn(X.shape[0])\n",
    "                        elif case == 2:\n",
    "                            Y = np.cos(np.dot(beta_vec,X.T)) + noise_var*np.random.randn(X.shape[0])\n",
    "\n",
    "                        if b == 'I':\n",
    "                            b1 = 0\n",
    "                        elif b == 'II':\n",
    "                            b1 = 0.5\n",
    "                        elif b == 'III':\n",
    "                            Z  = bernoulli.rvs(0.5,size=len(T[T==1]))\n",
    "                            beta = 1.\n",
    "                            b1 = (2*Z - 1)*beta\n",
    "                        elif b == 'IV':\n",
    "                            beta = 2\n",
    "                            b1 = np.random.uniform(-beta, beta, len(T[T==1]))\n",
    "                        else:\n",
    "                            print('b not recognized! Setting b1 = 0.')\n",
    "                            b1 = 0\n",
    "\n",
    "                        Y[T==1] += b1\n",
    "                        YY0 = Y[T==0]\n",
    "                        YY1 = Y[T==1]\n",
    "\n",
    "                        ### calculate the test statistics and p-value\n",
    "                        Y = Y[:,np.newaxis]\n",
    "                        YY0 = YY0[:,np.newaxis]\n",
    "                        YY1 = YY1[:,np.newaxis]\n",
    "\n",
    "                        # Gaussian RBF kernel\n",
    "                        sigma2 = np.median(pairwise_distances(YY0, YY1, metric='euclidean'))**2\n",
    "                            \n",
    "                        if method == 'DR-xKTE':\n",
    "                            t0 = time.time()\n",
    "                            value, p_value = kernel_dr_two_sample_test_agnostic(Y, X, T, w,\n",
    "                                                                                kernel_function='rbf',\n",
    "                                                                                gamma=1.0/sigma2,\n",
    "                                                                                verbose=False)\n",
    "                            times[n] = time.time() - t0\n",
    "                            p_values[n] = p_value\n",
    "                            values[n] = value\n",
    "\n",
    "                        elif method == 'IPW-xKTE':\n",
    "                            t0 = time.time()\n",
    "                            value, p_value = kernel_two_sample_test_agnostic(Y, T, w,\n",
    "                                                                                kernel_function='rbf',\n",
    "                                                                                gamma=1.0/sigma2,\n",
    "                                                                                verbose=False)\n",
    "                            times[n] = time.time() - t0\n",
    "                            p_values[n] = p_value\n",
    "                            values[n] = value\n",
    "\n",
    "                        elif method == 'KTE':\n",
    "                            t0 = time.time()\n",
    "                            mmd2u_rbf, mmd2u_null_rbf, p_value = kernel_two_sample_test_nonuniform(YY0, YY1, T, w,\n",
    "                                                                                kernel_function='rbf',\n",
    "                                                                                gamma=1.0/sigma2,\n",
    "                                                                                verbose=False,\n",
    "                                                                                iterations=iterations)\n",
    "                            times[n] = time.time() - t0\n",
    "                            p_values[n] = p_value\n",
    "\n",
    "                        elif method == 'linear':\n",
    "                            t0 = time.time()\n",
    "                            mmd2u_lin, mmd2u_null_lin, p_value = kernel_two_sample_test_nonuniform(YY0, YY1, T, w,\n",
    "                                                                                kernel_function='linear',\n",
    "                                                                                verbose=False,\n",
    "                                                                                iterations=iterations)\n",
    "                            times[n] = time.time() - t0\n",
    "                            p_values[n] = p_value\n",
    "\n",
    "                        elif method == 'DR-CFME':\n",
    "                            t0 = time.time()\n",
    "                            mmd2u_rbf, mmd2u_null_rbf, p_value = kernel_dr_nonuniform(Y, X, T, \n",
    "                                                                                    w, experiment,\n",
    "                                                                                    iterations=iterations,\n",
    "                                                                                    verbose=False,\n",
    "                                                                                    kernel_function='rbf',\n",
    "                                                                                    gamma=1.0/sigma2)\n",
    "                            times[n] = time.time() - t0\n",
    "                            p_values[n] = p_value\n",
    "\n",
    "                        elif method == 'Vanilla_DR':\n",
    "                            T = T[:,np.newaxis]\n",
    "                            t0 = time.time()\n",
    "                            vanilla_dr = vanilla_dr_baseline_test(X, T, Y, iterations)\n",
    "                            p_value, value = vanilla_dr.permutation_test()\n",
    "                            times[n] = time.time() - t0\n",
    "                            p_values[n] = p_value\n",
    "                            values[n] = value\n",
    "\n",
    "                        elif method == 'BART':\n",
    "                            T = T[:,np.newaxis]\n",
    "                            t0 = time.time()\n",
    "                            bart = BART_baseline_test(X, T, Y, iterations)\n",
    "                            p_value, value = bart.permutation_test()\n",
    "                            times[n] = time.time() - t0\n",
    "                            p_values[n] = p_value\n",
    "                            values[n] = value\n",
    "                        elif method == 'CausalForest':\n",
    "                            T = T[:,np.newaxis]\n",
    "                            t0 = time.time()\n",
    "                            causal_forest = CausalForest_baseline_test(X, T, Y, iterations)\n",
    "                            p_value, value = causal_forest.permutation_test()\n",
    "                            times[n] = time.time() - t0\n",
    "                            p_values[n] = p_value\n",
    "                            values[n] = value\n",
    "                        else:\n",
    "                            print('Method not recognized.')\n",
    "\n",
    "\n",
    "\n",
    "                    df = pd.DataFrame()\n",
    "                    df['times'] = times\n",
    "                    df['p_values'] = p_values\n",
    "                    df['stat_values'] = values\n",
    "                    df.to_csv(name_folder + 'ns' + str(ns) + 'b' + str(b) + 'case' + str(case) + method + '.csv')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run functions for different settings"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Null hypothesis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "b =  I\n",
      "b =  I\n"
     ]
    }
   ],
   "source": [
    "'''\n",
    "num_experiments=500\n",
    "iterations=100\n",
    "\n",
    "ns_list = np.arange(100, 1050, 50)\n",
    "b_list = ['I']\n",
    "case_list = [1, 2]\n",
    "method_list = ['DR-xKTE', 'IPW-xKTE']\n",
    "\n",
    "experiment_list = [False, True]\n",
    "for experiment in experiment_list:\n",
    "    name_folder = 'data' +str(experiment) + '/'\n",
    "    run_tests(b_list, method_list, case_list, ns_list, experiment, name_folder, num_experiments, iterations)\n",
    "'''"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Experimental setting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "b =  II\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-28-b812ffed1c1f>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0mrun_tests\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmethod_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcase_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mns_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexperiment\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname_folder\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_experiments\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0miterations\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m<ipython-input-27-b76b79479bfe>\u001b[0m in \u001b[0;36mrun_tests\u001b[0;34m(b_list, method_list, case_list, ns_list, experiment, name_folder, num_experiments, iterations)\u001b[0m\n\u001b[1;32m     99\u001b[0m                         \u001b[0;32melif\u001b[0m \u001b[0mmethod\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'KTE'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    100\u001b[0m                             \u001b[0mt0\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 101\u001b[0;31m                             mmd2u_rbf, mmd2u_null_rbf, p_value = kernel_two_sample_test_nonuniform(YY0, YY1, T, w,\n\u001b[0m\u001b[1;32m    102\u001b[0m                                                                                 \u001b[0mkernel_function\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'rbf'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    103\u001b[0m                                                                                 \u001b[0mgamma\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1.0\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0msigma2\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Documents/Papers/code_DR_KTE/kte.py\u001b[0m in \u001b[0;36mkernel_two_sample_test_nonuniform\u001b[0;34m(X, Y, T, Prob_vec, kernel_function, iterations, verbose, random_state, **kwargs)\u001b[0m\n\u001b[1;32m     90\u001b[0m         \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Computing the null distribution.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     91\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 92\u001b[0;31m     mmd2u_null = compute_null_distribution(K, w, m, n, iterations,\n\u001b[0m\u001b[1;32m     93\u001b[0m                                            \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mverbose\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     94\u001b[0m                                            random_state=random_state)\n",
      "\u001b[0;32m~/Documents/Papers/code_DR_KTE/kte.py\u001b[0m in \u001b[0;36mcompute_null_distribution\u001b[0;34m(K, w, m, n, iterations, verbose, random_state, marker_interval)\u001b[0m\n\u001b[1;32m     41\u001b[0m         \u001b[0mK_i\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mK\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0midx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     42\u001b[0m         \u001b[0mw_i\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mw\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 43\u001b[0;31m         \u001b[0mmmd2u_null\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mMMD2u\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mK_i\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mw_i\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     44\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     45\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Documents/Papers/code_DR_KTE/kte.py\u001b[0m in \u001b[0;36mMMD2u\u001b[0;34m(K, w, m, n)\u001b[0m\n\u001b[1;32m     12\u001b[0m     \u001b[0mwy\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m1.0\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1.0\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mw\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnewaxis\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m     \u001b[0mKx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mouter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mwx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mwx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mK\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     15\u001b[0m     \u001b[0mKy\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mouter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mwy\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mwy\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mK\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     16\u001b[0m     \u001b[0mKxy\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mouter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mwx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mwy\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mK\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/numpy/core/overrides.py\u001b[0m in \u001b[0;36mouter\u001b[0;34m(*args, **kwargs)\u001b[0m\n",
      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/numpy/core/numeric.py\u001b[0m in \u001b[0;36mouter\u001b[0;34m(a, b, out)\u001b[0m\n\u001b[1;32m    934\u001b[0m     \u001b[0ma\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    935\u001b[0m     \u001b[0mb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 936\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0mmultiply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mravel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnewaxis\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mravel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnewaxis\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    937\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    938\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "num_experiments = 500\n",
    "iterations=100\n",
    "\n",
    "ns_list = [100, 150, 200, 250, 300, 350]\n",
    "b_list = ['II', 'III', 'IV']\n",
    "method_list = ['DR-xKTE', 'IPW-xKTE', 'KTE']\n",
    "case_list = [2]\n",
    "experiment = True\n",
    "name_folder = 'data' + str(experiment) + '/'\n",
    "\n",
    "\n",
    "run_tests(b_list, method_list, case_list, ns_list, experiment, name_folder, num_experiments, iterations)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Observational setting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_experiments = 500\n",
    "iterations=100\n",
    "\n",
    "ns_list = [100, 150, 200, 250, 300, 350]\n",
    "b_list = ['II', 'III', 'IV']\n",
    "method_list = ['DR-xKTE', 'BART', 'CausalForest', 'Vanilla_DR']\n",
    "case_list = [2]\n",
    "experiment = False\n",
    "name_folder = 'data' + str(experiment) + '/'\n",
    "\n",
    "run_tests(b_list, method_list, case_list, ns_list, experiment, name_folder, num_experiments, iterations)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
