{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "5ff04ccb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x248390c5230>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd \n",
    "import torch \n",
    "import numpy as np \n",
    "\n",
    "from warnings import filterwarnings \n",
    "\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "import matplotlib.pyplot as plt \n",
    "SEED = 0\n",
    "\n",
    "filterwarnings('ignore')\n",
    "\n",
    "rng = np.random.default_rng(SEED)\n",
    "torch.manual_seed(SEED)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c652585c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['treat', 'bw', 'b.head', 'preterm', 'birth.o', 'nnhealth', 'momage',\n",
       "       'sex', 'twin', 'b.marr', 'mom.lths', 'mom.hs', 'mom.scoll', 'cig',\n",
       "       'first', 'booze', 'drugs', 'work.dur', 'prenatal', 'ark', 'ein', 'har',\n",
       "       'mia', 'pen', 'tex', 'was', 'momwhite', 'momblack', 'momhisp'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "IHDP = pd.read_csv(\"ihdp.csv\")\n",
    "IHDP.columns"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7b5af6fb",
   "metadata": {},
   "source": [
    "# Linear Response Surface\n",
    "The goal of these experiments is to estimate the Average Treatment Effect on the Treated (ATE) for different response surfaces."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e9680743",
   "metadata": {},
   "outputs": [],
   "source": [
    "continuous_features = ['bw', 'b.head', 'preterm', 'birth.o', 'nnhealth', 'momage']\n",
    "IHDP = IHDP[IHDP['momwhite'] == 1]\n",
    "A = IHDP['treat']\n",
    "X = IHDP.loc[:,~IHDP.columns.isin(['treat','momwhite','momblack', 'momhisp'])]\n",
    "\n",
    "scaler = StandardScaler()\n",
    "scaler_y = StandardScaler()\n",
    "X.loc[:, X.columns.isin(continuous_features)]  = scaler.fit_transform(X.loc[:,  X.columns.isin(continuous_features)])\n",
    "\n",
    "β = rng.choice([0,1,2,3,4], p=[0.5, 0.2, 0.15, 0.1, 0.05], size = 25).reshape(-1,1)\n",
    "\n",
    "Y = A*rng.normal(loc = (X@β + 4).values.reshape(-1), scale = 1.0) + (1 - A)*rng.normal( loc = (X@β).values.reshape(-1) , scale = 1.0)\n",
    "\n",
    "Y = scaler_y.fit_transform(Y.values.reshape(-1,1))\n",
    "Y = Y.reshape(-1)\n",
    "\n",
    "X['treat'] = A "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4b44007d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Importting clipped adam...\n"
     ]
    }
   ],
   "source": [
    "from variationalRegressionTree import variationalRegressionTree\n",
    "from bartpy.sklearnmodel import SklearnModel\n",
    "\n",
    "train = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "3698143d",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not train:\n",
    "    BART = SklearnModel(n_trees=100, n_burn = 200, n_samples = 200) # Use default parameters\n",
    "    BART.fit(X, Y)\n",
    "    train = True "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "277e5dbd",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                                                       \r"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "DEVICE = 'cpu'\n",
    "\n",
    "X_tree = torch.tensor(X.values, device = DEVICE)\n",
    "y_tree = torch.tensor(Y, device = DEVICE)\n",
    "\n",
    "torch.manual_seed(SEED)\n",
    "λ = 1e0\n",
    "VaRT = variationalRegressionTree(5, X_tree, y_tree, device=DEVICE)\n",
    "elbos = VaRT.train(epochs = 1000, h1=λ, h2=-λ )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "f88ff46d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x248511a57c0>]"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAEDCAYAAAA2k7/eAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy89olMNAAAACXBIWXMAAAsTAAALEwEAmpwYAAAWmklEQVR4nO3dfYxnV33f8fcna2yMMQXq9eN6awcWpy4Fh4w2EEc0BkPsLfXmQaSmjXAC6RYpjtpUSWO6UqIqikRD1UQRbsiK0EILODTNYgsv+IFGckhD8DrBsH6CxZh6WYMXAsaJDfba3/7xu2N+T/N4Z2dnzrxf0mh+99wz99zz299+5sy5T6kqJEnt+77jvQOSpNVh4EvSBmHgS9IGYeBL0gZh4EvSBmHgS9IGseYDP8l7kzyc5MAi6/9MkruT3JXkg8d6/yRpvchaPw8/yauBvwXeX1UvXaDuNuDDwGuq6ptJTq+qh1djPyVprVvzI/yqug34m+GyJC9K8vEkdyT5syQ/0K36V8C1VfXN7mcNe0nqrPnAn8Me4Jeq6oeAXwH+a1f+EuAlSf48yaeSXHbc9lCS1pgTjvcOLFWS5wI/AvyvJLPFJ3XfTwC2AT8GbAH+LMlLq+pbq7ybkrTmrLvAZ/BXybeq6qIp6w4Bn6qqJ4EvJbmPwS+A21dx/yRpTVp3UzpV9W0GYf5GgAy8vFv9EeCSrvw0BlM89x+P/ZSktWbNB36SDwF/AVyQ5FCStwL/EnhrkjuBu4CdXfWbgG8kuRv4U+BXq+obx2O/JWmtWfOnZUqSVsaaH+FLklbGmj5oe9ppp9V55513vHdDktaNO+644+tVtXnaujUd+Oeddx779+8/3rshSetGki/Ptc4pHUnaIAx8SdogDHxJ2iAMfEnaIAx8SdogViTwk1yW5L4kB5NcM2V9kvxet/6zSV6xEu1Kkhavd+An2QRcC1wOXAi8KcmFY9UuZ3ATs23ALuD3+7YrSVqalTgPfztwsKruB0hyHYN729w9VGcngydWFfCpJM9PclZVPbQC7S9JVXH4ke/w+a8+yt89cZRHv3OUp54uarByUOd7L6nq1g2XdeVMlNdIndl1s1sYvotF1eLrrhXfuxt1f+N9n7PNJe7HYre7Fo13a7yfmfJuzNaZXTP72R3+LE6zUu/PYj8Si/3sTOvjQmoJvRn5P7jkllbPKSdu4l//kxet+HZXIvDPAR4cWj4E/PAi6pwDTAR+kl0M/gpg69atK7B7A5899C3e8t9v55HHn+TJp9byP7Wkje605560ZgN/2q/k8URdTJ1BYdUeBk+0YmZmpncy/9+DX+cX3r+fx554ir938rN40/atvPj05/KPzn4eJ52wiRecciLP2pRuJzMyYpp9wMrgNc/UmS1cbN3h0U0yWp5nyqf8/EoOqXtaqZvsVfXr31z70Xe7a814P6d1u8bqFt/7LM5+jtbCe7HYz85SPmKzfZ210F99o/8Hj/97crysROAfAs4dWt4CHF5GnWPit2+6j8eeeIpTTzqBj/7Sj3LuC5+zGs02Z6X+k/TdzFz70dr/4fF+zt+/td35xX52jtW/YWufjT5W4iyd24FtSc5PciJwJXDDWJ0bgDd3Z+u8EnhkNebvq4p7v/ptLrlgMx+5+mLDXtKG1nuEX1VHk1zN4OEjm4D3VtVdSd7WrX83sA/YARwEHgN+vm+7C9n714f4wtf+lu88+TSvfslmXrT5uce6SUla01bkbplVtY9BqA+XvXvodQG/uBJtLdYv/9Gdz7w+5/knr2bTkrQmbYgrbf/hWc873rsgScfdhgj8LS9whC9JTQZ+VT1zZP7Pr3nNhj4NS5JmNRn43z36NFXwqz9+gfP3ktRpMvAfe+IpYHB5siRpoMnAf/zJQeCfbOBL0jOaDPzvdIH/7GcZ+JI0q8nAf/wJA1+SxjUZ+LMj/JMNfEl6RqOB/zTgHL4kDWsy8GcP2j77BANfkmY1Hfgnn9hk9yRpWZpMxO940FaSJrQZ+Ec9aCtJ45oMfE/LlKRJbQa+F15J0oReD0BJ8kLgj4DzgAeAn6mqb06p9wDwKPAUcLSqZvq0u5DHn3yKE0/4PjZ9n3fJlKRZfUf41wCfqKptwCe65blcUlUXHeuwB/juk087fy9JY/oG/k7gfd3r9wE/0XN7K+LxJ57i2c9qcrZKkpatbyqeUVUPAXTfT5+jXgE3J7kjya75NphkV5L9SfYfOXJkWTv1+JNPOX8vSWMWnMNPcitw5pRVu5fQzsVVdTjJ6cAtSe6tqtumVayqPcAegJmZmVpCG894usr5e0kas2DgV9Wlc61L8rUkZ1XVQ0nOAh6eYxuHu+8PJ9kLbAemBv5KWNZvCUlqXN8pnRuAq7rXVwHXj1dIckqSU2dfA68HDvRsd0GO7yVpVN/AfwfwuiRfAF7XLZPk7CT7ujpnAJ9McifwaeDGqvp4z3YlSUvU6zz8qvoG8Nop5YeBHd3r+4GX92ln6Tu2qq1J0rrQ7LmLiZM6kjSsycAvh/iSNKHJwAcP2krSuCYDvxzgS9KEJgMfwCl8SRrVZOA7wpekSU0GPkCcxZekEc0GviRpVJOB72mZkjSpycAHD9pK0rgmA9+DtpI0qcnAlyRNajLwHeBL0qQmAx+8eZokjWsy8J3Dl6RJTQY+ePM0SRrXK/CTvDHJXUmeTjIzT73LktyX5GCSa/q0uTgO8SVpXN8R/gHgp5jngeRJNgHXApcDFwJvSnJhz3YX5BS+JI3q+4jDe2DBA6TbgYPdow5Jch2wE7i7T9uSpKVZjTn8c4AHh5YPdWVTJdmVZH+S/UeOHFlWgx60laRJC47wk9wKnDll1e6qun4RbUwb/s8ZyVW1B9gDMDMzs+zodkpHkkYtGPhVdWnPNg4B5w4tbwEO99zmvBzgS9Kk1ZjSuR3YluT8JCcCVwI3HOtGvR++JI3qe1rmTyY5BLwKuDHJTV352Un2AVTVUeBq4CbgHuDDVXVXv92eXzmJL0kT+p6lsxfYO6X8MLBjaHkfsK9PW0vlHL4kjWrySlvH95I0qcnAB2+tIEnjmgx8p/AlaVKTgQ84iS9JY9oNfEnSiCYD3xkdSZrUZOCDB20laVyTge+FV5I0qcnAB4/ZStK4ZgNfkjSq2cB3gC9Jo5oMfKfwJWlSk4EPCz52UZI2nGYDX5I0qsnALy+9kqQJTQY+eNBWksb1feLVG5PcleTpJDPz1HsgyeeSfCbJ/j5tLoYHbSVpUq8nXgEHgJ8C/mARdS+pqq/3bG/RPGYrSaP6PuLwHlh7Z8Q4wpekSas1h1/AzUnuSLJrNRqMs/iSNGLBEX6SW4Ezp6zaXVXXL7Kdi6vqcJLTgVuS3FtVt83R3i5gF8DWrVsXuflRnqUjSZMWDPyqurRvI1V1uPv+cJK9wHZgauBX1R5gD8DMzMzyk9sBviSNOOZTOklOSXLq7Gvg9QwO9h4zzuFL0qS+p2X+ZJJDwKuAG5Pc1JWfnWRfV+0M4JNJ7gQ+DdxYVR/v0+6i9u1YNyBJ60zfs3T2AnunlB8GdnSv7wde3qcdSVJ/TV5p64yOJE1qMvDBC68kaVybge8QX5ImtBn4eOGVJI1rMvC98EqSJjUZ+OAcviSNazLwvfBKkiY1GfjgCF+SxjUb+JKkUU0GvjM6kjSpycAHT8uUpHFNBn551FaSJjQZ+OBBW0ka12TgO76XpElNBr4kaVKTge8UviRN6vvEq3cmuTfJZ5PsTfL8OepdluS+JAeTXNOnzSXs22o0I0nrRt8R/i3AS6vqZcDngbePV0iyCbgWuBy4EHhTkgt7tjsvB/iSNKlX4FfVzVV1tFv8FLBlSrXtwMGqur+qngCuA3b2aXcxHN9L0qiVnMN/C/CxKeXnAA8OLR/qyiRJq2jBh5gnuRU4c8qq3VV1fVdnN3AU+MC0TUwpm3PWJckuYBfA1q1bF9q96TxqK0kTFgz8qrp0vvVJrgLeALy2pl/iegg4d2h5C3B4nvb2AHsAZmZmlp3cHrOVpFF9z9K5DPg14IqqemyOarcD25Kcn+RE4Erghj7tLsTxvSRN6juH/y7gVOCWJJ9J8m6AJGcn2QfQHdS9GrgJuAf4cFXd1bPdBTnAl6RRC07pzKeqXjxH+WFgx9DyPmBfn7aWtl+r1ZIkrR9NXmkLXnglSeOaDPxyFl+SJjQZ+OAcviSNazLwncOXpElNBj54Hr4kjWs28CVJo5oMfKd0JGlSk4E/4JyOJA1rMvAd4EvSpCYDHzxoK0njmgz86TftlKSNrcnAB2fwJWlcs4EvSRrVbOA7hy9Jo5oNfEnSqCYD32O2kjSpycAHiIdtJWlErydeJXkn8M+AJ4AvAj9fVd+aUu8B4FHgKeBoVc30aXch3g9fkib1HeHfAry0ql4GfB54+zx1L6mqi4512M/yoK0kjeoV+FV1c/eQcoBPAVv671J/zuFL0qSVnMN/C/CxOdYVcHOSO5Lsmm8jSXYl2Z9k/5EjR5a9M47wJWnUgnP4SW4FzpyyandVXd/V2Q0cBT4wx2YurqrDSU4Hbklyb1XdNq1iVe0B9gDMzMwsa6zuAF+SJi0Y+FV16Xzrk1wFvAF4bc1xE5uqOtx9fzjJXmA7MDXwV4pn6UjSqF5TOkkuA34NuKKqHpujzilJTp19DbweONCn3YV48zRJmtR3Dv9dwKkMpmk+k+TdAEnOTrKvq3MG8MkkdwKfBm6sqo/3bHdhDvAlaUSv8/Cr6sVzlB8GdnSv7wde3qcdSVJ/TV5p64SOJE1qMvDBGR1JGtdm4DvEl6QJbQY+EK+8kqQRTQa+A3xJmtRk4INz+JI0rsnA98IrSZrUZOCDN0+TpHHNBr4kaVSTge+EjiRNajLwwYO2kjSuycD3mK0kTWoy8MELryRpXJOBX87iS9KEJgMfnMOXpHFNBr5z+JI0qe8jDn8zyWe7p13dnOTsOepdluS+JAeTXNOnzcXv3Kq0IknrRt8R/jur6mVVdRHwUeDXxysk2QRcC1wOXAi8KcmFPdudlyN8SZrUK/Cr6ttDi6cw/Zqn7cDBqrq/qp4ArgN29ml3MeIQX5JG9HqmLUCS3wLeDDwCXDKlyjnAg0PLh4Afnmd7u4BdAFu3bu27e5KkzoIj/CS3Jjkw5WsnQFXtrqpzgQ8AV0/bxJSyOSddqmpPVc1U1czmzZsX2w9J0gIWHOFX1aWL3NYHgRuB3xgrPwScO7S8BTi8yG0um9ddSdKovmfpbBtavAK4d0q124FtSc5PciJwJXBDn3YX4v3wJWlS3zn8dyS5AHga+DLwNoDu9Mz3VNWOqjqa5GrgJmAT8N6quqtnuwtygC9Jo3oFflX99Bzlh4EdQ8v7gH192lrSfq1WQ5K0jjR5pS04hy9J45oMfKfwJWlSk4EPXnglSeOaDHxvjyxJk5oMfHAOX5LGNRv4kqRRTQa+B20laVKTgQ9O6UjSuCYD3wG+JE1qMvAHHOJL0rAmA985fEma1GTgg3P4kjSu0cB3iC9J4xoNfGfwJWlcs4EvSRrVZOB70FaSJvV6AEqS3wR2Mnji1cPAz3UPPxmv9wDwKPAUcLSqZvq0u7h9O9YtSNL60neE/86qellVXQR8FPj1eepeUlUXrUbYO8CXpEm9Ar+qvj20eAprKGu9H74kjer7EHOS/BbwZuAR4JI5qhVwc5IC/qCq9syzvV3ALoCtW7cua5/KSXxJmrDgCD/JrUkOTPnaCVBVu6vqXOADwNVzbObiqnoFcDnwi0lePVd7VbWnqmaqambz5s3L6NLsfi/7RyWpSQuO8Kvq0kVu64PAjcBvTNnG4e77w0n2AtuB25awn0vi+F6SJvWaw0+ybWjxCuDeKXVOSXLq7Gvg9cCBPu0uat+OdQOStM70ncN/R5ILGJyW+WXgbQBJzgbeU1U7gDOAvRnMsZwAfLCqPt6z3Xk5hS9Jk3oFflX99Bzlh4Ed3ev7gZf3aWc54iS+JI1o8kpbSdKkJgPf0zIlaVKTgS9JmtRk4Du+l6RJTQY+eOGVJI1rM/Ad4kvShDYDH2+eJknjmgx8B/iSNKnJwAfn8CVpXLOBL0ka1WTge+GVJE1qMvDBu2VK0rgmA9/xvSRNajLwwYO2kjSuycB3Cl+SJjUZ+OD98CVp3IoEfpJfSVJJTptj/WVJ7ktyMMk1K9HmfC576Zn8wJmnHutmJGld6fuIQ5KcC7wO+H9zrN8EXNvVOQTcnuSGqrq7b9tz+Z1/ftGx2rQkrVsrMcL/HeDfM/fJMduBg1V1f1U9AVwH7FyBdiVJS9Ar8JNcAXylqu6cp9o5wINDy4e6srm2uSvJ/iT7jxw50mf3JElDFpzSSXIrcOaUVbuB/wC8fqFNTCmb8zyaqtoD7AGYmZnxfBtJWiELBn5VXTqtPMk/Bs4H7uzOiNkC/FWS7VX11aGqh4Bzh5a3AIeXvceSpGVZ9kHbqvoccPrscpIHgJmq+vpY1duBbUnOB74CXAn8i+W2K0lanmNyHn6Ss5PsA6iqo8DVwE3APcCHq+quY9GuJGluvU/LnFVV5w29PgzsGFreB+xbqbYkSUvX7JW2kqRRWcv3jk9yBPjyMn/8NGD8eELr7PPGYJ/b16e//6CqNk9bsaYDv48k+6tq5njvx2qyzxuDfW7fseqvUzqStEEY+JK0QbQc+HuO9w4cB/Z5Y7DP7Tsm/W12Dl+SNKrlEb4kaYiBL0kbRHOBv9pP11otSc5N8qdJ7klyV5J/05W/MMktSb7QfX/B0M+8vXsf7kvy48dv7/tJsinJXyf5aLfcdJ+TPD/JHye5t/v3ftUG6PMvd5/rA0k+lOTZrfU5yXuTPJzkwFDZkvuY5IeSfK5b93tZyvNcq6qZL2AT8EXg+4ETgTuBC4/3fq1Q384CXtG9PhX4PHAh8NvANV35NcB/6l5f2PX/JAZ3Nf0isOl492OZff93wAeBj3bLTfcZeB/wC93rE4Hnt9xnBs/H+BJwcrf8YeDnWusz8GrgFcCBobIl9xH4NPAqBree/xhw+WL3obURfrNP16qqh6rqr7rXjzK4Ed05DPr3vq7a+4Cf6F7vBK6rqu9W1ZeAgwzen3UlyRbgnwLvGSputs9JnscgGP4QoKqeqKpv0XCfOycAJyc5AXgOg1uoN9XnqroN+Jux4iX1MclZwPOq6i9qkP7vH/qZBbUW+Et6utZ6leQ84AeBvwTOqKqHYPBLge/dsrqV9+J3GTxC8+mhspb7/P3AEeC/ddNY70lyCg33uaq+AvxnBs/Ffgh4pKpupuE+D1lqH8/pXo+XL0prgb+kp2utR0meC/xv4N9W1bfnqzqlbF29F0neADxcVXcs9kemlK2rPjMY6b4C+P2q+kHg7xj8qT+Xdd/nbt56J4Opi7OBU5L87Hw/MqVsXfV5EebqY6++txb4TT9dK8mzGIT9B6rqT7rir3V/5tF9f7grb+G9uBi4onu4znXAa5L8T9ru8yHgUFX9Zbf8xwx+AbTc50uBL1XVkap6EvgT4Edou8+zltrHQ93r8fJFaS3wn3m6VpITGTxd64bjvE8rojsS/4fAPVX1X4ZW3QBc1b2+Crh+qPzKJCd1TxvbxuBgz7pRVW+vqi01eNbClcD/qaqfpe0+fxV4MMkFXdFrgbtpuM8MpnJemeQ53ef8tQyOUbXc51lL6mM37fNokld279Wbh35mYcf7yPUxOBK+g8EZLF8Edh/v/VnBfv0ogz/dPgt8pvvaAfx94BPAF7rvLxz6md3d+3AfSziSvxa/gB/je2fpNN1n4CJgf/dv/RHgBRugz/8RuBc4APwPBmenNNVn4EMMjlE8yWCk/tbl9BGY6d6nLwLvortjwmK+vLWCJG0QrU3pSJLmYOBL0gZh4EvSBmHgS9IGYeBL0gZh4EvSBmHgS9IG8f8BVxnru1c8loIAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(elbos)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "179d0ae8",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                                                       \r"
     ]
    }
   ],
   "source": [
    "yhat_VaRT = VaRT.predict(X_tree)\n",
    "yhat_BART = BART.predict(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "de7b098a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.22572623921177254, 0.2785888808951609)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.sqrt(((yhat_BART - Y)**2)).mean() , VaRT.rmse(yhat_VaRT, y_tree)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "563cc1e4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                                                       \r"
     ]
    }
   ],
   "source": [
    "X_0 = X.copy()\n",
    "X_0['treat'] = 0\n",
    "X_1 = X.copy()\n",
    "X_1['treat'] = 1 \n",
    "\n",
    "y1_BART = scaler_y.inverse_transform(BART.predict(X_1).reshape(-1,1)).reshape(-1)\n",
    "y0_BART = scaler_y.inverse_transform(BART.predict(X_0).reshape(-1,1)).reshape(-1)\n",
    "\n",
    "y1_VaRT = scaler_y.inverse_transform(VaRT.predict(torch.tensor(X_1.values, device=DEVICE), samples=1000).cpu().numpy().reshape(-1,1)).reshape(-1)\n",
    "y0_VaRT = scaler_y.inverse_transform(VaRT.predict(torch.tensor(X_0.values, device=DEVICE), samples=1000).cpu().numpy().reshape(-1,1)).reshape(-1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "b61bfbeb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(4.058290514431225, 4.0571747)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(y1_BART - y0_BART).mean(), (y1_VaRT - y0_VaRT).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "840d4654",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "4c015488",
   "metadata": {},
   "source": [
    "# Non-linear Response Surface"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "dd3159d3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x248390c5230>"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd \n",
    "import torch \n",
    "import numpy as np \n",
    "\n",
    "from warnings import filterwarnings \n",
    "\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "import matplotlib.pyplot as plt \n",
    "\n",
    "SEED = 0\n",
    "\n",
    "filterwarnings('ignore')\n",
    "\n",
    "rng = np.random.default_rng(SEED)\n",
    "torch.manual_seed(SEED)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "41bcffbe",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['treat', 'bw', 'b.head', 'preterm', 'birth.o', 'nnhealth', 'momage',\n",
       "       'sex', 'twin', 'b.marr', 'mom.lths', 'mom.hs', 'mom.scoll', 'cig',\n",
       "       'first', 'booze', 'drugs', 'work.dur', 'prenatal', 'ark', 'ein', 'har',\n",
       "       'mia', 'pen', 'tex', 'was', 'momwhite', 'momblack', 'momhisp'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "IHDP = pd.read_csv(\"ihdp.csv\")\n",
    "IHDP.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "a7a3c4d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "continuous_features = ['bw', 'b.head', 'preterm', 'birth.o', 'nnhealth', 'momage']\n",
    "IHDP = IHDP[IHDP['momwhite'] == 1]\n",
    "A = IHDP['treat']\n",
    "X = IHDP.loc[:,~IHDP.columns.isin(['treat','momwhite','momblack', 'momhisp'])]\n",
    "\n",
    "scaler = StandardScaler()\n",
    "scaler_y = StandardScaler()\n",
    "X.loc[:, X.columns.isin(continuous_features)]  = scaler.fit_transform(X.loc[:,  X.columns.isin(continuous_features)])\n",
    "\n",
    "β = rng.choice([0,0.1, 0.2, 0.3, 0.4], p=[0.6, 0.1, 0.1, 0.1, 0.1], size = 25).reshape(-1,1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "41bea51e",
   "metadata": {},
   "outputs": [],
   "source": [
    "ω = (X@β).values - np.exp((X + 0.5)@β).values - 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "bb53bd5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "Y = A*rng.normal(loc = (X@β - ω).values.reshape(-1), scale = 1.0) + (1 - A)*rng.normal( loc = (np.exp((X + 0.5)@β)).values.reshape(-1) , scale = 1.0)\n",
    "Y = scaler_y.fit_transform(Y.values.reshape(-1,1))\n",
    "Y = Y.reshape(-1)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "8307d996",
   "metadata": {},
   "outputs": [],
   "source": [
    "X['treat'] = A "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "53318064",
   "metadata": {},
   "outputs": [],
   "source": [
    "from variationalRegressionTree import CVTree\n",
    "from bartpy.sklearnmodel import SklearnModel\n",
    "\n",
    "train = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "e3571cd9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                                                       \r"
     ]
    }
   ],
   "source": [
    "DEVICE = 'cpu'\n",
    "\n",
    "X_tree = torch.tensor(X.values, device = DEVICE)\n",
    "y_tree = torch.tensor(Y, device = DEVICE)\n",
    "\n",
    "torch.manual_seed(SEED)\n",
    "λ = 1e0\n",
    "VaRT = variationalRegressionTree(5, X_tree, y_tree, device=DEVICE)\n",
    "elbos = VaRT.train(epochs = 3000, h1=λ, h2=-λ )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "49acbbce",
   "metadata": {
    "collapsed": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x248424be3a0>]"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZMAAAD4CAYAAAApWAtMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy89olMNAAAACXBIWXMAAAsTAAALEwEAmpwYAAAyPUlEQVR4nO3deXxV5Z348c83KwkEAmGVsCmgBUSFiNi6UKGAXYQ62DJdZGbsMLV2tZvWtoxa22pntGNb7TCVVp3+qtbawtQqBddaUYkKCrIFRAlrQgIkQPbv74/z3HCS3C259+Yk1+/79bqv3DzneZ57Tm7u+d5nOecRVcUYY4xJREbQO2CMMab3s2BijDEmYRZMjDHGJMyCiTHGmIRZMDHGGJOwrKB3IAiDBw/WsWPHBr0bxhjTq7z66quVqjok3Lb3ZDAZO3YspaWlQe+GMcb0KiLyTqRt1s1ljDEmYRZMjDHGJMyCiTHGmIRZMDHGGJMwCybGGGMSZsHEGGNMwiyYGGOMSVhaBBMRmS8i20SkTERuCHp/ulNLi7Ju52EOHquLmq+puYWtB47x9NaDEfOoKrsrj/PnN/bxsZ+9wDuHj0fMd7y+idWbD3DpT55h9eYDEeusb2rmbzsquOPJrWzedzTqa1fU1PPXzQd4ZuuhqMfS2NzCviMn2X6whsbmlqh5VZWaukbiWWqhuUXjymeM6Uh6+4dHRDKB7cCHgHJgPfCPqvpWpDIlJSUaxEWLqsqLOw9z55rt3HzFZKaMHNBme219Ey+WVfL6niPc++xOAG6+YjJL3j+2tfyuyuP87Kkd/GnDvg71l912OVmZ3veDrQeOMf+nfwu7H7+6uoQ5k4bR0NTCXWu3t75We5+dOYZbF04B4PE39nPd/3st4rHt/vFHAHixrJJfPr+LpuYWXtx5uEO+t3/0YUSE+qZmfvSXrfzmxd1h63v8yxcx+TTv7/PstkP806/Xh81368IpfHbmGADqGpt5eP0eVvz9bd45fAKAvOxMTjY28+kLRnPbx88GvAD82Ot7eXLTfjbsOUJlbUNrfbPPGsp9/3R+6++b9h7lp2u3s3bLIU4f0pddFacC7KvfnUNRv9zWOp/fUcHfyypZu+UQb1eeyrf6q5dw5vCC1t+rjjewae9R3q06waGaelZt2Mu5owq5qmQUHxg/uDVfU3MLb+0/xt7qkxw8Vsdr7x6htr6Jz8wczWVnDWvzd6hrbObgsTqqTzRSfaKB19+pZtqYgcw6c2iHv1ljcwsHj9VxqKaejXuOMHZwXy4aP5jszLbfLVtalMrj9Rw72cSBo3XUNzXz/jMGk5eT2aHOhqYWjpxsoL6xhRMNzZxW2IeCPtlh37PG5haOnmzk2MlGCvNzGNQ3J2w+VaW+qYWauiYyMyRivvb7nJEhMfOpKiKx85m2RORVVS0Jty0droCfAZSp6i4AEXkIWABEDCbJVt/UzLuHT/BI6R7mTxnBs9sOMbIwj5+s3sbxhibqGjt+e/7oz15oPQGXHaplzp3Pha172arNXFVSzBvlR1m8/KWo+1Fb34QqbD9Ywyej5H1j71HOG13IZ+57hS37j3XYfuV5I3ns9b2tJ41VG/fx5d+9HrG+LPfh/evmAyx98FUAxg/tFzavKlTU1DHjh0+F3T4wP5vqE43U1jUB8MC63Xx/5eY2eaYWD+CNcq+Vc+S4Fwgqa+sp+cHaDvV98vxR/ObF3eypPunlP9HAF377Wmugy8wQ+vfJ4ph7vadcq0hV+fdVm7l/3akLfv2BBOCdqhP0z8vml8/u5D/XbA97PAAvlFVy5vACdhys4XsrN/HSrqoOeXYfPkGzwgfGD2bjniPc/dQO1u+uat0vvzf3HuWlG4dSU9fIc9srWPHC22zad4zmlo5fDF/5zmwUKK8+wS1/3sLGPUfC7uMvPzOd+VOGs2X/Mb728Aa2HqgJm++MIX156uuzqK1v4rHXyvn+ys2t75nf+WMH8vvPv5+GphbWbjnIF34b+YvI1lvn0yc7k7JDNfz5DS/Abz9Qw76jp1rbGQJPfOUSRg3KY81bB3l66yEyRXi36gRHXGCqqWuirqmZWxdM4fIpwznZ2Mz63VU8uclrOT+3vaLNZ/ETJcXc9JFJ7KqoZWfFcXZV1HLPszuZOKwfgrDt4Km/waOfv5ADx+rYXXmcXZXHqW9s4fE39wMwalAee6q8/687P3EOuVmZvLizkrzsTKqON/Dm3qPUNTWzp+okIwvz2HvkJN+cdyZnDisgM0OoqK3ncG0DlbX1HKqp5/827mPa6EJONDQztqgvn545mvLqkzS1eD0CNXWNHDvpftZ5P3OyMlhw7khONjRzvKGJE/XtfjY0c7ze+/nLz05nZGFexPejq9IhmIwE9vh+LwcuaJ9JRJYCSwFGjx6d8IvurKjl+kc2dvhw/s/f3o6r/OVThgPw2rvVXHnPi1HzTvr+6ojbVn/1Em7981u8UFZJZW09H777BRqa2gavm6+YzLJVmzljSF92Vhxn9KB8/vv5Xew4WMO9n57G63uOsPz5Xbz+vQ8x0H37e2LTAVSVtyuP8+Xfvc7oQfn8bulMyg7VsmTFK6z+6iXM++nzAJxdPICVG/bylYc2UNAni5e/M5v8nCyuf3gDYwf35cuzJ/Bfa3dw19rtKHCPawl976OT+OFftjB30jDu/cx0wGvZfOpXL7P1QA3Pba/g13/fzfljB7J+dzUA234wn9ysTFpalNO/8xf+c812lv9tFzXupLvsY5NYeO5IauubGJCfTf8+2bz+bjXPb6+gpUX5xu83Urq7mq/OmcBHp57WGvRC70O+C6C//vtu7l/3Dv/0/rHMnTyM/Jwspo4cwL6jJ9l2oIZr7i8lU4QlK15p0wL7t0tP59pLzyAnK4OauiYu+OFTNLe0sPAXf2eD739lTFE+Dy2dydCCPpxoaOLKe15k3c7DLLr3RTbtO0q/3CzmTR7OxGEFTD6tP1UnGnhmawVrtxyksbmFf3uwlLVbvMB35rACvjDrDMYU9WVQX++Yf/Pibv78xn7eKD/K5x5o2wr/yNQRFOZls3nfMRTYuOcIn/9f70tAhkCLwmDX4po3eRjnjx3EvqMnuePJbeysOM4dT27lwZfeaf2bXzRhCEMLcrnvBe9/Pyczg10Vx3n01XJ+unY75S6Qh3z5svH8ccPe1hPw91du4olNB6ipayJDYOKwAqaMHMDpQ/oxID+bgfnZ/O9L7zLvp8+TnSk0NntB87QBfRhdlM/EYf0oyM2mf14WD6/fw3f/tInv/mlTm9ccP7QfedmZrcGkIDeLR0rLeaS0vMNnavvBWuZOGtYmmCz65brW532yMzjNdzI+d9TA1mO5/pGNbeoa3C+X/JzM1u2jBnnB5Cert3V43fycTPJzslyZLF579whbD9TwZLtu5MwMoaBPFv37ZFPQJ4uDx+qprK3n72Wn/g8zBPrmZtE3J4v83EzvZ04mQwpyU9aVmw7BJFxbtcNfS1WXA8vB6+bq6oudaGiKenIH+Mmiqcw8vYgB+dn0y8miqUXJyTrVhTDjtrUU5ntdAP5A8ucvXcSkEf0Jtb6vXvEKb+07xuHjp7phtv/g8jZ1Acx+31BeKKtkzp3Pt0kPtXwA/mF6Mbsrj/PRn71Av9xM/m/jPmadOYTLzx7BhyYNY+klp7cGEgARrxXxufu97qV7Pj2NkYV5jCzMa613948/wmfve5na+iZ+/nQZAIumF5Of4/1b3fnJc9vUB3D0ZCOPlO5h4bmncc1F47jmonFt/3gu37JVp1ojP7pyKqcP7kuzamtXjL8rI3RSO6d4AP/8Aa8+/7FsdK2YK37xApv2HuOGy8/i85ee0eZlp40eyNJLTueBdbtpaGrhnmfLuHjCYJZ9bFKb7pDigfnsdC2U+qZTXXmbbp5Hv9y2H6esDG9fG5paWgPJc9+cxavvVPPhs0fQJ9sLXAV9ssnKzKCy1jspjCzMY+UXP9B6Qg/56NTT+NFftvCbF3e3BpJpowt59PPv79C1U32ikT+/sZ/N+061PFd98QNMGtG/tSsUoKKmnvNvO9Wiu+Kc07j5iikMyO/YRXXHk94JMPRl4Pefv5Dzxw5q3X7BuEH89a2DDC3I5Z5nd/KN329kysj+3LJgMhdPGMJTWw4xrH8u540eyNc+NJEfP7mV/35uV+sJ/ZvzzmTx+aNauw5DWlqUdw6fYNPeoyyaXszU4kLOGl7AhGEFtNei8Nhr5VSfaOSWBZOZPmYgZwzp1/q3Dtl64Bh/eLWcIQW5jBvcjzOG9GXUoHzKDtUy0bUYauoaERFe2FFJn+wMhvXvQ/HAPPJzssjMEFQVVe9/8cdXns1b+4/R0NTSGkCKB+a1/u+0tCiHjzcwpCCXN8qP8MrbVRw92cgHzxrKkH65FPXLaf3chFTW1vPN32/kY+ecxsRhBRT1y2FAXjZ52Zlt/idDX/qyMzPom+sFjdysjG7vxkuHYFIOjPL9Xgx0HFBIkn/29d0PLcjlUE09i6YXc/MVk2loaiEvJ7PDP25OHH24QIcxlP552dT7WhlPf/3SDoEkkkc/f2Gb3/vlZrWe0LcfrGX/0Tq+NmciAFmZGR1OXIIXkauON1CYn91h3/xUvX3fcaiWGy9/X9T9+tXfdnGioZkF540Mu13afTeYMLRfa+sho922zAxp07XzSLtjbm/TXu/EevWFY8JuzxChReGvbx2gsraBf7loXNgPZLZ7Pzfs8VpL35g7sUMgAcjO9PKFgtmSC8cwpqgvY4r6RswLXkuy/fsR0ic7s83/xGNf+EDYfKEW1l1rve63NV+7JOzJt0/2qf+nf5hWzH9cNTWuk1Coa8pv7uThzJ08nK0HjlFZW8+c9w3jQ5OGtdY337XGAUSEpRefzks7DzN2cF9+suiciP/bGRnCg9d06GwI63sfncT3PjopZr6zhvfnpo90zPe+Ef1bn4fGfPz77ScirZ+pvrlZbQJrexkZwpAC7z2dWlzI1OLCmPs4uF8uv/7nGTHziQinDwnfrdyd0iGYrAcmiMg4YC+wGPhUql4s1GTf+cMPk9kuSPQN//mPy8fDnFwFryUUEukfJtxHvyTKP/adrn//jKEdT2qtdYqgCk0typURTvyt+YCDx+o4b3RhxBNCaB9D32pLxgyMUF/b399/RlHE184UodnXCM3N6jgw3N7i80d1+AYYkiHeN8intxyiqG8Ol0wIe6ft1m/2f9tRCcA/zgjfbSoi5GRlsOYtbwbdV13wDic0BgQwZ9KwiPn8f9+nvn5pxHz+E/2dnzgnbCAB74T5i09NY+KwfhHzhLz63Tkcr29mdFF+1HxnDe/PHYvOiZoHoKhfLiu/eFHMfKZ36PXBRFWbROSLwGogE1ihqptjFOuSytp69h45yacvGN0hkCTqug+O75Am7ptysrT/1j+0oE+UvF4gq6lrYtiA6PlQpbz6JOeOKoycz730rDOHsHV/TcSZPu3/qmf5vim2l5EBNHvPr/vgGRHz+U2LEMTAa+m0qFL6TjUlYwdGfI9D6S+UVTJhaL8O3TJ+/vGrgXHMRnrum7Oibu+f5/3dJgztxxlRvo2eVui9Z4vPHxX2i4rfR6aOiLlf4J38i4L/Amx6qF4fTABU9S/AX1L9Ojc+9iYAx+s7zrBJVKTZT8nU/lt/qNkdTkNzCw+t9+Y1DO8fJZi4Adt9R05GPSmFujoO1zbEbBH5jQ3TJRSS6csbrSXmd2aUb9+h4P1u1Qk+FKV1EOqSCnXvxeOeT0+Lun3HbZfT3KIduo7aWzStmDGD8ttMIQ5nxIC8sONrxqSK/ad1Qmga7YgkTKuLZ0KF/7T6Q3eNRNh87U7AI6K0JEL698mKeuLy98tHa8EAVJ9ooKlFGRolOIVU1NRHHA+AjgFvepSWhP9POD7OPuMJwyLn8zdEZp/V8fqMkNDAOkTvhoNTxzM3SnACyM7MiBlIAPJyMrlk4pC4WsYWSEx3SouWSXe5aPxgHlq/h6/MnpBQPfFOsvDny8qMv1stN9K4ha+KeLpcTuUN3yUFXsCrqKkHiBokQipr6ymKMrjU/iijnRD9Afm0OAN8pPESaNvSidYd5h8sP7s4esuk9KY5ZGVktJlBZUw6smDSCScbmxlTlB/XN8jOmDa6MGy6/8RaEGa2UCRfjhDs/GMmA/IiB4j2BuZHDjze1exeKyZat1noPN3UohT1i1Zf3LtFiy+aJGMMq8F3a5Zo77E/MJw1PPKYDhB1PMWYdGJflzrhRENz1G+2nfXYa978+jNjnJAALntf5G6X9qaNjvytOqQzwSRaXv8pPGr3lS9nYZhrGMLVWBSj9RQKJTPiHC+JNS71i2fK4qon1DIZGPU4jHlvsWDSCROG9mP6mMKk1Re6MK+iJvxNGv1jIf5++o752v6eH+beSe3zdSaYRKqvvSFxjoWEuyYjXD7/xZphuWjysXOiz0aa5GaE/dJdZR/Jdz7sXSPzt299MGq+kYV5fHv+WTz+5Yuj758x7yHWzdUJ35p/VlLrC125/fz2yrDb/TGiM704QyPMvvJX0ZlgEu0ittCmDIH+eVGChO953yitu850VoW6pWKNR/x08blsPVATs2XyuYtP53MXnx7zdUWEa2fFNxXZmPcKCyY9wA/cnXkjEYl+Qu+K6F1NneHtV35OVtz72K9PtJbJqTp+dGXkGWx+0VpE4N3raWKMC/KMMYmxbq6A7D1y6uZ3s84Mf6U1rd/6o5+k4w0zXe3miqfOcLclj/TaUbu5fM/7xjnpoHhQ8u+AaozpHAsmARCEN/eeun1G5G4p79SavIvtuzabKx6xxlX8A/DxjpnkxzlrbvSg6Lf3MMakngWTgMRqbYB/PCL5d/8ckBf/dSbRhGbnHmm3nkV7/kOIt8UR78B/MmfYGWO6xj6FAelMgEhWMElFN9fG8iOAd2v5eEXv5jq1k/kxgs6fv3RR1LqMMd3HPokBieeC6NBpNTc7RuY4g01XZ3PFW2fUfL597BPleNp0c8VomcR7XyxjTOpZN1dAOtPNlZOCW3EURJlR5RfrtbvSaop31le8+2iMCZ4Fk4B05iScrBv2+U/i8dYZK1+8kwNC2fp2YtZXsicJGGNSx4JJQOK5cWNo/CDWCT3urib/68cZBbJj7Gf75WIjvrbLFmscxD9mkpfke6AZY1LHgkkARE7NQLr7H8+Lmg9S080V711sq2PM0gqtPBlLKETEij3+lkl3r2FtjOm6lAUTEfmJiGwVkTdE5I8iUujbdqOIlInINhGZ50ufLiJvum13izubiEiuiDzs0l8WkbG+MktEZId7LEnV8SSbujm10bp9QufS3CR9Q/efm2O1OJItFBgOHquPka879sYYk2ypbJmsAaao6lRgO3AjgIhMwlunfTIwH7hHREJny3uBpcAE95jv0q8BqlV1PHAXcLuraxCwDLgAmAEsE5HYt8ztAULXZ8Rz8sztRMvk21HuH+bvQkr2ssOXTIxwFb/THOf6w+2XFjbG9A4pCyaq+ldVDa1v+xJQ7J4vAB5S1XpVfRsoA2aIyAigv6quU+9r+wPAQl+Z+93zR4HZrtUyD1ijqlWqWo0XwEIBqEdTd8vb6CfPOMdM2tyqJL5WTHaUuxB3xQXjot8GvtG3Vkg01jIxpnfqrjGTfwGecM9HAnt828pd2kj3vH16mzIuQB0FiqLU1YGILBWRUhEpraioSOhgkqF1XacoJ8/WMZPOzOaK4w6/EP/AeRJeFvAWxYpHfWN8QccY07MkNJFfRNYCw8NsuklVV7o8NwFNwG9DxcLk1yjpXS3TNlF1ObAcoKSkJL4zWwqFdiCeKcKdGYBP9hf7y6Kshd72daO/coNbjTHWksehuxlfeV7Y7wTGmB4qoWCiqnOibXcD4h8FZqu2fhcvB0b5shUD+1x6cZh0f5lyEckCBgBVLn1WuzLPduFQul1oydnYnVyxr4D3n8ijxaaudCHFe6v6WHWHurlitbJGDcrnrVvm2f22jOllUjmbaz7wbeAKVT3h27QKWOxmaI3DG2h/RVX3AzUiMtONh1wNrPSVCc3UWgQ87YLTamCuiAx0A+9zXVqPJsQ3AN+VqcHJHsCOt75Y+1hb7w2fxdMSs0BiTO+TyjGTnwMFwBoR2SAivwRQ1c3AI8BbwJPAdara7MpcC/wKb1B+J6fGWe4DikSkDLgeuMHVVQXcCqx3j1tcWo+nGnsAPt6LFtuUiRqcunLrk/jyxbqP1gPr3gFg1cZ9UfMZY3qnlH0FdNN4I227DbgtTHop0GHZQVWtA66KUNcKYEXX9zQYrYNBcZysOxVMurgtknjH6T8YY2xlxIA+7D9ax9xJw7qwF8aYns6ugA9Ip7q5OjE1ONlTa+O9h9iwCAt8hXxg/GAAPnz2iIT3yRjT81jndEBa4ujmCsnNiv8K+Gj1tWjnJ7El65YmtyyYzEfOHsGZw20tdmPSkbVMAhJPy6Sx2cuU26nrTCJvivcqdL9kXY6Sn5MVsyvMGNN7WcskYNHO1X/dfACIfesTifC8vXgvHGxTd4xg8oOFU3h2W/AXgRpjgmXBJCChLqdoV6IfPt4AnLrgL+HX7EIwuciNdUTymZlj+MzMMV3dJWNMmrBurgCISGsrIdoA9znFA1yeztUdSVdaJheMK+p0GWPMe48Fk4CErjOJ1oV10YTorYJwosWdor45na/PbrxojImDBZOANLe2TBKvK96pwUNjTN81xpiusmASkOY4urm6ItktCVtfxBgTDwsmATlW592rKlo3V1dO5Ek/+VssMcbEwYJJwGrqmmJniiHeuwaHjCnKT/g1jTHGz6YGB6wpzhUIk+Xl78ymb278b7sNwBtj4mHBJGCNXZiuG02s25/EuodWh/oS2RljzHuGdXMFbOrIAUmtz07+xpggWDAJWEGfJDQOU3jX4GTd6NEYk94smAQs2SfrZM/mykrWnR6NMWkt5cFERL4hIioig31pN4pImYhsE5F5vvTpIvKm23a3W74Xt8Tvwy79ZREZ6yuzRER2uMcSepmoi1l14Tye7IZEZ5YMNsa8d6X0TCEio4APAe/60iYBi4HJwHzgHhEJLdhxL7AUb134CW47wDVAtVu98S7gdlfXIGAZcAEwA1jm1oLvNZJx8o/3rsFdEe1GlMYYE5Lqr513Ad/i1Cq1AAuAh1S1XlXfxlvvfYaIjAD6q+o69W5c9QCw0Ffmfvf8UWC2a7XMA9aoapWqVgNrOBWAeoWkd3PZud8YE4CUBRMRuQLYq6ob220aCezx/V7u0ka65+3T25RR1SbgKFAUpa5w+7NUREpFpLSiItj1N0In/NSc+C2aGGO6X0JTiURkLTA8zKabgO8Ac8MVC5OmUdK7WqZtoupyYDlASUlJci/u6KJYp/3QXX4HduJuv3WNzQnskTHGdE1CwURV54RLF5GzgXHARteNUwy8JiIz8FoPo3zZi4F9Lr04TDq+MuUikgUMAKpc+qx2ZZ5N5Ji6U6wurs9eOJb+edksPDdsYytsPSs37GXhedHzG2NMsqWkm0tV31TVoao6VlXH4p30p6nqAWAVsNjN0BqHN9D+iqruB2pEZKYbD7kaWOmqXAWEZmotAp524yqrgbkiMtANvM91ab1CrJZJZoZw5bTiTg2Cd2UBLGOMSVS3305FVTeLyCPAW0ATcJ2qhvpmrgV+A+QBT7gHwH3AgyJShtciWezqqhKRW4H1Lt8tqlrVLQeSBKkYM2ns5nt9GWMMdFMwca0T/++3AbeFyVcKTAmTXgdcFaHuFcCKpOxoN0vWTC5/LXPeNywpdT799Uvt6ndjTNzsRo8BSsWp+tKJQ5JSz+lD+iWlHmPMe4Nd3hyAVE4NttaEMSYIFkwClIolce2CdWNMECyYBChZjQh/PcleU94YY+JhwSRAqTjtWzAxxgTBgkmAUjG+YbHEGBMECyYBSkU3lwUTY0wQLJgEyLq5jDHpwoJJAEKzuGwarzEmXVgwCVAqYkl9k91OxRjT/SyYBChZscR/vUpBH7upgTGm+1kwCVAqurkG98tNep3GGBOLBZMA2YiJMSZdWDAJUNLuGmxRyRgTMAsmAbIgYIxJFxZMAtB61+Bgd8MYY5ImpcFERL4kIttEZLOI3OFLv1FEyty2eb706SLyptt2t1u+F7fE78Mu/WURGesrs0REdrjHEnoRa5kYY9JFyuaRisgHgQXAVFWtF5GhLn0S3rK7k4HTgLUiMtEt3XsvsBR4CfgLMB9v6d5rgGpVHS8ii4HbgU+KyCBgGVACKPCqiKxS1epUHVcypeIW9MYYE4RUtkyuBX6sqvUAqnrIpS8AHlLVelV9GygDZojICKC/qq5TVQUeABb6ytzvnj8KzHatlnnAGlWtcgFkDV4A6hWS3TLpl2vXmBhjgpHKYDIRuNh1Sz0nIue79JHAHl++cpc20j1vn96mjKo2AUeBoih19Qr7j9Yltb7C/Oyk1meMMfFK6KusiKwFhofZdJOreyAwEzgfeERETif8uLNGSaeLZdrv61K8LjRGjx4dLkuvZff4MsYELaFgoqpzIm0TkWuBx1yX1Ssi0gIMxms9jPJlLQb2ufTiMOn4ypSLSBYwAKhy6bPalXk2wr4uB5YDlJSUhA04xhhjuiaV3Vx/Ai4DEJGJQA5QCawCFrsZWuOACcArqrofqBGRmW485GpgpatrFRCaqbUIeNoFqdXAXBEZKCIDgbkuzRhjTDdK5YjtCmCFiGwCGoAlLgBsFpFHgLeAJuA6N5MLvEH73wB5eLO4nnDp9wEPikgZXotkMYCqVonIrcB6l+8WVa1K4TElhXVKGWPSTcqCiao2AJ+JsO024LYw6aXAlDDpdcBVEepagRe43rMsOBljgmZXwBtjjEmYBZM0ojatwBgTEAsmacBmBhtjgmbBJI1YUDHGBMWCSRqxbi5jTFAsmAQg2Ves2w0jjTFBs2BijDEmYRZMAnTZWUOTWp+NmRhjgmLBJEAZST7525iJMSYoFkwClZxoYi0SY0zQLJgYY4xJmAWTAFmLwhiTLiyYBCDZMcRikjEmaBZMAmRBwBiTLiyYBMi6uYwx6cKCiTHGmIRZMAlQsm6DYi0cY0zQUhZMRORcEXlJRDaISKmIzPBtu1FEykRkm4jM86VPF5E33ba73VrwuPXiH3bpL4vIWF+ZJSKywz2WYIwxptulsmVyB3Czqp4LfN/9johMwlvDfTIwH7hHRDJdmXuBpcAE95jv0q8BqlV1PHAXcLuraxCwDLgAmAEsE5GBKTympLIWhTEmXaQymCjQ3z0fAOxzzxcAD6lqvaq+DZQBM0RkBNBfVdepqgIPAAt9Ze53zx8FZrtWyzxgjapWqWo1sIZTAajnSnoQsahkjAlWVgrr/iqwWkT+Ay9ovd+ljwRe8uUrd2mN7nn79FCZPQCq2iQiR4Eif3qYMm2IyFK8Vg+jR4/u6jEllbVMjDHpIqFgIiJrgeFhNt0EzAa+pqp/EJFPAPcBcwj/NVqjpNPFMm0TVZcDywFKSkp6xC0RbR0SY0y6SCiYqOqcSNtE5AHgK+7X3wO/cs/LgVG+rMV4XWDl7nn7dH+ZchHJwus2q3Lps9qVebbzR2KMMSYRqRwz2Qdc6p5fBuxwz1cBi90MrXF4A+2vqOp+oEZEZrrxkKuBlb4yoZlai4Cn3bjKamCuiAx0A+9zXVrvkKSGiXWXGWOClsoxk38F/su1JOpw4xWqullEHgHeApqA61S12ZW5FvgNkAc84R7gdZE9KCJleC2Sxa6uKhG5FVjv8t2iqlUpPKakOnSsLuhdMMaYpEhZMFHVF4DpEbbdBtwWJr0UmBImvQ64KkJdK4AVCe1sQNbvrg56F4wxJinsCvgA2F2DjTHpxoKJMcaYhFkwMcYYkzALJsYYYxJmwSQNiM0NNsYEzIKJMcaYhFkwSSPedZzGGNP9LJgEINndUqHarLvLGBMUCybGGGMSZsHEGGNMwiyYGGOMSZgFkwAke6DchkqMMUGzYJJGLKgYY4JiwSQANoHXGJNuLJgE6F8vHpeUeqxFYowJmgWTAPXLzU5qfRZUjDFBSSiYiMhVIrJZRFpEpKTdthtFpExEtonIPF/6dBF502272y3Ri1vG92GX/rKIjPWVWSIiO9xjiS99nMu7w5XNSeR4upud/I0x6SLRlskm4ErgeX+iiEzCW1p3MjAfuEdEMt3me/GW8J3gHvNd+jVAtaqOB+4Cbnd1DQKWARcAM4Blbr13XJ67VHUCUO3q6PlSNGgitkyWMSYgCQUTVd2iqtvCbFoAPKSq9ar6NlAGzBCREUB/VV2n3vzYB4CFvjL3u+ePArNdq2UesEZVq1S1GlgDzHfbLnN5cWVDdfUKyTr1WxAxxgQtVWMmI4E9vt/LXdpI97x9epsyqtoEHAWKotRVBBxxedvX1YGILBWRUhEpraio6OJhJYfN5jLGpJusWBlEZC0wPMymm1R1ZaRiYdI0SnpXykSrq+MG1eXAcoCSkpIecT5P9piJjcEYY4ISM5io6pwu1FsOjPL9Xgzsc+nFYdL9ZcpFJAsYAFS59FntyjwLVAKFIpLlWif+ut5bLIgYYwKWqm6uVcBiN0NrHN5A+yuquh+oEZGZbszjamClr0xoptYi4Gk3rrIamCsiA93A+1xgtdv2jMuLKxuppdQj2S3jjTHpItGpwR8XkXLgQuBxEVkNoKqbgUeAt4AngetUtdkVuxb4Fd6g/E7gCZd+H1AkImXA9cANrq4q4FZgvXvc4tIAvg1c78oUuTresyw0GWOCErObKxpV/SPwxwjbbgNuC5NeCkwJk14HXBWhrhXAijDpu/CmC/cqtiKiMSbd2BXwacBaJMaYoFkwCUCq2iU2BmOMCYoFkwDZud8Yky4smKQBa5EYY4JmwSRAyb4NioUUY0xQLJgEIGWTuSyaGGMCYsEkQNY7ZYxJFxZMAqBJns9lMckYEzQLJgGyIGCMSRcWTAKQqjETC07GmKBYMAlQssZMbOzFGBM0CybGGGMSZsEkQEm/zsSaKMaYgFgwCYDdNNgYk24smAQoaWMmNvRujAmYBZM0YiHFGBOURFdavEpENotIi4iU+NI/JCKvisib7udlvm3TXXqZiNztlu/FLfH7sEt/WUTG+sosEZEd7rHElz7O5d3hyuYkcjzGGGO6JtGWySbgSuD5dumVwMdU9Wy8tdkf9G27F1iKty78BGC+S78GqFbV8cBdwO0AIjIIWAZcgLeq4jK3Fjwuz12qOgGodnW859i4uzEmaAkFE1XdoqrbwqS/rqr73K+bgT6u5TEC6K+q69Rbu/YBYKHLtwC43z1/FJjtWi3zgDWqWqWq1cAaYL7bdpnLiysbqqtXSPbsKwsqxpigdMeYyT8Ar6tqPTASKPdtK3dpuJ97AFS1CTgKFPnT25UpAo64vO3r6kBElopIqYiUVlRUJHxQiUjVGvA2EG+MCUpWrAwishYYHmbTTaq6MkbZyXhdUXNDSWGyaYxtnU0PS1WXA8sBSkpKesTkXDv1G2PSRcxgoqpzulKxiBQDfwSuVtWdLrkcKPZlKwb2+baNAspFJAsYAFS59FntyjyLNy5TKCJZrnXir6tHS3Yks6BkjAlaSrq5RKQQeBy4UVX/HkpX1f1AjYjMdGMeVwOh1s0qvMF6gEXA025cZTUwV0QGuoH3ucBqt+0ZlxdXNmpLqadJ9hhHZoaFFWNMMBKdGvxxESkHLgQeF5HVbtMXgfHA90Rkg3sMdduuBX4FlAE7gSdc+n1AkYiUAdcDNwCoahVwK7DePW5xaQDfBq53ZYpcHT1eqq6Az860YGKMCUbMbq5oVPWPeF1Z7dN/APwgQplSYEqY9DrgqghlVgArwqTvwpsu3Csl7dTvKsrKtGtQjTHBsLNPOnAtnSzr5jLGBMSCSYCSdZ1JY4sXTbKtZWKMCYidfQKQ7DXgW9wgjA3AG2OCYsEkQMmazRW6CNJiiTEmKBZM0kBzi/czw+6nYowJiAWTACR7anCom8tWWjTGBMWCSQDceHnSWhLaWl9SqjPGmE6zYBKIUEsiObW1tI6ZWDQxxgTDgkkANMktE5vNZYwJmgWTALSOcSStPu+nNUyMMUGxYBKA0Ph78sZMrJvLGBMsCyYBaGltSiSnvuYWu87EGBMsCyYBOHXyT9aYCUmtzxhjOsuCSQCONzQDyWtJ2HUmxpigWTAJULJvp2L3eTTGBMVOPwFKdjeX2AK+xpiAJLrS4lUisllEWkSkJMz20SJSKyLf8KVNF5E3RaRMRO52y/ciIrki8rBLf1lExvrKLBGRHe6xxJc+zuXd4crmJHI8vVXrRYv21cAYE5BETz+bgCuB5yNsv4tTy/KG3AssBSa4x3yXfg1QrarjXbnbAURkELAMuABvVcVlbi14XJ67VHUCUO3q6DVsAN4Yky4SCiaqukVVt4XbJiILgV3AZl/aCKC/qq5Tr6P/AWCh27wAuN89fxSY7Vot84A1qlqlqtXAGmC+23aZy4srG6qrV0haMEny7DBjjOmslHSMiEhf4NvAze02jQTKfb+Xu7TQtj0AqtoEHAWK/OntyhQBR1ze9nWF26elIlIqIqUVFRVdOaykS/69uZJTnzHGdFbMYCIia0VkU5jHgijFbsbrfqptX12YvBpjW2fTw1LV5apaoqolQ4YMiZStWyVvarD306YGG2OCkhUrg6rO6UK9FwCLROQOoBBoEZE64A9AsS9fMbDPPS8HRgHlIpIFDACqXPqsdmWeBSqBQhHJcq0Tf129RHJvp2I3ejTGBCUl3VyqerGqjlXVscBPgR+q6s9VdT9QIyIz3ZjH1cBKV2wVEJqptQh42o2rrAbmishAN/A+F1jttj3j8uLKhurqFZJ17g+NlWTbhSbGmIDEbJlEIyIfB34GDAEeF5ENqjovRrFrgd8AeXgzvUKzve4DHhSRMrwWyWIAVa0SkVuB9S7fLapa5Z5/G3hIRH4AvO7q6DWS1S31qQtGs/fISb502fik1GeMMZ0lmuw1ZHuBkpISLS0tDez1x97wOAD3LSlh9vuGBbYfxhjTGSLyqqp2uKYQ7Ar4QNlUXmNMurBgEiSLJcaYNGHBJECHaxuC3gVjjEkKCyYBKq8+EfQuGGNMUlgwCZDd5dcYky4smATIrjE0xqQLCyYBsslcxph0YcEkQHYvLWNMurBgEiCLJcaYdGHBJEA2AG+MSRcWTAJkLRNjTLqwYBIgm81ljEkXFkwClGO3jDfGpAk7mwUgO9Nrknzi/FEB74kxxiRHQuuZmK554isX8/TWQ+Tn2J/fGJMe7GwWgPFDCxg/tCDo3TDGmKRJqJtLRK4Skc0i0iIiJe22TRWRdW77myLSx6VPd7+XicjdbvleRCRXRB526S+LyFhfXUtEZId7LPGlj3N5d7iyOYkcjzHGmK5JdMxkE3Al8Lw/UUSygP8FPq+qk4FZQKPbfC+wFJjgHvNd+jVAtaqOB+4Cbnd1DQKWARcAM4Blbi14XJ67VHUCUO3qMMYY080SCiaqukVVt4XZNBd4Q1U3unyHVbVZREYA/VV1nXrrBT8ALHRlFgD3u+ePArNdq2UesEZVq1S1GlgDzHfbLnN5cWVDdRljjOlGqZrNNRFQEVktIq+JyLdc+kig3Jev3KWFtu0BUNUm4ChQ5E9vV6YIOOLytq+rAxFZKiKlIlJaUVGR0MEZY4xpK+YAvIisBYaH2XSTqq6MUu9FwPnACeApEXkVOBYmr4ZeKsK2zqaHparLgeUAJSUlEfMZY4zpvJjBRFXndKHecuA5Va0EEJG/ANPwxlGKffmKgX2+MqOAcjfmMgCocumz2pV5FqgECkUky7VO/HUZY4zpRqnq5loNTBWRfBcYLgXeUtX9QI2IzHRjHlcDodbNKiA0U2sR8LQbV1kNzBWRgW7gfS6w2m17xuXFlY3UUjLGGJNCiU4N/riIlAMXAo+LyGoAN1B+J7Ae2AC8pqqPu2LXAr8CyoCdwBMu/T6gSETKgOuBG1xdVcCtrq71wC0uDeDbwPWuTJGrwxhjTDcT7wv+e4uIVADvdLH4YLwutnRgx9JzpdPx2LH0TF05ljGqOiTchvdkMEmEiJSqaknsnD2fHUvPlU7HY8fSMyX7WOxGj8YYYxJmwcQYY0zCLJh03vKgdyCJ7Fh6rnQ6HjuWnimpx2JjJsYYYxJmLRNjjDEJs2BijDEmYRZM4iQi80Vkm1tv5Yag9yceIrLbrR2zQURKXdogEVnj1oBZ47udPyJyozu+bSIyL7g9b92fFSJySEQ2+dI6vf+R1tDpAcfy7yKy170/G0Tkw73kWEaJyDMissWtV/QVl97r3psox9Lr3hsR6SMir4jIRncsN7v07nlfVNUeMR5AJt7V+qcDOcBGYFLQ+xXHfu8GBrdLuwO4wT2/AbjdPZ/kjisXGOeONzPg/b8E755umxLZf+AVvLs0CN4dFy7vIcfy78A3wuTt6ccyApjmnhcA290+97r3Jsqx9Lr3xr1uP/c8G3gZmNld74u1TOIzAyhT1V2q2gA8hLf+Sm/kXzfGvwbMAuAhVa1X1bfxbnczo/t37xRVfR7vZp9+ndp/ib6GTreJcCyR9PRj2a+qr7nnNcAWvOUfet17E+VYIunJx6KqWut+zXYPpZveFwsm8Ym0pkpPp8BfReRVEVnq0oapd8NN3M+hLr23HGNn9z/aGjo9wRdF5A3XDRbqfug1xyLe8trn4X0L7tXvTbtjgV743ohIpohsAA7hLSrYbe+LBZP4dGrtlB7kA6o6DbgcuE5ELomSt7ceY0hS1r3pZvcCZwDnAvuB/3TpveJYRKQf8Afgq6oabq2i1qxh0nrU8YQ5ll753qhqs6qei7ckxwwRmRIle1KPxYJJfEJrrYT0irVTVHWf+3kI+CNet9VB14zF/TzksveWY+zs/pcTeQ2dQKnqQffhbwH+h1Pdij3+WEQkG+/k+1tVfcwl98r3Jtyx9Ob3BkBVj+Ct+zSfbnpfLJjEZz0wQUTGiUgOsBhv/ZUeS0T6ikhB6DneOjCbaLtujH8NmFXAYhHJFZFxwAS8QbieplP7r9HX0AlU6APufBzv/YEefizute8Dtqjqnb5Nve69iXQsvfG9EZEhIlLonucBc4CtdNf70p2zDXrzA/gw3kyPnXhLFge+TzH293S8mRobgc2hfcZb9+UpYIf7OchX5iZ3fNsIYJZQmGP4HV4XQyPet6VrurL/QAneyWAn8HPcnR96wLE8CLwJvOE+2CN6ybFchNft8QbeekUb3Oej1703UY6l1703wFTgdbfPm4Dvu/RueV/sdirGGGMSZt1cxhhjEmbBxBhjTMIsmBhjjEmYBRNjjDEJs2BijDEmYRZMjDHGJMyCiTHGmIT9f5RTXelKdFduAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(elbos[25:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "581b2ed4",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not train:\n",
    "    BART = SklearnModel(n_trees=100,n_burn = 200, n_samples = 200) # Use default parameters\n",
    "    BART.fit(X, Y)\n",
    "    train = True "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "d39dee23",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                                                       \r"
     ]
    }
   ],
   "source": [
    "yhat_VaRT = VaRT.predict(X_tree)\n",
    "yhat_BART = BART.predict(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "536324c2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.3024682456904126, 0.2252530582550627)"
      ]
     },
     "execution_count": 65,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.sqrt(((yhat_BART - Y)**2)).mean() , VaRT.rmse(yhat_VaRT, y_tree)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "2e822ddb",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                                                       \r"
     ]
    }
   ],
   "source": [
    "X_0 = X.copy()\n",
    "X_0['treat'] = 0\n",
    "X_1 = X.copy()\n",
    "X_1['treat'] = 1 \n",
    "\n",
    "y1_BART = scaler_y.inverse_transform(BART.predict(X_1).reshape(-1,1)).reshape(-1)\n",
    "y0_BART = scaler_y.inverse_transform(BART.predict(X_0).reshape(-1,1)).reshape(-1)\n",
    "\n",
    "y1_VaRT = scaler_y.inverse_transform(VaRT.predict(torch.tensor(X_1.values, device=DEVICE), samples=1000).cpu().numpy().reshape(-1,1)).reshape(-1)\n",
    "y0_VaRT = scaler_y.inverse_transform(VaRT.predict(torch.tensor(X_0.values, device=DEVICE), samples=1000).cpu().numpy().reshape(-1,1)).reshape(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "id": "0b69ad78",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(3.522810318414722, 3.78311)"
      ]
     },
     "execution_count": 67,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(y1_BART - y0_BART).mean(), (y1_VaRT - y0_VaRT).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0f2e7b3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
