{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Running Experiment for STAR dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Current working directory: /hpc/home/aa671/phd/generating_confounders/Deconfounding-MD\n",
      "Device:  cuda:0\n"
     ]
    }
   ],
   "source": [
    "# import libraries\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from generate_data import *\n",
    "from train_models import *\n",
    "from utils import *\n",
    "from generate_data import *\n",
    "from models import *\n",
    "\n",
    "seed = 2024\n",
    "np.random.seed(seed)\n",
    "print(\"Current working directory:\", os.getcwd())\n",
    "\n",
    "# check if CUDA is available\n",
    "use_cuda = torch.cuda.is_available()\n",
    "device = torch.device(\"cuda:0\" if use_cuda else \"cpu\")\n",
    "print(\"Device: \", device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Current working directory: /hpc/home/aa671/phd/generating_confounders/Deconfounding-MD\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/hpc/home/aa671/.local/lib/python3.8/site-packages/pyreadr/_pyreadr_parser.py:260: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider using pd.concat instead.  To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  self.df['rownames'] = self.row_names\n",
      "/hpc/home/aa671/phd/generating_confounders/Deconfounding-MD/generate_data.py:165: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider using pd.concat instead.  To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df['treatment'] = (df['g1classtype']=='SMALL CLASS').astype(int)\n",
      "/hpc/home/aa671/phd/generating_confounders/Deconfounding-MD/generate_data.py:166: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider using pd.concat instead.  To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df['outcome'] = df['g1tlistss'] + df['g1treadss'] + df['g1tmathss']\n",
      "/hpc/home/aa671/phd/generating_confounders/Deconfounding-MD/generate_data.py:167: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider using pd.concat instead.  To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df['rural'] = (df['g1surban'] == 'RURAL') | (df['g1surban'] == 'INNER CITY')\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Size of unconfounded (RCT) features: (1280, 8)\n",
      "Size of confounded (Observational) features: (626, 8)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/hpc/home/aa671/.local/lib/python3.8/site-packages/sklearn/base.py:443: UserWarning: X has feature names, but StandardScaler was fitted without feature names\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "data = read_START_data()\n",
    "\n",
    "# Print the shape of unconfounded and confounded features\n",
    "x_unc = data['x_unc']\n",
    "x_conf = data['x_conf']\n",
    "\n",
    "print(f\"Size of unconfounded (RCT) features: {x_unc.shape}\")\n",
    "print(f\"Size of confounded (Observational) features: {x_conf.shape}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_to_tensors(x, t, y, device='cpu'):\n",
    "    X_tensor = torch.tensor(x, dtype=torch.float32, device=device)\n",
    "    T_tensor = torch.tensor(t, dtype=torch.float32, device=device)\n",
    "    Y_tensor = torch.tensor(y, dtype=torch.float32, device=device)\n",
    "    return X_tensor, T_tensor, Y_tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Size of test features: torch.Size([1152, 8])\n",
      "Size of observational train features: torch.Size([626, 8])\n",
      "Size of RCT train: torch.Size([128, 1])\n"
     ]
    }
   ],
   "source": [
    "# Get the unconfounded (RCT) and confounded (observational) data\n",
    "x_unc = data['x_unc']\n",
    "t_unc = data['t_unc']\n",
    "y_unc = data['y_unc']\n",
    "\n",
    "x_conf = data['x_conf']\n",
    "t_conf = data['t_conf']\n",
    "y_conf = data['y_conf']\n",
    "\n",
    "# Split the RCT (unconfounded) data into training and test sets\n",
    "_, X_rct_test, T_rct_train, T_rct_test, Y_rct_train, Y_rct_test = train_test_split(x_unc, t_unc, y_unc, test_size=0.9, random_state=42)\n",
    "\n",
    "# Convert RCT (unconfounded) and observational (confounded) data into tensors\n",
    "_, T_rct_train_tensor, Y_rct_train_tensor = convert_to_tensors(_, T_rct_train, Y_rct_train, device=device)\n",
    "X_rct_test_tensor, T_rct_test_tensor, Y_rct_test_tensor = convert_to_tensors(X_rct_test, T_rct_test, Y_rct_test)\n",
    "\n",
    "X_obs_train_tensor, T_obs_train_tensor, Y_obs_train_tensor = convert_to_tensors(x_conf, t_conf, y_conf, device=device)\n",
    "\n",
    "observation_data = TensorDataset(X_obs_train_tensor, T_obs_train_tensor, Y_obs_train_tensor)\n",
    "rct_data = TensorDataset(T_rct_train_tensor, Y_rct_train_tensor)\n",
    "# Print the sizes to check everything\n",
    "print(f\"Size of test features: {X_rct_test_tensor.shape}\")\n",
    "print(f\"Size of observational train features: {X_obs_train_tensor.shape}\")\n",
    "print(f\"Size of RCT train: {T_rct_train_tensor.shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs = 1500"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "import torch.optim as optim\n",
    "import torch.nn as nn\n",
    "\n",
    "def train_baseline(observational_data, baseline_cate_learner, num_epochs=1000, batch_size=256, device='cpu', lr=0.001):\n",
    "    # Loss function\n",
    "    mse = nn.MSELoss()\n",
    "\n",
    "    # Optimizer\n",
    "    optimizer = optim.Adam(baseline_cate_learner.parameters(), lr=lr)\n",
    "\n",
    "    # Create a DataLoader for batching\n",
    "    data_loader = DataLoader(observational_data, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "    # Training loop\n",
    "    for epoch in range(num_epochs):\n",
    "        epoch_loss = 0.0\n",
    "        for X_batch, T_batch, Y_batch in data_loader:\n",
    "            # Move data to device (if using GPU)\n",
    "            X_batch, T_batch, Y_batch = X_batch.to(device), T_batch.to(device), Y_batch.to(device)\n",
    "            \n",
    "            # Concatenate X and T as input to the model\n",
    "            XT = torch.cat((X_batch, T_batch), dim=1)\n",
    "            \n",
    "            # Forward pass\n",
    "            Y_pred = baseline_cate_learner(XT)\n",
    "            \n",
    "            # Calculate loss\n",
    "            loss = mse(Y_pred, Y_batch)\n",
    "            epoch_loss += loss.item()\n",
    "\n",
    "            # Backward pass and optimize\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "        # Print loss every 100 epochs\n",
    "        if epoch % 100 == 0:\n",
    "            print(f'Epoch {epoch}, Loss MSE: {epoch_loss / len(data_loader)}')\n",
    "\n",
    "    return baseline_cate_learner\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0, Loss MSE: 2.597060441970825\n",
      "Epoch 100, Loss MSE: 0.14329662919044495\n",
      "Epoch 200, Loss MSE: 0.138443852464358\n",
      "Epoch 300, Loss MSE: 0.13961943238973618\n",
      "Epoch 400, Loss MSE: 0.13998555143674216\n",
      "Epoch 500, Loss MSE: 0.13782229522864023\n",
      "Epoch 600, Loss MSE: 0.12859762956698736\n",
      "Epoch 700, Loss MSE: 0.12747720380624136\n",
      "Epoch 800, Loss MSE: 0.12746480604012808\n",
      "Epoch 900, Loss MSE: 0.12753313034772873\n",
      "Epoch 1000, Loss MSE: 0.13129897912343344\n",
      "Epoch 1100, Loss MSE: 0.12470560520887375\n",
      "Epoch 1200, Loss MSE: 0.1231997733314832\n",
      "Epoch 1300, Loss MSE: 0.12226802110671997\n",
      "Epoch 1400, Loss MSE: 0.12512065966924033\n"
     ]
    }
   ],
   "source": [
    "dim_X = X_obs_train_tensor.shape[1]\n",
    "\n",
    "# Training loop for the baseline\n",
    "baseline_cate_learner = snet(input_dim=dim_X+1, hidden_dim=16)\n",
    "baseline_cate_learner.to(device)\n",
    "\n",
    "baseline_cate_learner =  train_baseline(observation_data, baseline_cate_learner,num_epochs=num_epochs, device=device, lr=0.001)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Marginals and Projections Balancing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "import torch.optim as optim\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "\n",
    "def train_model_mb_plus_pb(observational_data, rct_data, model_g, model_f, generator, cate_learner, \n",
    "                           alpha_start=100, alpha_end=0.01, num_epochs=500, \n",
    "                           balancing_iterations_start=5, balancing_iterations_end=100, \n",
    "                           generator_input_dim=1, batch_size=512, device='cpu', \n",
    "                           lr_g=0.001, lr_te=0.001, lr_f=0.001):\n",
    "    \n",
    "    # Extract RCT data (treatment and outcome)\n",
    "    T_rct, Y_rct = rct_data[:][0], rct_data[:][1]\n",
    "    T_rct, Y_rct = T_rct.view(-1, 1), Y_rct  # Ensure correct shapes\n",
    "\n",
    "    # Loss function\n",
    "    mse = nn.MSELoss()\n",
    "\n",
    "    # Optimizers for different models\n",
    "    optimizer_g = optim.Adam(generator.parameters(), lr=lr_g)\n",
    "    optimizer_te = optim.Adam(cate_learner.parameters(), lr=lr_te)\n",
    "    optimizer_f = optim.Adam(model_f.parameters(), lr=lr_f)\n",
    "    optimizer_g_model = optim.Adam(model_g.parameters(), lr=lr_g)\n",
    "\n",
    "    # Prepare DataLoader for the observational dataset\n",
    "    data_loader = DataLoader(observational_data, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "    # Training loop\n",
    "    for epoch in range(num_epochs):\n",
    "        epoch_loss = 0.0\n",
    "\n",
    "        # Update alpha and balancing iterations based on epoch\n",
    "        if epoch < int((num_epochs * 2 / 3) - 100):\n",
    "            alpha = alpha_start\n",
    "            balancing_iterations = balancing_iterations_start\n",
    "        elif epoch > int((num_epochs * 2 / 3) + 100):\n",
    "            alpha = alpha_end\n",
    "            balancing_iterations = balancing_iterations_end\n",
    "        else:\n",
    "            balancing_iterations = int((balancing_iterations_start + balancing_iterations_end) / 2)\n",
    "            alpha = alpha_start - (alpha_start - alpha_end) * (epoch - int(num_epochs * 2 / 3 - 100)) / 200\n",
    "\n",
    "        # Loop through observational dataset in batches\n",
    "        for X_batch, T_batch, Y_batch in data_loader:\n",
    "            X_batch, T_batch, Y_batch = X_batch.to(device), T_batch.to(device), Y_batch.to(device)\n",
    "\n",
    "            # Generate U_hat from generator\n",
    "            Z = torch.randn(X_batch.shape[0], generator_input_dim, device=device)\n",
    "            U_hat = generator(Z)\n",
    "\n",
    "            # Prepare inputs for CATE learner\n",
    "            input_te = torch.cat((X_batch, U_hat, T_batch), dim=1)\n",
    "            Y_pred = cate_learner(input_te)\n",
    "\n",
    "            # Predict for T=1 and T=0\n",
    "            input_te_one = torch.cat((X_batch, U_hat, torch.ones(X_batch.shape[0], 1, device=device)), dim=1)\n",
    "            Y_pred_one = cate_learner(input_te_one)\n",
    "\n",
    "            input_te_zero = torch.cat((X_batch, U_hat, torch.zeros(X_batch.shape[0], 1, device=device)), dim=1)\n",
    "            Y_pred_zero = cate_learner(input_te_zero)\n",
    "\n",
    "\n",
    "            if X_batch.shape[0] > Y_rct[T_rct == 1].shape[0]:\n",
    "                idx = torch.randint(0, X_batch.shape[0], (Y_rct[T_rct == 1].shape[0],), device=device)\n",
    "                X_batch_small = X_batch[idx]\n",
    "            else:\n",
    "                X_batch_small = X_batch  # Default to full batch if the condition is false\n",
    "\n",
    "\n",
    "            f_Y_rct_1 = model_f(X_batch_small) * Y_rct[T_rct == 1].view(-1, 1)\n",
    "            f_Y_pred_1 = model_f(X_batch) * Y_pred_one.view(-1, 1)\n",
    "\n",
    "            if X_batch.shape[0] > Y_rct[T_rct == 0].shape[0]:\n",
    "                idx = torch.randint(0, X_batch.shape[0], (Y_rct[T_rct == 0].shape[0],), device=device)\n",
    "                X_batch_small = X_batch[idx]\n",
    "            else:\n",
    "                X_batch_small = X_batch  # Default to full batch if the condition is false\n",
    "\n",
    "\n",
    "            f_Y_rct_0 = model_f(X_batch_small) * Y_rct[T_rct == 0].view(-1, 1)\n",
    "            f_Y_pred_0 = model_f(X_batch) * Y_pred_zero.view(-1, 1)\n",
    "\n",
    "            g_Y_pred_1 = model_g(Y_pred_one.view(-1, 1))\n",
    "            g_Y_rct_1 = model_g(Y_rct[T_rct == 1].view(-1, 1))\n",
    "            g_Y_pred_0 = model_g(Y_pred_zero.view(-1, 1))\n",
    "            g_Y_rct_0 = model_g(Y_rct[T_rct == 0].view(-1, 1))\n",
    "\n",
    "            # Compute losses\n",
    "            loss1 = mse(Y_pred, Y_batch)\n",
    "            loss2 = mse(f_Y_pred_1.mean(), f_Y_rct_1.mean())\n",
    "            loss3 = mse(f_Y_pred_0.mean(), f_Y_rct_0.mean())\n",
    "            loss2p = mse(g_Y_pred_1.mean(), g_Y_rct_1.mean())\n",
    "            loss3p = mse(g_Y_pred_0.mean(), g_Y_rct_0.mean())\n",
    "            loss = alpha * loss1 + loss2 + loss3 + loss2p + loss3p\n",
    "\n",
    "            # Update generator and CATE learner\n",
    "            optimizer_te.zero_grad()\n",
    "            optimizer_g.zero_grad()\n",
    "            loss.backward(retain_graph=True)\n",
    "            optimizer_te.step()\n",
    "            optimizer_g.step()\n",
    "\n",
    "            epoch_loss += loss.item()\n",
    "\n",
    "            # Balancing iterations\n",
    "            for _ in range(balancing_iterations):\n",
    "\n",
    "                if X_batch.shape[0] > Y_rct[T_rct == 1].view(-1, 1).shape[0]:\n",
    "                    idx = np.random.choice(X_batch.shape[0], Y_rct[T_rct == 1].view(-1, 1).shape[0], replace=False)\n",
    "                    X_batch_small = X_batch[idx]\n",
    "                f_Y_rct_1 = model_f(X_batch_small)*Y_rct[T_rct == 1].view(-1, 1).detach()\n",
    "                f_Y_pred_1 = model_f(X_batch)*Y_pred_one.view(-1, 1).detach()\n",
    "                if X_batch.shape[0] > Y_rct[T_rct == 0].view(-1, 1).shape[0]:\n",
    "                    idx = np.random.choice(X_batch.shape[0], Y_rct[T_rct == 0].view(-1, 1).shape[0], replace=False)\n",
    "                    X_batch_small = X_batch[idx]\n",
    "                f_Y_rct_0 = model_f(X_batch_small)*Y_rct[T_rct == 0].view(-1, 1).detach()\n",
    "                f_Y_pred_0 = model_f(X_batch)*Y_pred_zero.view(-1, 1).detach()\n",
    "\n",
    "                g_Y_pred_1 = model_g(Y_pred_one.view(-1, 1)).detach()\n",
    "                g_Y_rct_1 = model_g(Y_rct[T_rct == 1].view(-1, 1)).detach()\n",
    "                g_Y_pred_0 = model_g(Y_pred_zero.view(-1, 1)).detach()\n",
    "                g_Y_rct_0 = model_g(Y_rct[T_rct == 0].view(-1, 1)).detach()\n",
    "\n",
    "                loss4 = mse(f_Y_rct_1.mean(), f_Y_pred_1.mean())\n",
    "                loss5 = mse(f_Y_rct_0.mean(), f_Y_pred_0.mean())\n",
    "                loss4p = mse(g_Y_pred_1.mean(), g_Y_rct_1.mean())\n",
    "                loss5p = mse(g_Y_pred_0.mean(), g_Y_rct_0.mean())\n",
    "\n",
    "                loss_f = -loss4 - loss5 -  loss4p - loss5p\n",
    "                #print(\"loss_f = \", loss_f)\n",
    "                optimizer_f.zero_grad()\n",
    "                optimizer_g_model.zero_grad()\n",
    "                loss_f.backward()\n",
    "                optimizer_f.step()\n",
    "                optimizer_g_model.step()\n",
    "\n",
    "        # Logging and printing every 100 epochs\n",
    "        if epoch % 100 == 0:\n",
    "            print(f'Epoch {epoch}, Loss: {epoch_loss / len(data_loader)}, Alpha: {alpha}, Balancing Iterations: {balancing_iterations}')\n",
    "\n",
    "    return generator, cate_learner\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0, Loss: 23.724104563395183, Alpha: 10, Balancing Iterations: 5\n",
      "Epoch 100, Loss: 4.625618775685628, Alpha: 10, Balancing Iterations: 5\n",
      "Epoch 200, Loss: 4.538351694742839, Alpha: 10, Balancing Iterations: 5\n",
      "Epoch 300, Loss: 4.336471239725749, Alpha: 10, Balancing Iterations: 5\n",
      "Epoch 400, Loss: 4.240451494852702, Alpha: 10, Balancing Iterations: 5\n",
      "Epoch 500, Loss: 4.117396195729573, Alpha: 10, Balancing Iterations: 5\n",
      "Epoch 600, Loss: 4.078203280766805, Alpha: 10, Balancing Iterations: 5\n",
      "Epoch 700, Loss: 4.0095163981119795, Alpha: 10, Balancing Iterations: 5\n",
      "Epoch 800, Loss: 4.041411558787028, Alpha: 10, Balancing Iterations: 5\n",
      "Epoch 900, Loss: 3.848877191543579, Alpha: 10.0, Balancing Iterations: 52\n",
      "Epoch 1000, Loss: 2.676098585128784, Alpha: 5.05, Balancing Iterations: 52\n",
      "Epoch 1100, Loss: 0.18788951138655344, Alpha: 0.09999999999999964, Balancing Iterations: 52\n",
      "Epoch 1200, Loss: 0.13917935887972513, Alpha: 0.1, Balancing Iterations: 100\n",
      "Epoch 1300, Loss: 0.14433428645133972, Alpha: 0.1, Balancing Iterations: 100\n",
      "Epoch 1400, Loss: 0.1340770572423935, Alpha: 0.1, Balancing Iterations: 100\n"
     ]
    }
   ],
   "source": [
    "\n",
    "dim_X = X_obs_train_tensor.shape[1]\n",
    "gen_input_dim = 2\n",
    "gen_output_dim = 1   \n",
    "\n",
    "# Define model_f, model_g, and the generator\n",
    "model_f = BoundedContinuousFunctionModel(input_dim=dim_X, output_dim=1)\n",
    "model_f.to(device)\n",
    "\n",
    "model_g = BoundedContinuousFunctionModel(input_dim=1, output_dim=1)\n",
    "model_g.to(device)\n",
    "\n",
    "generator = Generator(input_dim=gen_input_dim, output_dim=gen_output_dim)\n",
    "generator.to(device)\n",
    "\n",
    "# Define the CATE learner\n",
    "cate_learner = snet(input_dim=dim_X + 1 + gen_output_dim, hidden_dim=16)\n",
    "cate_learner.to(device)\n",
    "\n",
    "# Train the new models with the mb+pb method\n",
    "generator, cate_learner = train_model_mb_plus_pb(\n",
    "    observational_data=observation_data,\n",
    "    rct_data=rct_data,\n",
    "    model_g=model_g,\n",
    "    model_f=model_f,\n",
    "    generator=generator,\n",
    "    cate_learner=cate_learner,\n",
    "    alpha_start=100,\n",
    "    alpha_end=0.01,\n",
    "    generator_input_dim=gen_input_dim,\n",
    "    num_epochs=num_epochs,\n",
    "    batch_size=256,\n",
    "    device=device,\n",
    "    lr_g=0.001,\n",
    "    lr_te=0.001,\n",
    "    lr_f=0.001\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0, Loss MSE: 1.0226651787757874\n",
      "Epoch 100, Loss MSE: 0.7902319550514221\n",
      "Epoch 200, Loss MSE: 0.7665843009948731\n",
      "Epoch 300, Loss MSE: 0.7311073303222656\n",
      "Epoch 400, Loss MSE: 0.7336826682090759\n",
      "Epoch 500, Loss MSE: 0.7289070248603821\n",
      "Epoch 600, Loss MSE: 0.7151643633842468\n",
      "Epoch 700, Loss MSE: 0.7025520443916321\n",
      "Epoch 800, Loss MSE: 0.6910792827606201\n",
      "Epoch 900, Loss MSE: 0.6848186254501343\n",
      "Epoch 1000, Loss MSE: 0.664492154121399\n",
      "Epoch 1100, Loss MSE: 0.6720301270484924\n",
      "Epoch 1200, Loss MSE: 0.6642081618309021\n",
      "Epoch 1300, Loss MSE: 0.6611286759376526\n",
      "Epoch 1400, Loss MSE: 0.6590203881263733\n"
     ]
    }
   ],
   "source": [
    "# Train a baseline model on the RCT_test call it oracle\n",
    "oracle = snet(input_dim=dim_X + 1, hidden_dim=16)\n",
    "oracle.to(device)\n",
    "rct_data_test = TensorDataset(X_rct_test_tensor, T_rct_test_tensor, Y_rct_test_tensor)\n",
    "\n",
    "oracle = train_baseline(rct_data_test, oracle, num_epochs=num_epochs, device=device, lr=0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EPEHE for mb+pb model: 0.5298659801483154\n",
      "EPEHE for baseline model: 2.6846933364868164\n"
     ]
    }
   ],
   "source": [
    "# Move all models and data to CPU for evaluation\n",
    "oracle = oracle.to('cpu')\n",
    "baseline_cate_learner = baseline_cate_learner.to('cpu')\n",
    "cate_learner = cate_learner.to('cpu')\n",
    "generator = generator.to('cpu')\n",
    "\n",
    "# Move test data to CPU\n",
    "X_rct_test_tensor = X_rct_test_tensor.to('cpu')\n",
    "T_one = torch.ones(X_rct_test_tensor.shape[0], 1)\n",
    "T_zero = torch.zeros(X_rct_test_tensor.shape[0], 1)\n",
    "\n",
    "# Oracle ITE (True ITE)\n",
    "Y_pred_one_test = oracle(torch.cat((X_rct_test_tensor, T_one), dim=1))\n",
    "Y_pred_zero_test = oracle(torch.cat((X_rct_test_tensor, T_zero), dim=1))\n",
    "true_ite = Y_pred_one_test - Y_pred_zero_test\n",
    "\n",
    "# Baseline ITE\n",
    "Y_pred_one_baseline = baseline_cate_learner(torch.cat((X_rct_test_tensor, T_one), dim=1))\n",
    "Y_pred_zero_baseline = baseline_cate_learner(torch.cat((X_rct_test_tensor, T_zero), dim=1))\n",
    "baseline_ite = Y_pred_one_baseline - Y_pred_zero_baseline\n",
    "\n",
    "# MB+PB ITE with averaging over multiple U samples\n",
    "n_samples = 10  # Number of U samples\n",
    "y_list_one = []\n",
    "y_list_zero = []\n",
    "\n",
    "for _ in range(n_samples):\n",
    "    # Sample U from generator\n",
    "    Z_test = torch.randn(X_rct_test_tensor.shape[0], gen_input_dim)\n",
    "\n",
    "    # Generate U_hat_test using generator\n",
    "    U_hat_test = generator(Z_test)\n",
    "\n",
    "    # Compute predictions for T=1\n",
    "    test_input_te_one = torch.cat((X_rct_test_tensor, U_hat_test, T_one), dim=1)\n",
    "    Y_pred_one_model = cate_learner(test_input_te_one)\n",
    "    y_list_one.append(Y_pred_one_model)\n",
    "\n",
    "    # Compute predictions for T=0\n",
    "    test_input_te_zero = torch.cat((X_rct_test_tensor, U_hat_test, T_zero), dim=1)\n",
    "    Y_pred_zero_model = cate_learner(test_input_te_zero)\n",
    "    y_list_zero.append(Y_pred_zero_model)\n",
    "\n",
    "# Average predictions over all samples of U\n",
    "Y_pred_one_avg = torch.mean(torch.stack(y_list_one), dim=0)\n",
    "Y_pred_zero_avg = torch.mean(torch.stack(y_list_zero), dim=0)\n",
    "\n",
    "# Compute ITE from averaged predictions\n",
    "ite_mb_pb = Y_pred_one_avg - Y_pred_zero_avg\n",
    "\n",
    "# Compute MSE and EPEHE\n",
    "mse = nn.MSELoss()\n",
    "epehe_mb_pb = torch.sqrt(mse(true_ite, ite_mb_pb))\n",
    "epehe_baseline = torch.sqrt(mse(true_ite, baseline_ite))\n",
    "\n",
    "# Optionally print or log the results\n",
    "print(f'EPEHE for mb+pb model: {epehe_mb_pb.item()}')\n",
    "print(f'EPEHE for baseline model: {epehe_baseline.item()}')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Factual loss for oracle: 0.7350462079048157\n",
      "Factual loss for baseline: 2.6231489181518555\n",
      "Factual loss for mb+pb model: 1.2278058528900146\n"
     ]
    }
   ],
   "source": [
    "# factual loss for the oracle\n",
    "factual_loss_oracle = mse(Y_pred_one_test, Y_rct_test_tensor)\n",
    "print(f'Factual loss for oracle: {factual_loss_oracle.item()}')\n",
    "\n",
    "# factual loss for the baseline\n",
    "Y_pred_baseline = baseline_cate_learner(torch.cat((X_rct_test_tensor, T_rct_test_tensor), dim=1))\n",
    "factual_loss_baseline = mse(Y_pred_baseline, Y_rct_test_tensor)\n",
    "print(f'Factual loss for baseline: {factual_loss_baseline.item()}')\n",
    "\n",
    "# factual loss for the mb+pb model\n",
    "Y_pred_model = cate_learner(torch.cat((X_rct_test_tensor, U_hat_test, T_rct_test_tensor), dim=1))\n",
    "factual_loss_model = mse(Y_pred_model, Y_rct_test_tensor)\n",
    "print(f'Factual loss for mb+pb model: {factual_loss_model.item()}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run 1/10\n",
      "Epoch 0, Loss MSE: 2.56801430384318\n",
      "Epoch 100, Loss MSE: 0.1500848283370336\n",
      "Epoch 200, Loss MSE: 0.13552939891815186\n",
      "Epoch 300, Loss MSE: 0.13103613754113516\n",
      "Epoch 400, Loss MSE: 0.13395926853020987\n",
      "Epoch 500, Loss MSE: 0.13281472027301788\n",
      "Epoch 600, Loss MSE: 0.12853018442789713\n",
      "Epoch 700, Loss MSE: 0.1317370260755221\n",
      "Epoch 800, Loss MSE: 0.131316972275575\n",
      "Epoch 900, Loss MSE: 0.1266584942738215\n",
      "Epoch 1000, Loss MSE: 0.12334658950567245\n",
      "Epoch 1100, Loss MSE: 0.12545419732729593\n",
      "Epoch 1200, Loss MSE: 0.12788542608420053\n",
      "Epoch 1300, Loss MSE: 0.12494890143473943\n",
      "Epoch 1400, Loss MSE: 0.1178762490550677\n",
      "Epoch 0, Loss: 215.26797993977866, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 100, Loss: 18.470491409301758, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 200, Loss: 18.289758046468098, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 300, Loss: 18.06126085917155, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 400, Loss: 17.547835032145183, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 500, Loss: 17.53635088602702, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 600, Loss: 17.52212079366048, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 700, Loss: 17.2751522064209, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 800, Loss: 16.850069046020508, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 900, Loss: 17.059695879618328, Alpha: 100.0, Balancing Iterations: 52\n",
      "Epoch 1000, Loss: 10.263731320699057, Alpha: 50.005, Balancing Iterations: 52\n",
      "Epoch 1100, Loss: 2.1716747283935547, Alpha: 0.010000000000005116, Balancing Iterations: 52\n",
      "Epoch 1200, Loss: 0.40039925773938495, Alpha: 0.01, Balancing Iterations: 100\n",
      "Epoch 1300, Loss: 0.14065530399481455, Alpha: 0.01, Balancing Iterations: 100\n",
      "Epoch 1400, Loss: 0.05972563227017721, Alpha: 0.01, Balancing Iterations: 100\n",
      "Epoch 0, Loss MSE: 3.1238230228424073\n",
      "Epoch 100, Loss MSE: 0.8014254689216613\n",
      "Epoch 200, Loss MSE: 0.7614701151847839\n",
      "Epoch 300, Loss MSE: 0.7585489273071289\n",
      "Epoch 400, Loss MSE: 0.7343188881874084\n",
      "Epoch 500, Loss MSE: 0.7248520612716675\n",
      "Epoch 600, Loss MSE: 0.7213839411735534\n",
      "Epoch 700, Loss MSE: 0.7023608446121216\n",
      "Epoch 800, Loss MSE: 0.6937752604484558\n",
      "Epoch 900, Loss MSE: 0.6919978380203247\n",
      "Epoch 1000, Loss MSE: 0.677921187877655\n",
      "Epoch 1100, Loss MSE: 0.6937908053398132\n",
      "Epoch 1200, Loss MSE: 0.6746914863586426\n",
      "Epoch 1300, Loss MSE: 0.665681529045105\n",
      "Epoch 1400, Loss MSE: 0.6683744668960572\n",
      "Run 2/10\n",
      "Epoch 0, Loss MSE: 2.5817155838012695\n",
      "Epoch 100, Loss MSE: 0.15545530120531717\n",
      "Epoch 200, Loss MSE: 0.13800521194934845\n",
      "Epoch 300, Loss MSE: 0.13819496830304465\n",
      "Epoch 400, Loss MSE: 0.13531170785427094\n",
      "Epoch 500, Loss MSE: 0.14112109442551932\n",
      "Epoch 600, Loss MSE: 0.13392911851406097\n",
      "Epoch 700, Loss MSE: 0.13056535025437674\n",
      "Epoch 800, Loss MSE: 0.12874835977951685\n",
      "Epoch 900, Loss MSE: 0.1272734229763349\n",
      "Epoch 1000, Loss MSE: 0.13114256411790848\n",
      "Epoch 1100, Loss MSE: 0.12906857331593832\n",
      "Epoch 1200, Loss MSE: 0.12647768606742224\n",
      "Epoch 1300, Loss MSE: 0.1320205976565679\n",
      "Epoch 1400, Loss MSE: 0.1212399626771609\n",
      "Epoch 0, Loss: 276.0569559733073, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 100, Loss: 19.31166394551595, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 200, Loss: 17.804541905721027, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 300, Loss: 17.67790349324544, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 400, Loss: 17.433520634969074, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 500, Loss: 17.549991607666016, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 600, Loss: 17.083105087280273, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 700, Loss: 17.41920344034831, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 800, Loss: 16.935956319173176, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 900, Loss: 16.68070920308431, Alpha: 100.0, Balancing Iterations: 52\n",
      "Epoch 1000, Loss: 10.029836336771647, Alpha: 50.005, Balancing Iterations: 52\n",
      "Epoch 1100, Loss: 2.121769110361735, Alpha: 0.010000000000005116, Balancing Iterations: 52\n",
      "Epoch 1200, Loss: 0.3262627025445302, Alpha: 0.01, Balancing Iterations: 100\n",
      "Epoch 1300, Loss: 0.07612292716900508, Alpha: 0.01, Balancing Iterations: 100\n",
      "Epoch 1400, Loss: 0.02911963996787866, Alpha: 0.01, Balancing Iterations: 100\n",
      "Epoch 0, Loss MSE: 1.1126047611236571\n",
      "Epoch 100, Loss MSE: 0.7814292192459107\n",
      "Epoch 200, Loss MSE: 0.7806070208549499\n",
      "Epoch 300, Loss MSE: 0.7430508732795715\n",
      "Epoch 400, Loss MSE: 0.7271391868591308\n",
      "Epoch 500, Loss MSE: 0.7064221143722534\n",
      "Epoch 600, Loss MSE: 0.7106117486953736\n",
      "Epoch 700, Loss MSE: 0.686673927307129\n",
      "Epoch 800, Loss MSE: 0.689939546585083\n",
      "Epoch 900, Loss MSE: 0.680771815776825\n",
      "Epoch 1000, Loss MSE: 0.6688156723976135\n",
      "Epoch 1100, Loss MSE: 0.6758832454681396\n",
      "Epoch 1200, Loss MSE: 0.6612313032150269\n",
      "Epoch 1300, Loss MSE: 0.6781937837600708\n",
      "Epoch 1400, Loss MSE: 0.6640569925308227\n",
      "Run 3/10\n",
      "Epoch 0, Loss MSE: 2.7280455430348716\n",
      "Epoch 100, Loss MSE: 0.1596800535917282\n",
      "Epoch 200, Loss MSE: 0.13839488724867502\n",
      "Epoch 300, Loss MSE: 0.13469504316647848\n",
      "Epoch 400, Loss MSE: 0.13364597161610922\n",
      "Epoch 500, Loss MSE: 0.13454733788967133\n",
      "Epoch 600, Loss MSE: 0.13100843131542206\n",
      "Epoch 700, Loss MSE: 0.12997127572695413\n",
      "Epoch 800, Loss MSE: 0.13087079674005508\n",
      "Epoch 900, Loss MSE: 0.12696785231431326\n",
      "Epoch 1000, Loss MSE: 0.12181210269530614\n",
      "Epoch 1100, Loss MSE: 0.12363623827695847\n",
      "Epoch 1200, Loss MSE: 0.12396814674139023\n",
      "Epoch 1300, Loss MSE: 0.12195108830928802\n",
      "Epoch 1400, Loss MSE: 0.12297399342060089\n",
      "Epoch 0, Loss: 216.19869486490884, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 100, Loss: 17.80650742848714, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 200, Loss: 18.08923403422038, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 300, Loss: 17.434765497843426, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 400, Loss: 17.59410031636556, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 500, Loss: 17.658896446228027, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 600, Loss: 17.659342130025227, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 700, Loss: 17.24172083536784, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 800, Loss: 16.82024097442627, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 900, Loss: 16.64068349202474, Alpha: 100.0, Balancing Iterations: 52\n",
      "Epoch 1000, Loss: 10.213027636210123, Alpha: 50.005, Balancing Iterations: 52\n",
      "Epoch 1100, Loss: 2.0235960483551025, Alpha: 0.010000000000005116, Balancing Iterations: 52\n",
      "Epoch 1200, Loss: 0.31176023681958515, Alpha: 0.01, Balancing Iterations: 100\n",
      "Epoch 1300, Loss: 0.08543365945418675, Alpha: 0.01, Balancing Iterations: 100\n",
      "Epoch 1400, Loss: 0.03353733569383621, Alpha: 0.01, Balancing Iterations: 100\n",
      "Epoch 0, Loss MSE: 1.3635905504226684\n",
      "Epoch 100, Loss MSE: 0.7884147644042969\n",
      "Epoch 200, Loss MSE: 0.7788406014442444\n",
      "Epoch 300, Loss MSE: 0.7401029467582703\n",
      "Epoch 400, Loss MSE: 0.7357568740844727\n",
      "Epoch 500, Loss MSE: 0.7276833295822144\n",
      "Epoch 600, Loss MSE: 0.7050598502159119\n",
      "Epoch 700, Loss MSE: 0.7021937012672425\n",
      "Epoch 800, Loss MSE: 0.6918884873390198\n",
      "Epoch 900, Loss MSE: 0.6832594037055969\n",
      "Epoch 1000, Loss MSE: 0.6969392418861389\n",
      "Epoch 1100, Loss MSE: 0.6759308695793151\n",
      "Epoch 1200, Loss MSE: 0.6743337035179138\n",
      "Epoch 1300, Loss MSE: 0.6721014261245728\n",
      "Epoch 1400, Loss MSE: 0.6745543956756592\n",
      "Run 4/10\n",
      "Epoch 0, Loss MSE: 3.386211713155111\n",
      "Epoch 100, Loss MSE: 0.16139421363671622\n",
      "Epoch 200, Loss MSE: 0.14002913236618042\n",
      "Epoch 300, Loss MSE: 0.13699795802434286\n",
      "Epoch 400, Loss MSE: 0.13803314169247946\n",
      "Epoch 500, Loss MSE: 0.134576678276062\n",
      "Epoch 600, Loss MSE: 0.12914619594812393\n",
      "Epoch 700, Loss MSE: 0.13363722463448843\n",
      "Epoch 800, Loss MSE: 0.13650073607762656\n",
      "Epoch 900, Loss MSE: 0.13064433634281158\n",
      "Epoch 1000, Loss MSE: 0.13486810525258383\n",
      "Epoch 1100, Loss MSE: 0.1235605850815773\n",
      "Epoch 1200, Loss MSE: 0.12849281976620355\n",
      "Epoch 1300, Loss MSE: 0.1305484821399053\n",
      "Epoch 1400, Loss MSE: 0.1265934780240059\n",
      "Epoch 0, Loss: 242.97460428873697, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 100, Loss: 19.045474370320637, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 200, Loss: 18.429515202840168, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 300, Loss: 17.82745424906413, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 400, Loss: 17.99134063720703, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 500, Loss: 17.08654499053955, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 600, Loss: 17.104562123616535, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 700, Loss: 17.239498774210613, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 800, Loss: 17.13503138224284, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 900, Loss: 16.98638153076172, Alpha: 100.0, Balancing Iterations: 52\n",
      "Epoch 1000, Loss: 10.32024097442627, Alpha: 50.005, Balancing Iterations: 52\n",
      "Epoch 1100, Loss: 2.078859011332194, Alpha: 0.010000000000005116, Balancing Iterations: 52\n",
      "Epoch 1200, Loss: 0.3669367730617523, Alpha: 0.01, Balancing Iterations: 100\n",
      "Epoch 1300, Loss: 0.1217108244697253, Alpha: 0.01, Balancing Iterations: 100\n",
      "Epoch 1400, Loss: 0.04785575593511263, Alpha: 0.01, Balancing Iterations: 100\n",
      "Epoch 0, Loss MSE: 0.9889953374862671\n",
      "Epoch 100, Loss MSE: 0.7690680623054504\n",
      "Epoch 200, Loss MSE: 0.7607437252998352\n",
      "Epoch 300, Loss MSE: 0.7521781325340271\n",
      "Epoch 400, Loss MSE: 0.7346855163574219\n",
      "Epoch 500, Loss MSE: 0.7207570195198059\n",
      "Epoch 600, Loss MSE: 0.7126500725746154\n",
      "Epoch 700, Loss MSE: 0.6909961938858032\n",
      "Epoch 800, Loss MSE: 0.6805786967277527\n",
      "Epoch 900, Loss MSE: 0.6700534701347352\n",
      "Epoch 1000, Loss MSE: 0.6712855577468873\n",
      "Epoch 1100, Loss MSE: 0.641958224773407\n",
      "Epoch 1200, Loss MSE: 0.6421032667160034\n",
      "Epoch 1300, Loss MSE: 0.6415232658386231\n",
      "Epoch 1400, Loss MSE: 0.6430862545967102\n",
      "Run 5/10\n",
      "Epoch 0, Loss MSE: 2.4733285109202066\n",
      "Epoch 100, Loss MSE: 0.1459791213274002\n",
      "Epoch 200, Loss MSE: 0.1357977936665217\n",
      "Epoch 300, Loss MSE: 0.1359331657489141\n",
      "Epoch 400, Loss MSE: 0.12978029747804007\n",
      "Epoch 500, Loss MSE: 0.12424187362194061\n",
      "Epoch 600, Loss MSE: 0.12517482539017996\n",
      "Epoch 700, Loss MSE: 0.12838426232337952\n",
      "Epoch 800, Loss MSE: 0.1268147975206375\n",
      "Epoch 900, Loss MSE: 0.12705648442109427\n",
      "Epoch 1000, Loss MSE: 0.12077133854230244\n",
      "Epoch 1100, Loss MSE: 0.12195749829212825\n",
      "Epoch 1200, Loss MSE: 0.11891863743464152\n",
      "Epoch 1300, Loss MSE: 0.12377561132113139\n",
      "Epoch 1400, Loss MSE: 0.11784907430410385\n",
      "Epoch 0, Loss: 270.60694376627606, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 100, Loss: 19.473617553710938, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 200, Loss: 18.209190368652344, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 300, Loss: 18.045076370239258, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 400, Loss: 17.677011489868164, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 500, Loss: 17.40911865234375, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 600, Loss: 17.20317014058431, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 700, Loss: 17.378552118937176, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 800, Loss: 17.11054865519206, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 900, Loss: 16.763131459554035, Alpha: 100.0, Balancing Iterations: 52\n",
      "Epoch 1000, Loss: 10.348597844441732, Alpha: 50.005, Balancing Iterations: 52\n",
      "Epoch 1100, Loss: 1.9473637739817302, Alpha: 0.010000000000005116, Balancing Iterations: 52\n",
      "Epoch 1200, Loss: 0.29195436835289, Alpha: 0.01, Balancing Iterations: 100\n",
      "Epoch 1300, Loss: 0.11332407842079799, Alpha: 0.01, Balancing Iterations: 100\n",
      "Epoch 1400, Loss: 0.05557463193933169, Alpha: 0.01, Balancing Iterations: 100\n",
      "Epoch 0, Loss MSE: 0.9835596919059754\n",
      "Epoch 100, Loss MSE: 0.7881011843681336\n",
      "Epoch 200, Loss MSE: 0.7599525213241577\n",
      "Epoch 300, Loss MSE: 0.7306415438652039\n",
      "Epoch 400, Loss MSE: 0.7109872102737427\n",
      "Epoch 500, Loss MSE: 0.7210091829299927\n",
      "Epoch 600, Loss MSE: 0.7218286871910096\n",
      "Epoch 700, Loss MSE: 0.6910683751106262\n",
      "Epoch 800, Loss MSE: 0.695120632648468\n",
      "Epoch 900, Loss MSE: 0.6823610663414001\n",
      "Epoch 1000, Loss MSE: 0.6849450469017029\n",
      "Epoch 1100, Loss MSE: 0.6752638101577759\n",
      "Epoch 1200, Loss MSE: 0.6437813878059387\n",
      "Epoch 1300, Loss MSE: 0.6422345399856567\n",
      "Epoch 1400, Loss MSE: 0.6555190682411194\n",
      "Run 6/10\n",
      "Epoch 0, Loss MSE: 2.2368434270222983\n",
      "Epoch 100, Loss MSE: 0.15251622597376505\n",
      "Epoch 200, Loss MSE: 0.13754361867904663\n",
      "Epoch 300, Loss MSE: 0.1455103432138761\n",
      "Epoch 400, Loss MSE: 0.13755608101685843\n",
      "Epoch 500, Loss MSE: 0.14203323423862457\n",
      "Epoch 600, Loss MSE: 0.1303450663884481\n",
      "Epoch 700, Loss MSE: 0.13772346824407578\n",
      "Epoch 800, Loss MSE: 0.1295821194847425\n",
      "Epoch 900, Loss MSE: 0.12898673117160797\n",
      "Epoch 1000, Loss MSE: 0.12216753264268239\n",
      "Epoch 1100, Loss MSE: 0.1266083170970281\n",
      "Epoch 1200, Loss MSE: 0.11719775199890137\n",
      "Epoch 1300, Loss MSE: 0.12403542300065358\n",
      "Epoch 1400, Loss MSE: 0.12002277125914891\n",
      "Epoch 0, Loss: 219.09341430664062, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 100, Loss: 17.942355473836262, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 200, Loss: 17.77343813578288, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 300, Loss: 17.728925069173176, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 400, Loss: 18.03349431355794, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 500, Loss: 17.432627360026043, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 600, Loss: 18.122217814127605, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 700, Loss: 17.432850519816082, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 800, Loss: 16.89868672688802, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 900, Loss: 17.305727005004883, Alpha: 100.0, Balancing Iterations: 52\n",
      "Epoch 1000, Loss: 10.658878962198893, Alpha: 50.005, Balancing Iterations: 52\n",
      "Epoch 1100, Loss: 2.1327105363210044, Alpha: 0.010000000000005116, Balancing Iterations: 52\n",
      "Epoch 1200, Loss: 0.3642309010028839, Alpha: 0.01, Balancing Iterations: 100\n",
      "Epoch 1300, Loss: 0.11091019213199615, Alpha: 0.01, Balancing Iterations: 100\n",
      "Epoch 1400, Loss: 0.043685405204693474, Alpha: 0.01, Balancing Iterations: 100\n",
      "Epoch 0, Loss MSE: 1.0183188915252686\n",
      "Epoch 100, Loss MSE: 0.789329731464386\n",
      "Epoch 200, Loss MSE: 0.749407422542572\n",
      "Epoch 300, Loss MSE: 0.7509783625602722\n",
      "Epoch 400, Loss MSE: 0.7341128826141358\n",
      "Epoch 500, Loss MSE: 0.7054536581039429\n",
      "Epoch 600, Loss MSE: 0.6920508503913879\n",
      "Epoch 700, Loss MSE: 0.6954913973808289\n",
      "Epoch 800, Loss MSE: 0.6793006658554077\n",
      "Epoch 900, Loss MSE: 0.6942174434661865\n",
      "Epoch 1000, Loss MSE: 0.6744353532791137\n",
      "Epoch 1100, Loss MSE: 0.664788794517517\n",
      "Epoch 1200, Loss MSE: 0.6569962859153747\n",
      "Epoch 1300, Loss MSE: 0.6470621585845947\n",
      "Epoch 1400, Loss MSE: 0.6457537770271301\n",
      "Run 7/10\n",
      "Epoch 0, Loss MSE: 2.5602389176686606\n",
      "Epoch 100, Loss MSE: 0.15060368676980337\n",
      "Epoch 200, Loss MSE: 0.1379337360461553\n",
      "Epoch 300, Loss MSE: 0.1388035168250402\n",
      "Epoch 400, Loss MSE: 0.13460556666056314\n",
      "Epoch 500, Loss MSE: 0.13669422020514807\n",
      "Epoch 600, Loss MSE: 0.13487515350182852\n",
      "Epoch 700, Loss MSE: 0.13067896167437235\n",
      "Epoch 800, Loss MSE: 0.1304337481657664\n",
      "Epoch 900, Loss MSE: 0.13088048746188483\n",
      "Epoch 1000, Loss MSE: 0.13382165630658469\n",
      "Epoch 1100, Loss MSE: 0.1256595402956009\n",
      "Epoch 1200, Loss MSE: 0.12465575834115346\n",
      "Epoch 1300, Loss MSE: 0.12523890535036722\n",
      "Epoch 1400, Loss MSE: 0.11932059874137242\n",
      "Epoch 0, Loss: 221.81624857584634, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 100, Loss: 18.996848424275715, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 200, Loss: 17.90662384033203, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 300, Loss: 17.239693959554035, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 400, Loss: 17.277626355489094, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 500, Loss: 17.332043011983234, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 600, Loss: 17.17125129699707, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 700, Loss: 17.049923578898113, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 800, Loss: 16.97489579518636, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 900, Loss: 16.421181042989094, Alpha: 100.0, Balancing Iterations: 52\n",
      "Epoch 1000, Loss: 9.99747085571289, Alpha: 50.005, Balancing Iterations: 52\n",
      "Epoch 1100, Loss: 2.014741818110148, Alpha: 0.010000000000005116, Balancing Iterations: 52\n",
      "Epoch 1200, Loss: 0.35244664549827576, Alpha: 0.01, Balancing Iterations: 100\n",
      "Epoch 1300, Loss: 0.11405972391366959, Alpha: 0.01, Balancing Iterations: 100\n",
      "Epoch 1400, Loss: 0.04517166316509247, Alpha: 0.01, Balancing Iterations: 100\n",
      "Epoch 0, Loss MSE: 1.0259965896606444\n",
      "Epoch 100, Loss MSE: 0.7961587905883789\n",
      "Epoch 200, Loss MSE: 0.7570922613143921\n",
      "Epoch 300, Loss MSE: 0.7303380727767944\n",
      "Epoch 400, Loss MSE: 0.7108319163322449\n",
      "Epoch 500, Loss MSE: 0.7032505750656128\n",
      "Epoch 600, Loss MSE: 0.7172064781188965\n",
      "Epoch 700, Loss MSE: 0.692798113822937\n",
      "Epoch 800, Loss MSE: 0.6867125988006592\n",
      "Epoch 900, Loss MSE: 0.6794603586196899\n",
      "Epoch 1000, Loss MSE: 0.6714487433433532\n",
      "Epoch 1100, Loss MSE: 0.6607993125915528\n",
      "Epoch 1200, Loss MSE: 0.668963897228241\n",
      "Epoch 1300, Loss MSE: 0.6603252410888671\n",
      "Epoch 1400, Loss MSE: 0.6555694818496705\n",
      "Run 8/10\n",
      "Epoch 0, Loss MSE: 2.560335636138916\n",
      "Epoch 100, Loss MSE: 0.1403905153274536\n",
      "Epoch 200, Loss MSE: 0.1344373549024264\n",
      "Epoch 300, Loss MSE: 0.13619631777207056\n",
      "Epoch 400, Loss MSE: 0.1307465781768163\n",
      "Epoch 500, Loss MSE: 0.13193872074286142\n",
      "Epoch 600, Loss MSE: 0.1295565441250801\n",
      "Epoch 700, Loss MSE: 0.13106648127237955\n",
      "Epoch 800, Loss MSE: 0.125999353826046\n",
      "Epoch 900, Loss MSE: 0.12815479189157486\n",
      "Epoch 1000, Loss MSE: 0.12897262970606485\n",
      "Epoch 1100, Loss MSE: 0.1262020468711853\n",
      "Epoch 1200, Loss MSE: 0.12730536113182703\n",
      "Epoch 1300, Loss MSE: 0.12331310907999675\n",
      "Epoch 1400, Loss MSE: 0.12610695014397302\n",
      "Epoch 0, Loss: 233.94527689615884, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 100, Loss: 19.3261775970459, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 200, Loss: 18.347021102905273, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 300, Loss: 18.325895309448242, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 400, Loss: 17.683547337849934, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 500, Loss: 17.604819615681965, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 600, Loss: 17.580393473307293, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 700, Loss: 17.53828493754069, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 800, Loss: 16.864261945088703, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 900, Loss: 17.287704149882, Alpha: 100.0, Balancing Iterations: 52\n",
      "Epoch 1000, Loss: 10.286628405253092, Alpha: 50.005, Balancing Iterations: 52\n",
      "Epoch 1100, Loss: 2.1184476216634116, Alpha: 0.010000000000005116, Balancing Iterations: 52\n",
      "Epoch 1200, Loss: 0.4316501518090566, Alpha: 0.01, Balancing Iterations: 100\n",
      "Epoch 1300, Loss: 0.17624154686927795, Alpha: 0.01, Balancing Iterations: 100\n",
      "Epoch 1400, Loss: 0.08275967091321945, Alpha: 0.01, Balancing Iterations: 100\n",
      "Epoch 0, Loss MSE: 1.0341660499572753\n",
      "Epoch 100, Loss MSE: 0.7749505281448364\n",
      "Epoch 200, Loss MSE: 0.7729147791862487\n",
      "Epoch 300, Loss MSE: 0.7468930125236511\n",
      "Epoch 400, Loss MSE: 0.7192745327949523\n",
      "Epoch 500, Loss MSE: 0.7112675070762634\n",
      "Epoch 600, Loss MSE: 0.6985648393630981\n",
      "Epoch 700, Loss MSE: 0.6932977318763733\n",
      "Epoch 800, Loss MSE: 0.671553087234497\n",
      "Epoch 900, Loss MSE: 0.6574442267417908\n",
      "Epoch 1000, Loss MSE: 0.6522838950157166\n",
      "Epoch 1100, Loss MSE: 0.6371167063713074\n",
      "Epoch 1200, Loss MSE: 0.6358277201652527\n",
      "Epoch 1300, Loss MSE: 0.6300937294960022\n",
      "Epoch 1400, Loss MSE: 0.6365414381027221\n",
      "Run 9/10\n",
      "Epoch 0, Loss MSE: 3.418764670689901\n",
      "Epoch 100, Loss MSE: 0.34817371765772503\n",
      "Epoch 200, Loss MSE: 0.1376307432850202\n",
      "Epoch 300, Loss MSE: 0.1358389506737391\n",
      "Epoch 400, Loss MSE: 0.13857191304365793\n",
      "Epoch 500, Loss MSE: 0.12850159406661987\n",
      "Epoch 600, Loss MSE: 0.13133248686790466\n",
      "Epoch 700, Loss MSE: 0.13107594847679138\n",
      "Epoch 800, Loss MSE: 0.13099536299705505\n",
      "Epoch 900, Loss MSE: 0.1274005447824796\n",
      "Epoch 1000, Loss MSE: 0.12817775706450144\n",
      "Epoch 1100, Loss MSE: 0.12675178050994873\n",
      "Epoch 1200, Loss MSE: 0.12581912179787955\n",
      "Epoch 1300, Loss MSE: 0.1275190512339274\n",
      "Epoch 1400, Loss MSE: 0.12645428876082102\n",
      "Epoch 0, Loss: 211.57486470540366, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 100, Loss: 20.428265889485676, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 200, Loss: 18.54543972015381, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 300, Loss: 17.439343134562176, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 400, Loss: 17.23773701985677, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 500, Loss: 16.858670870463055, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 600, Loss: 16.80514685312907, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 700, Loss: 17.36233107248942, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 800, Loss: 16.696459134419758, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 900, Loss: 16.762948989868164, Alpha: 100.0, Balancing Iterations: 52\n",
      "Epoch 1000, Loss: 10.111520449320475, Alpha: 50.005, Balancing Iterations: 52\n",
      "Epoch 1100, Loss: 2.0643486181894937, Alpha: 0.010000000000005116, Balancing Iterations: 52\n",
      "Epoch 1200, Loss: 0.19897520542144775, Alpha: 0.01, Balancing Iterations: 100\n",
      "Epoch 1300, Loss: 0.026872392122944195, Alpha: 0.01, Balancing Iterations: 100\n",
      "Epoch 1400, Loss: 0.020561058074235916, Alpha: 0.01, Balancing Iterations: 100\n",
      "Epoch 0, Loss MSE: 1.0210149526596068\n",
      "Epoch 100, Loss MSE: 0.7838486433029175\n",
      "Epoch 200, Loss MSE: 0.750796663761139\n",
      "Epoch 300, Loss MSE: 0.735885488986969\n",
      "Epoch 400, Loss MSE: 0.7383949160575867\n",
      "Epoch 500, Loss MSE: 0.7041282057762146\n",
      "Epoch 600, Loss MSE: 0.7057451963424682\n",
      "Epoch 700, Loss MSE: 0.677926528453827\n",
      "Epoch 800, Loss MSE: 0.6768294453620911\n",
      "Epoch 900, Loss MSE: 0.6695768594741821\n",
      "Epoch 1000, Loss MSE: 0.6723948240280151\n",
      "Epoch 1100, Loss MSE: 0.6778181791305542\n",
      "Epoch 1200, Loss MSE: 0.6628174901008606\n",
      "Epoch 1300, Loss MSE: 0.6343737721443177\n",
      "Epoch 1400, Loss MSE: 0.6425625681877136\n",
      "Run 10/10\n",
      "Epoch 0, Loss MSE: 2.863548755645752\n",
      "Epoch 100, Loss MSE: 0.15780570606390634\n",
      "Epoch 200, Loss MSE: 0.1415309210618337\n",
      "Epoch 300, Loss MSE: 0.1424944450457891\n",
      "Epoch 400, Loss MSE: 0.12940448274215063\n",
      "Epoch 500, Loss MSE: 0.13155990093946457\n",
      "Epoch 600, Loss MSE: 0.13157122830549875\n",
      "Epoch 700, Loss MSE: 0.13296478986740112\n",
      "Epoch 800, Loss MSE: 0.13357413311799368\n",
      "Epoch 900, Loss MSE: 0.1308335637052854\n",
      "Epoch 1000, Loss MSE: 0.13405130306879678\n",
      "Epoch 1100, Loss MSE: 0.1300922930240631\n",
      "Epoch 1200, Loss MSE: 0.13313867151737213\n",
      "Epoch 1300, Loss MSE: 0.1295783519744873\n",
      "Epoch 1400, Loss MSE: 0.13254990180333456\n",
      "Epoch 0, Loss: 261.9461161295573, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 100, Loss: 18.563074111938477, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 200, Loss: 18.51245880126953, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 300, Loss: 18.735082626342773, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 400, Loss: 17.42875607808431, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 500, Loss: 17.700167338053387, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 600, Loss: 17.450899759928387, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 700, Loss: 17.558409372965496, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 800, Loss: 16.930500030517578, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 900, Loss: 17.341607411702473, Alpha: 100.0, Balancing Iterations: 52\n",
      "Epoch 1000, Loss: 10.468817710876465, Alpha: 50.005, Balancing Iterations: 52\n",
      "Epoch 1100, Loss: 2.084432045618693, Alpha: 0.010000000000005116, Balancing Iterations: 52\n",
      "Epoch 1200, Loss: 0.24008202056090036, Alpha: 0.01, Balancing Iterations: 100\n",
      "Epoch 1300, Loss: 0.05997398619850477, Alpha: 0.01, Balancing Iterations: 100\n",
      "Epoch 1400, Loss: 0.02358374061683814, Alpha: 0.01, Balancing Iterations: 100\n",
      "Epoch 0, Loss MSE: 1.175162410736084\n",
      "Epoch 100, Loss MSE: 0.7861440539360046\n",
      "Epoch 200, Loss MSE: 0.7546347737312317\n",
      "Epoch 300, Loss MSE: 0.7351032018661499\n",
      "Epoch 400, Loss MSE: 0.7219625949859619\n",
      "Epoch 500, Loss MSE: 0.7000554919242858\n",
      "Epoch 600, Loss MSE: 0.6901679396629333\n",
      "Epoch 700, Loss MSE: 0.7119056701660156\n",
      "Epoch 800, Loss MSE: 0.6751713156700134\n",
      "Epoch 900, Loss MSE: 0.6718607187271118\n",
      "Epoch 1000, Loss MSE: 0.6871438503265381\n",
      "Epoch 1100, Loss MSE: 0.678342604637146\n",
      "Epoch 1200, Loss MSE: 0.6609416961669922\n",
      "Epoch 1300, Loss MSE: 0.6682433843612671\n",
      "Epoch 1400, Loss MSE: 0.6507738947868347\n",
      "MB+PB Model: EPEHE Mean: 0.3579904168844223, Std: 0.03778232749831618\n",
      "Baseline Model: EPEHE Mean: 2.661280632019043, Std: 0.023697555580681477\n",
      "MB+PB Model: Factual Loss Mean: 1.076923781633377, Std: 0.13198194273849534\n",
      "Baseline Model: Factual Loss Mean: 3.802150917053223, Std: 0.06279535503296144\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "from torch.utils.data import TensorDataset\n",
    "import os\n",
    "import pickle\n",
    "\n",
    "\n",
    "# Initialize lists to store results\n",
    "epehe_mb_pb_list = []\n",
    "epehe_baseline_list = []\n",
    "factual_loss_mb_pb_list = []\n",
    "factual_loss_baseline_list = []\n",
    "\n",
    "n_runs = 10  # Number of runs\n",
    "\n",
    "# Run the training and evaluation loop 10 times\n",
    "for run in range(n_runs):\n",
    "    print(f\"Run {run+1}/{n_runs}\")\n",
    "\n",
    "    # Initialize and train the baseline CATE learner\n",
    "    baseline_cate_learner = snet(input_dim=dim_X + 1, hidden_dim=16)\n",
    "    baseline_cate_learner.to(device)\n",
    "    baseline_cate_learner = train_baseline(observation_data, baseline_cate_learner, num_epochs=num_epochs, device=device, lr=0.001)\n",
    "\n",
    "    # Initialize and train the mb+pb models\n",
    "    model_f = BoundedContinuousFunctionModel(input_dim=dim_X, output_dim=1)\n",
    "    model_f.to(device)\n",
    "\n",
    "    model_g = BoundedContinuousFunctionModel(input_dim=1, output_dim=1)\n",
    "    model_g.to(device)\n",
    "\n",
    "    generator = Generator(input_dim=gen_input_dim, output_dim=gen_output_dim)\n",
    "    generator.to(device)\n",
    "\n",
    "    cate_learner = snet(input_dim=dim_X + 1 + gen_output_dim, hidden_dim=16)\n",
    "    cate_learner.to(device)\n",
    "\n",
    "    generator, cate_learner = train_model_mb_plus_pb(\n",
    "        observational_data=observation_data,\n",
    "        rct_data=rct_data,\n",
    "        model_g=model_g,\n",
    "        model_f=model_f,\n",
    "        generator=generator,\n",
    "        cate_learner=cate_learner,\n",
    "        alpha_start=100,\n",
    "        alpha_end=0.01,\n",
    "        generator_input_dim=gen_input_dim,\n",
    "        num_epochs=num_epochs,\n",
    "        batch_size=250,\n",
    "        device=device,\n",
    "        lr_g=0.001,\n",
    "        lr_te=0.001,\n",
    "        lr_f=0.001\n",
    "    )\n",
    "\n",
    "    # Train the oracle model (baseline on RCT test data)\n",
    "    oracle = snet(input_dim=dim_X + 1, hidden_dim=16)\n",
    "    oracle.to(device)\n",
    "    rct_data_test = TensorDataset(X_rct_test_tensor, T_rct_test_tensor, Y_rct_test_tensor)\n",
    "    oracle = train_baseline(rct_data_test, oracle, num_epochs=num_epochs, device=device, lr=0.001)\n",
    "\n",
    "    # Move all models and data to CPU for evaluation\n",
    "    oracle = oracle.to('cpu')\n",
    "    baseline_cate_learner = baseline_cate_learner.to('cpu')\n",
    "    cate_learner = cate_learner.to('cpu')\n",
    "    generator = generator.to('cpu')\n",
    "    X_rct_test_tensor = X_rct_test_tensor.to('cpu')\n",
    "    \n",
    "    # Oracle ITE (True ITE)\n",
    "    T_one = torch.ones(X_rct_test_tensor.shape[0], 1)\n",
    "    T_zero = torch.zeros(X_rct_test_tensor.shape[0], 1)\n",
    "    Y_pred_one_test = oracle(torch.cat((X_rct_test_tensor, T_one), dim=1))\n",
    "    Y_pred_zero_test = oracle(torch.cat((X_rct_test_tensor, T_zero), dim=1))\n",
    "    true_ite = Y_pred_one_test - Y_pred_zero_test\n",
    "\n",
    "    # Baseline ITE\n",
    "    Y_pred_one_baseline = baseline_cate_learner(torch.cat((X_rct_test_tensor, T_one), dim=1))\n",
    "    Y_pred_zero_baseline = baseline_cate_learner(torch.cat((X_rct_test_tensor, T_zero), dim=1))\n",
    "    baseline_ite = Y_pred_one_baseline - Y_pred_zero_baseline\n",
    "\n",
    "    # MB+PB ITE with averaging over multiple U samples\n",
    "    n_samples = 10\n",
    "    y_list_one = []\n",
    "    y_list_zero = []\n",
    "\n",
    "    for _ in range(n_samples):\n",
    "        # Sample U from generator\n",
    "        Z_test = torch.randn(X_rct_test_tensor.shape[0], gen_input_dim)\n",
    "\n",
    "        # Generate U_hat_test using generator\n",
    "        U_hat_test = generator(Z_test)\n",
    "\n",
    "        # Compute predictions for T=1\n",
    "        test_input_te_one = torch.cat((X_rct_test_tensor, U_hat_test, T_one), dim=1)\n",
    "        Y_pred_one_model = cate_learner(test_input_te_one)\n",
    "        y_list_one.append(Y_pred_one_model)\n",
    "\n",
    "        # Compute predictions for T=0\n",
    "        test_input_te_zero = torch.cat((X_rct_test_tensor, U_hat_test, T_zero), dim=1)\n",
    "        Y_pred_zero_model = cate_learner(test_input_te_zero)\n",
    "        y_list_zero.append(Y_pred_zero_model)\n",
    "\n",
    "    # Average predictions over all samples of U\n",
    "    Y_pred_one_avg = torch.mean(torch.stack(y_list_one), dim=0)\n",
    "    Y_pred_zero_avg = torch.mean(torch.stack(y_list_zero), dim=0)\n",
    "\n",
    "    # Compute ITE from averaged predictions\n",
    "    ite_mb_pb = Y_pred_one_avg - Y_pred_zero_avg\n",
    "\n",
    "    # Compute EPEHE (error in PEHE) and factual loss\n",
    "    mse = nn.MSELoss()\n",
    "    epehe_mb_pb = torch.sqrt(mse(true_ite, ite_mb_pb)).item()\n",
    "    epehe_baseline = torch.sqrt(mse(true_ite, baseline_ite)).item()\n",
    "\n",
    "    factual_loss_mb_pb = mse(Y_pred_one_avg, Y_rct_test_tensor).item()\n",
    "    factual_loss_baseline = mse(Y_pred_one_baseline, Y_rct_test_tensor).item()\n",
    "\n",
    "    # Append results to the lists\n",
    "    epehe_mb_pb_list.append(epehe_mb_pb)\n",
    "    epehe_baseline_list.append(epehe_baseline)\n",
    "    factual_loss_mb_pb_list.append(factual_loss_mb_pb)\n",
    "    factual_loss_baseline_list.append(factual_loss_baseline)\n",
    "\n",
    "# Compute mean and standard deviation for EPEHE and factual loss\n",
    "epehe_mb_pb_mean = np.mean(epehe_mb_pb_list)\n",
    "epehe_mb_pb_std = np.std(epehe_mb_pb_list)\n",
    "epehe_baseline_mean = np.mean(epehe_baseline_list)\n",
    "epehe_baseline_std = np.std(epehe_baseline_list)\n",
    "\n",
    "factual_loss_mb_pb_mean = np.mean(factual_loss_mb_pb_list)\n",
    "factual_loss_mb_pb_std = np.std(factual_loss_mb_pb_list)\n",
    "factual_loss_baseline_mean = np.mean(factual_loss_baseline_list)\n",
    "factual_loss_baseline_std = np.std(factual_loss_baseline_list)\n",
    "\n",
    "# Print results\n",
    "print(f'MB+PB Model: EPEHE Mean: {epehe_mb_pb_mean}, Std: {epehe_mb_pb_std}')\n",
    "print(f'Baseline Model: EPEHE Mean: {epehe_baseline_mean}, Std: {epehe_baseline_std}')\n",
    "print(f'MB+PB Model: Factual Loss Mean: {factual_loss_mb_pb_mean}, Std: {factual_loss_mb_pb_std}')\n",
    "print(f'Baseline Model: Factual Loss Mean: {factual_loss_baseline_mean}, Std: {factual_loss_baseline_std}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the experiment_results_folder if it doesn't exist\n",
    "results_folder = 'experiment_results'\n",
    "os.makedirs(results_folder, exist_ok=True)\n",
    "\n",
    "# Prepare the results to be saved\n",
    "results = {\n",
    "    'epehe_mb_pb_mean': epehe_mb_pb_mean,\n",
    "    'epehe_mb_pb_std': epehe_mb_pb_std,\n",
    "    'epehe_baseline_mean': epehe_baseline_mean,\n",
    "    'epehe_baseline_std': epehe_baseline_std,\n",
    "    'factual_loss_mb_pb_mean': factual_loss_mb_pb_mean,\n",
    "    'factual_loss_mb_pb_std': factual_loss_mb_pb_std,\n",
    "    'factual_loss_baseline_mean': factual_loss_baseline_mean,\n",
    "    'factual_loss_baseline_std': factual_loss_baseline_std\n",
    "}\n",
    "\n",
    "# Save the results to a pickle file\n",
    "results_file = os.path.join(results_folder, 'star_experiment_results.pkl')\n",
    "with open(results_file, 'wb') as f:\n",
    "    pickle.dump(results, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
