{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Running Experiment for ACTG dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Current working directory: /hpc/home/aa671/phd/generating_confounders/Deconfounding-MD\n",
      "Device:  cuda:0\n"
     ]
    }
   ],
   "source": [
    "# import libraries\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from generate_data import *\n",
    "from train_models import *\n",
    "from utils import *\n",
    "from generate_data import *\n",
    "from models import *\n",
    "\n",
    "seed = 2024\n",
    "np.random.seed(seed)\n",
    "print(\"Current working directory:\", os.getcwd())\n",
    "\n",
    "# check if CUDA is available\n",
    "use_cuda = torch.cuda.is_available()\n",
    "device = torch.device(\"cuda:0\" if use_cuda else \"cpu\")\n",
    "print(\"Device: \", device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Size of unconfounded (RCT) features: (1000, 12)\n",
      "Size of confounded (Observational) features: (580, 12)\n"
     ]
    }
   ],
   "source": [
    "data = upload_actg(n_unc=50)\n",
    "data_rct = upload_actg(n_unc=1000)\n",
    "# Print the shape of unconfounded and confounded features\n",
    "x_unc = data_rct['x_unc']\n",
    "x_conf = data['x_conf']\n",
    "\n",
    "print(f\"Size of unconfounded (RCT) features: {x_unc.shape}\")\n",
    "print(f\"Size of confounded (Observational) features: {x_conf.shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_to_tensors(x, t, y, device='cpu'):\n",
    "    X_tensor = torch.tensor(x, dtype=torch.float32, device=device)\n",
    "    T_tensor = torch.tensor(t, dtype=torch.float32, device=device)\n",
    "    Y_tensor = torch.tensor(y, dtype=torch.float32, device=device)\n",
    "    return X_tensor, T_tensor, Y_tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Size of test features: torch.Size([1000, 12])\n",
      "Size of observational train features: torch.Size([580, 12])\n",
      "Size of RCT train: torch.Size([50, 1])\n"
     ]
    }
   ],
   "source": [
    "# Get the unconfounded (RCT) and confounded (observational) data\n",
    "x_unc = data_rct['x_unc']\n",
    "t_unc = data_rct['t_unc']\n",
    "y_unc = data_rct['y_unc']\n",
    "\n",
    "t_unc_train = data['t_unc']\n",
    "y_unc_train = data['y_unc']\n",
    "\n",
    "x_conf = data['x_conf']\n",
    "t_conf = data['t_conf']\n",
    "y_conf = data['y_conf']\n",
    "\n",
    "# Split the RCT (unconfounded) data into training and test sets\n",
    "# _, X_rct_test, T_rct_train, T_rct_test, Y_rct_train, Y_rct_test = train_test_split(x_unc, t_unc, y_unc, test_size=0.9, random_state=42)\n",
    "X_rct_test = x_unc\n",
    "T_rct_train = t_unc_train\n",
    "Y_rct_train = y_unc_train\n",
    "T_rct_test = t_unc\n",
    "Y_rct_test = y_unc\n",
    "\n",
    "# Convert RCT (unconfounded) and observational (confounded) data into tensors\n",
    "X_rct_test_tensor, T_rct_train_tensor, Y_rct_train_tensor = convert_to_tensors(X_rct_test, T_rct_train, Y_rct_train, device=device)\n",
    "X_rct_test_tensor, T_rct_test_tensor, Y_rct_test_tensor = convert_to_tensors(X_rct_test, T_rct_test, Y_rct_test)\n",
    "\n",
    "X_obs_train_tensor, T_obs_train_tensor, Y_obs_train_tensor = convert_to_tensors(x_conf, t_conf, y_conf, device=device)\n",
    "\n",
    "observation_data = TensorDataset(X_obs_train_tensor, T_obs_train_tensor, Y_obs_train_tensor)\n",
    "rct_data = TensorDataset(T_rct_train_tensor, Y_rct_train_tensor)\n",
    "# Print the sizes to check everything\n",
    "print(f\"Size of test features: {X_rct_test_tensor.shape}\")\n",
    "print(f\"Size of observational train features: {X_obs_train_tensor.shape}\")\n",
    "print(f\"Size of RCT train: {T_rct_train_tensor.shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs = 1500"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "import torch.optim as optim\n",
    "import torch.nn as nn\n",
    "\n",
    "def train_baseline(observational_data, baseline_cate_learner, num_epochs=1000, batch_size=256, device='cpu', lr=0.001):\n",
    "    # Loss function\n",
    "    mse = nn.MSELoss()\n",
    "\n",
    "    # Optimizer\n",
    "    optimizer = optim.Adam(baseline_cate_learner.parameters(), lr=lr)\n",
    "\n",
    "    # Create a DataLoader for batching\n",
    "    data_loader = DataLoader(observational_data, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "    # Training loop\n",
    "    for epoch in range(num_epochs):\n",
    "        epoch_loss = 0.0\n",
    "        for X_batch, T_batch, Y_batch in data_loader:\n",
    "            # Move data to device (if using GPU)\n",
    "            X_batch, T_batch, Y_batch = X_batch.to(device), T_batch.to(device), Y_batch.to(device)\n",
    "            \n",
    "            # Concatenate X and T as input to the model\n",
    "            XT = torch.cat((X_batch, T_batch), dim=1)\n",
    "            \n",
    "            # Forward pass\n",
    "            Y_pred = baseline_cate_learner(XT)\n",
    "            \n",
    "            # Calculate loss\n",
    "            loss = mse(Y_pred, Y_batch)\n",
    "            epoch_loss += loss.item()\n",
    "\n",
    "            # Backward pass and optimize\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "        # Print loss every 100 epochs\n",
    "        if epoch % 100 == 0:\n",
    "            print(f'Epoch {epoch}, Loss MSE: {epoch_loss / len(data_loader)}')\n",
    "\n",
    "    return baseline_cate_learner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "import torch.optim as optim\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "\n",
    "def train_model_mb_plus_pb(observational_data, rct_data, model_g, model_f, generator, cate_learner, \n",
    "                           alpha_start=100, alpha_end=0.01, num_epochs=500, \n",
    "                           balancing_iterations_start=5, balancing_iterations_end=50, \n",
    "                           generator_input_dim=1, batch_size=512, device='cpu', \n",
    "                           lr_g=0.001, lr_te=0.001, lr_f=0.001):\n",
    "    \n",
    "    # Extract RCT data (treatment and outcome)\n",
    "    T_rct, Y_rct = rct_data[:][0], rct_data[:][1]\n",
    "    T_rct, Y_rct = T_rct.view(-1, 1), Y_rct  # Ensure correct shapes\n",
    "\n",
    "    # Loss function\n",
    "    mse = nn.MSELoss()\n",
    "\n",
    "    # Optimizers for different models\n",
    "    optimizer_g = optim.Adam(generator.parameters(), lr=lr_g)\n",
    "    optimizer_te = optim.Adam(cate_learner.parameters(), lr=lr_te)\n",
    "    optimizer_f = optim.Adam(model_f.parameters(), lr=lr_f)\n",
    "    optimizer_g_model = optim.Adam(model_g.parameters(), lr=lr_g)\n",
    "\n",
    "    # Prepare DataLoader for the observational dataset\n",
    "    data_loader = DataLoader(observational_data, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "    # Training loop\n",
    "    for epoch in range(num_epochs):\n",
    "        epoch_loss = 0.0\n",
    "\n",
    "        # Update alpha and balancing iterations based on epoch\n",
    "        if epoch < int((num_epochs * 2 / 3) - 100):\n",
    "            alpha = alpha_start\n",
    "            balancing_iterations = balancing_iterations_start\n",
    "        elif epoch > int((num_epochs * 2 / 3) + 100):\n",
    "            alpha = alpha_end\n",
    "            balancing_iterations = balancing_iterations_end\n",
    "        else:\n",
    "            balancing_iterations = int((balancing_iterations_start + balancing_iterations_end) / 2)\n",
    "            alpha = alpha_start - (alpha_start - alpha_end) * (epoch - int(num_epochs * 2 / 3 - 100)) / 200\n",
    "\n",
    "        # Loop through observational dataset in batches\n",
    "        for X_batch, T_batch, Y_batch in data_loader:\n",
    "            X_batch, T_batch, Y_batch = X_batch.to(device), T_batch.to(device), Y_batch.to(device)\n",
    "\n",
    "            # Generate U_hat from generator\n",
    "            Z = torch.randn(X_batch.shape[0], generator_input_dim, device=device)\n",
    "            U_hat = generator(Z)\n",
    "\n",
    "            # Prepare inputs for CATE learner\n",
    "            input_te = torch.cat((X_batch, U_hat, T_batch), dim=1)\n",
    "            Y_pred = cate_learner(input_te)\n",
    "\n",
    "            # Predict for T=1 and T=0\n",
    "            input_te_one = torch.cat((X_batch, U_hat, torch.ones(X_batch.shape[0], 1, device=device)), dim=1)\n",
    "            Y_pred_one = cate_learner(input_te_one)\n",
    "\n",
    "            input_te_zero = torch.cat((X_batch, U_hat, torch.zeros(X_batch.shape[0], 1, device=device)), dim=1)\n",
    "            Y_pred_zero = cate_learner(input_te_zero)\n",
    "\n",
    "\n",
    "            if X_batch.shape[0] > Y_rct[T_rct == 1].shape[0]:\n",
    "                idx = torch.randint(0, X_batch.shape[0], (Y_rct[T_rct == 1].shape[0],), device=device)\n",
    "                X_batch_small = X_batch[idx]\n",
    "            else:\n",
    "                X_batch_small = X_batch  # Default to full batch if the condition is false\n",
    "\n",
    "\n",
    "            f_Y_rct_1 = model_f(X_batch_small) * Y_rct[T_rct == 1].view(-1, 1)\n",
    "            f_Y_pred_1 = model_f(X_batch) * Y_pred_one.view(-1, 1)\n",
    "\n",
    "            if X_batch.shape[0] > Y_rct[T_rct == 0].shape[0]:\n",
    "                idx = torch.randint(0, X_batch.shape[0], (Y_rct[T_rct == 0].shape[0],), device=device)\n",
    "                X_batch_small = X_batch[idx]\n",
    "            else:\n",
    "                X_batch_small = X_batch  # Default to full batch if the condition is false\n",
    "\n",
    "\n",
    "            f_Y_rct_0 = model_f(X_batch_small) * Y_rct[T_rct == 0].view(-1, 1)\n",
    "            f_Y_pred_0 = model_f(X_batch) * Y_pred_zero.view(-1, 1)\n",
    "\n",
    "            g_Y_pred_1 = model_g(Y_pred_one.view(-1, 1))\n",
    "            g_Y_rct_1 = model_g(Y_rct[T_rct == 1].view(-1, 1))\n",
    "            g_Y_pred_0 = model_g(Y_pred_zero.view(-1, 1))\n",
    "            g_Y_rct_0 = model_g(Y_rct[T_rct == 0].view(-1, 1))\n",
    "\n",
    "            # Compute losses\n",
    "            loss1 = mse(Y_pred, Y_batch)\n",
    "            loss2 = mse(f_Y_pred_1.mean(), f_Y_rct_1.mean())\n",
    "            loss3 = mse(f_Y_pred_0.mean(), f_Y_rct_0.mean())\n",
    "            loss2p = mse(g_Y_pred_1.mean(), g_Y_rct_1.mean())\n",
    "            loss3p = mse(g_Y_pred_0.mean(), g_Y_rct_0.mean())\n",
    "            loss = alpha * loss1 + loss2 + loss3 + loss2p + loss3p\n",
    "\n",
    "            # Update generator and CATE learner\n",
    "            optimizer_te.zero_grad()\n",
    "            optimizer_g.zero_grad()\n",
    "            loss.backward(retain_graph=True)\n",
    "            optimizer_te.step()\n",
    "            optimizer_g.step()\n",
    "\n",
    "            epoch_loss += loss.item()\n",
    "\n",
    "            # Balancing iterations\n",
    "            for _ in range(balancing_iterations):\n",
    "\n",
    "                if X_batch.shape[0] > Y_rct[T_rct == 1].view(-1, 1).shape[0]:\n",
    "                    idx = np.random.choice(X_batch.shape[0], Y_rct[T_rct == 1].view(-1, 1).shape[0], replace=False)\n",
    "                    X_batch_small = X_batch[idx]\n",
    "                f_Y_rct_1 = model_f(X_batch_small)*Y_rct[T_rct == 1].view(-1, 1).detach()\n",
    "                f_Y_pred_1 = model_f(X_batch)*Y_pred_one.view(-1, 1).detach()\n",
    "                if X_batch.shape[0] > Y_rct[T_rct == 0].view(-1, 1).shape[0]:\n",
    "                    idx = np.random.choice(X_batch.shape[0], Y_rct[T_rct == 0].view(-1, 1).shape[0], replace=False)\n",
    "                    X_batch_small = X_batch[idx]\n",
    "                f_Y_rct_0 = model_f(X_batch_small)*Y_rct[T_rct == 0].view(-1, 1).detach()\n",
    "                f_Y_pred_0 = model_f(X_batch)*Y_pred_zero.view(-1, 1).detach()\n",
    "\n",
    "                g_Y_pred_1 = model_g(Y_pred_one.view(-1, 1)).detach()\n",
    "                g_Y_rct_1 = model_g(Y_rct[T_rct == 1].view(-1, 1)).detach()\n",
    "                g_Y_pred_0 = model_g(Y_pred_zero.view(-1, 1)).detach()\n",
    "                g_Y_rct_0 = model_g(Y_rct[T_rct == 0].view(-1, 1)).detach()\n",
    "\n",
    "                loss4 = mse(f_Y_rct_1.mean(), f_Y_pred_1.mean())\n",
    "                loss5 = mse(f_Y_rct_0.mean(), f_Y_pred_0.mean())\n",
    "                loss4p = mse(g_Y_pred_1.mean(), g_Y_rct_1.mean())\n",
    "                loss5p = mse(g_Y_pred_0.mean(), g_Y_rct_0.mean())\n",
    "\n",
    "                loss_f = -loss4 - loss5 -  loss4p - loss5p\n",
    "                #print(\"loss_f = \", loss_f)\n",
    "                optimizer_f.zero_grad()\n",
    "                optimizer_g_model.zero_grad()\n",
    "                loss_f.backward()\n",
    "                optimizer_f.step()\n",
    "                optimizer_g_model.step()\n",
    "\n",
    "        # Logging and printing every 100 epochs\n",
    "        if epoch % 100 == 0:\n",
    "            print(f'Epoch {epoch}, Loss: {epoch_loss / len(data_loader)}, Alpha: {alpha}, Balancing Iterations: {balancing_iterations}')\n",
    "\n",
    "    return generator, cate_learner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0, Loss: 98.96991729736328, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 100, Loss: 61.370073318481445, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 200, Loss: 49.617103576660156, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 300, Loss: 63.39825630187988, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 400, Loss: 41.63632774353027, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 500, Loss: 43.617279052734375, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 600, Loss: 57.22336387634277, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 700, Loss: 38.26298904418945, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 800, Loss: 43.89860153198242, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 900, Loss: 41.87715148925781, Alpha: 100.0, Balancing Iterations: 27\n",
      "Epoch 1000, Loss: 19.350014686584473, Alpha: 50.005, Balancing Iterations: 27\n",
      "Epoch 1100, Loss: 0.3843371272087097, Alpha: 0.010000000000005116, Balancing Iterations: 27\n",
      "Epoch 1200, Loss: 0.2026965394616127, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1300, Loss: 0.04582648351788521, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1400, Loss: 0.051820600405335426, Alpha: 0.01, Balancing Iterations: 50\n"
     ]
    }
   ],
   "source": [
    "dim_X = X_obs_train_tensor.shape[1]\n",
    "\n",
    "# Training loop for the baseline\n",
    "baseline_cate_learner = snet(input_dim=dim_X+1, hidden_dim=16)\n",
    "baseline_cate_learner.to(device)\n",
    "\n",
    "baseline_cate_learner =  train_baseline(observation_data, baseline_cate_learner,num_epochs=num_epochs, device=device, lr=0.001)\n",
    "\n",
    "\n",
    "gen_input_dim = 2\n",
    "gen_output_dim = 1   \n",
    "\n",
    "# Define model_f, model_g, and the generator\n",
    "model_f = BoundedContinuousFunctionModel(input_dim=dim_X, output_dim=1)\n",
    "model_f.to(device)\n",
    "\n",
    "model_g = BoundedContinuousFunctionModel(input_dim=1, output_dim=1)\n",
    "model_g.to(device)\n",
    "\n",
    "generator = Generator(input_dim=gen_input_dim, output_dim=gen_output_dim)\n",
    "generator.to(device)\n",
    "\n",
    "# Define the CATE learner\n",
    "cate_learner = snet(input_dim=dim_X + 1 + gen_output_dim, hidden_dim=16)\n",
    "cate_learner.to(device)\n",
    "\n",
    "# Train the new models with the mb+pb method\n",
    "generator, cate_learner = train_model_mb_plus_pb(\n",
    "    observational_data=observation_data,\n",
    "    rct_data=rct_data,\n",
    "    model_g=model_g,\n",
    "    model_f=model_f,\n",
    "    generator=generator,\n",
    "    cate_learner=cate_learner,\n",
    "    alpha_start=100,\n",
    "    alpha_end=.01,\n",
    "    generator_input_dim=gen_input_dim,\n",
    "    num_epochs=num_epochs,\n",
    "    batch_size=500,\n",
    "    device=device,\n",
    "    lr_g=0.001,\n",
    "    lr_te=0.001,\n",
    "    lr_f=0.001\n",
    ")\n",
    "\n",
    "# Train a baseline model on the RCT_test call it oracle\n",
    "oracle = snet(input_dim=dim_X + 1, hidden_dim=16)\n",
    "oracle.to(device)\n",
    "rct_data_test = TensorDataset(X_rct_test_tensor, T_rct_test_tensor, Y_rct_test_tensor)\n",
    "\n",
    "oracle = train_baseline(rct_data_test, oracle, num_epochs=num_epochs, device=device, lr=0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EPEHE for mb+pb model: 0.5331896543502808\n",
      "EPEHE for baseline model: 1.2511985301971436\n"
     ]
    }
   ],
   "source": [
    "# Move all models and data to CPU for evaluation\n",
    "oracle = oracle.to('cpu')\n",
    "baseline_cate_learner = baseline_cate_learner.to('cpu')\n",
    "cate_learner = cate_learner.to('cpu')\n",
    "generator = generator.to('cpu')\n",
    "\n",
    "# Move test data to CPU\n",
    "X_rct_test_tensor = X_rct_test_tensor.to('cpu')\n",
    "T_one = torch.ones(X_rct_test_tensor.shape[0], 1)\n",
    "T_zero = torch.zeros(X_rct_test_tensor.shape[0], 1)\n",
    "\n",
    "# Oracle ITE (True ITE)\n",
    "Y_pred_one_test = oracle(torch.cat((X_rct_test_tensor, T_one), dim=1))\n",
    "Y_pred_zero_test = oracle(torch.cat((X_rct_test_tensor, T_zero), dim=1))\n",
    "true_ite = Y_pred_one_test - Y_pred_zero_test\n",
    "\n",
    "# Baseline ITE\n",
    "Y_pred_one_baseline = baseline_cate_learner(torch.cat((X_rct_test_tensor, T_one), dim=1))\n",
    "Y_pred_zero_baseline = baseline_cate_learner(torch.cat((X_rct_test_tensor, T_zero), dim=1))\n",
    "baseline_ite = Y_pred_one_baseline - Y_pred_zero_baseline\n",
    "\n",
    "# MB+PB ITE with averaging over multiple U samples\n",
    "n_samples = 10  # Number of U samples\n",
    "y_list_one = []\n",
    "y_list_zero = []\n",
    "\n",
    "for _ in range(n_samples):\n",
    "    # Sample U from generator\n",
    "    Z_test = torch.randn(X_rct_test_tensor.shape[0], gen_input_dim)\n",
    "\n",
    "    # Generate U_hat_test using generator\n",
    "    U_hat_test = generator(Z_test)\n",
    "\n",
    "    # Compute predictions for T=1\n",
    "    test_input_te_one = torch.cat((X_rct_test_tensor, U_hat_test, T_one), dim=1)\n",
    "    Y_pred_one_model = cate_learner(test_input_te_one)\n",
    "    y_list_one.append(Y_pred_one_model)\n",
    "\n",
    "    # Compute predictions for T=0\n",
    "    test_input_te_zero = torch.cat((X_rct_test_tensor, U_hat_test, T_zero), dim=1)\n",
    "    Y_pred_zero_model = cate_learner(test_input_te_zero)\n",
    "    y_list_zero.append(Y_pred_zero_model)\n",
    "\n",
    "# Average predictions over all samples of U\n",
    "Y_pred_one_avg = torch.mean(torch.stack(y_list_one), dim=0)\n",
    "Y_pred_zero_avg = torch.mean(torch.stack(y_list_zero), dim=0)\n",
    "\n",
    "# Compute ITE from averaged predictions\n",
    "ite_mb_pb = Y_pred_one_avg - Y_pred_zero_avg\n",
    "\n",
    "# Compute MSE and EPEHE\n",
    "mse = nn.MSELoss()\n",
    "epehe_mb_pb = torch.sqrt(mse(true_ite, ite_mb_pb))\n",
    "epehe_baseline = torch.sqrt(mse(true_ite, baseline_ite))\n",
    "\n",
    "# Optionally print or log the results\n",
    "print(f'EPEHE for mb+pb model: {epehe_mb_pb.item()}')\n",
    "print(f'EPEHE for baseline model: {epehe_baseline.item()}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Factual loss for oracle: 0.5530259013175964\n",
      "Factual loss for baseline: 1.0099655389785767\n",
      "Factual loss for mb+pb model: 0.7719661593437195\n"
     ]
    }
   ],
   "source": [
    "# factual loss for the oracle\n",
    "factual_loss_oracle = mse(Y_pred_one_test, Y_rct_test_tensor)\n",
    "print(f'Factual loss for oracle: {factual_loss_oracle.item()}')\n",
    "\n",
    "# factual loss for the baseline\n",
    "Y_pred_baseline = baseline_cate_learner(torch.cat((X_rct_test_tensor, T_rct_test_tensor), dim=1))\n",
    "factual_loss_baseline = mse(Y_pred_baseline, Y_rct_test_tensor)\n",
    "print(f'Factual loss for baseline: {factual_loss_baseline.item()}')\n",
    "\n",
    "# factual loss for the mb+pb model\n",
    "Y_pred_model = cate_learner(torch.cat((X_rct_test_tensor, U_hat_test, T_rct_test_tensor), dim=1))\n",
    "factual_loss_model = mse(Y_pred_model, Y_rct_test_tensor)\n",
    "print(f'Factual loss for mb+pb model: {factual_loss_model.item()}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run 1/10\n",
      "Epoch 0, Loss MSE: 1.236999789873759\n",
      "Epoch 100, Loss MSE: 0.5480282107988993\n",
      "Epoch 200, Loss MSE: 0.4637209077676137\n",
      "Epoch 300, Loss MSE: 0.4139104833205541\n",
      "Epoch 400, Loss MSE: 0.4578157365322113\n",
      "Epoch 500, Loss MSE: 0.3738388915856679\n",
      "Epoch 600, Loss MSE: 0.3777015209197998\n",
      "Epoch 700, Loss MSE: 0.36702293157577515\n",
      "Epoch 800, Loss MSE: 0.4507797559102376\n",
      "Epoch 900, Loss MSE: 0.3685257335503896\n",
      "Epoch 1000, Loss MSE: 0.33672745029131573\n",
      "Epoch 1100, Loss MSE: 0.3125021656354268\n",
      "Epoch 1200, Loss MSE: 0.3203025857607524\n",
      "Epoch 1300, Loss MSE: 0.29900046189626056\n",
      "Epoch 1400, Loss MSE: 0.2864809234937032\n",
      "Epoch 0, Loss: 105.64417775472005, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 100, Loss: 57.312615712483726, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 200, Loss: 52.48102951049805, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 300, Loss: 48.30866559346517, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 400, Loss: 45.5583127339681, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 500, Loss: 43.22358067830404, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 600, Loss: 41.1629581451416, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 700, Loss: 39.29313596089681, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 800, Loss: 37.69321314493815, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 900, Loss: 36.64659563700358, Alpha: 100.0, Balancing Iterations: 27\n",
      "Epoch 1000, Loss: 18.184054692586262, Alpha: 50.005, Balancing Iterations: 27\n",
      "Epoch 1100, Loss: 0.26571381588776904, Alpha: 0.010000000000005116, Balancing Iterations: 27\n",
      "Epoch 1200, Loss: 0.06473127007484436, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1300, Loss: 0.032204135631521545, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1400, Loss: 0.02633092428247134, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 0, Loss MSE: 0.7570654600858688\n",
      "Epoch 100, Loss MSE: 0.5893290936946869\n",
      "Epoch 200, Loss MSE: 0.5779097974300385\n",
      "Epoch 300, Loss MSE: 0.5606968551874161\n",
      "Epoch 400, Loss MSE: 0.5470044240355492\n",
      "Epoch 500, Loss MSE: 0.5340165421366692\n",
      "Epoch 600, Loss MSE: 0.5233957320451736\n",
      "Epoch 700, Loss MSE: 0.5115816742181778\n",
      "Epoch 800, Loss MSE: 0.5013859122991562\n",
      "Epoch 900, Loss MSE: 0.49395306408405304\n",
      "Epoch 1000, Loss MSE: 0.4839552268385887\n",
      "Epoch 1100, Loss MSE: 0.47312621772289276\n",
      "Epoch 1200, Loss MSE: 0.46154844760894775\n",
      "Epoch 1300, Loss MSE: 0.4525892063975334\n",
      "Epoch 1400, Loss MSE: 0.44323965907096863\n",
      "Run 2/10\n",
      "Epoch 0, Loss MSE: 0.9690614541371664\n",
      "Epoch 100, Loss MSE: 0.5366472999254862\n",
      "Epoch 200, Loss MSE: 0.5497907996177673\n",
      "Epoch 300, Loss MSE: 0.46428726116816205\n",
      "Epoch 400, Loss MSE: 0.39469894270102185\n",
      "Epoch 500, Loss MSE: 0.42246273159980774\n",
      "Epoch 600, Loss MSE: 0.3844056228796641\n",
      "Epoch 700, Loss MSE: 0.36493255694707233\n",
      "Epoch 800, Loss MSE: 0.3623976508776347\n",
      "Epoch 900, Loss MSE: 0.34796924392382306\n",
      "Epoch 1000, Loss MSE: 0.35925153891245526\n",
      "Epoch 1100, Loss MSE: 0.36375803252061206\n",
      "Epoch 1200, Loss MSE: 0.34074151515960693\n",
      "Epoch 1300, Loss MSE: 0.4332222143809001\n",
      "Epoch 1400, Loss MSE: 0.30523782471815747\n",
      "Epoch 0, Loss: 110.4988276163737, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 100, Loss: 55.95766067504883, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 200, Loss: 52.304911295572914, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 300, Loss: 49.352734883626304, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 400, Loss: 45.56033452351888, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 500, Loss: 43.98831558227539, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 600, Loss: 41.4951171875, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 700, Loss: 39.682786305745445, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 800, Loss: 38.08641815185547, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 900, Loss: 36.90400250752767, Alpha: 100.0, Balancing Iterations: 27\n",
      "Epoch 1000, Loss: 17.825291633605957, Alpha: 50.005, Balancing Iterations: 27\n",
      "Epoch 1100, Loss: 0.37320329745610553, Alpha: 0.010000000000005116, Balancing Iterations: 27\n",
      "Epoch 1200, Loss: 0.04616035086413225, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1300, Loss: 0.03896638688941797, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1400, Loss: 0.01831075455993414, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 0, Loss MSE: 0.7364463359117508\n",
      "Epoch 100, Loss MSE: 0.5918829888105392\n",
      "Epoch 200, Loss MSE: 0.5701417848467827\n",
      "Epoch 300, Loss MSE: 0.5542175024747849\n",
      "Epoch 400, Loss MSE: 0.5372312664985657\n",
      "Epoch 500, Loss MSE: 0.5220615491271019\n",
      "Epoch 600, Loss MSE: 0.5085191801190376\n",
      "Epoch 700, Loss MSE: 0.4942484572529793\n",
      "Epoch 800, Loss MSE: 0.4780382439494133\n",
      "Epoch 900, Loss MSE: 0.4683239758014679\n",
      "Epoch 1000, Loss MSE: 0.4546471983194351\n",
      "Epoch 1100, Loss MSE: 0.4412410855293274\n",
      "Epoch 1200, Loss MSE: 0.43093500286340714\n",
      "Epoch 1300, Loss MSE: 0.4245978146791458\n",
      "Epoch 1400, Loss MSE: 0.41285184025764465\n",
      "Run 3/10\n",
      "Epoch 0, Loss MSE: 1.0367048581441243\n",
      "Epoch 100, Loss MSE: 0.6864302357037863\n",
      "Epoch 200, Loss MSE: 0.5175585150718689\n",
      "Epoch 300, Loss MSE: 0.47228070100148517\n",
      "Epoch 400, Loss MSE: 0.4491687019666036\n",
      "Epoch 500, Loss MSE: 0.42257586121559143\n",
      "Epoch 600, Loss MSE: 0.40446072816848755\n",
      "Epoch 700, Loss MSE: 0.39136544863382977\n",
      "Epoch 800, Loss MSE: 0.38344621658325195\n",
      "Epoch 900, Loss MSE: 0.3503074149290721\n",
      "Epoch 1000, Loss MSE: 0.4776814480622609\n",
      "Epoch 1100, Loss MSE: 0.36695529023806256\n",
      "Epoch 1200, Loss MSE: 0.33961350719134015\n",
      "Epoch 1300, Loss MSE: 0.4367881218592326\n",
      "Epoch 1400, Loss MSE: 0.3247630496819814\n",
      "Epoch 0, Loss: 103.95512390136719, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 100, Loss: 57.30693689982096, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 200, Loss: 50.619851430257164, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 300, Loss: 45.62238438924154, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 400, Loss: 43.26140594482422, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 500, Loss: 41.79385248819987, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 600, Loss: 39.00629742940267, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 700, Loss: 38.05361302693685, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 800, Loss: 35.06039810180664, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 900, Loss: 34.10814221700033, Alpha: 100.0, Balancing Iterations: 27\n",
      "Epoch 1000, Loss: 16.40002218882243, Alpha: 50.005, Balancing Iterations: 27\n",
      "Epoch 1100, Loss: 0.32472120225429535, Alpha: 0.010000000000005116, Balancing Iterations: 27\n",
      "Epoch 1200, Loss: 0.0631268247961998, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1300, Loss: 0.022552601993083954, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1400, Loss: 0.026396742711464565, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 0, Loss MSE: 0.7610913515090942\n",
      "Epoch 100, Loss MSE: 0.5901754274964333\n",
      "Epoch 200, Loss MSE: 0.5651359483599663\n",
      "Epoch 300, Loss MSE: 0.5565057992935181\n",
      "Epoch 400, Loss MSE: 0.5470445901155472\n",
      "Epoch 500, Loss MSE: 0.5356089994311333\n",
      "Epoch 600, Loss MSE: 0.5232575684785843\n",
      "Epoch 700, Loss MSE: 0.509981669485569\n",
      "Epoch 800, Loss MSE: 0.49551209062337875\n",
      "Epoch 900, Loss MSE: 0.4846445992588997\n",
      "Epoch 1000, Loss MSE: 0.47254006564617157\n",
      "Epoch 1100, Loss MSE: 0.4645285978913307\n",
      "Epoch 1200, Loss MSE: 0.45457571744918823\n",
      "Epoch 1300, Loss MSE: 0.4456195905804634\n",
      "Epoch 1400, Loss MSE: 0.4340885281562805\n",
      "Run 4/10\n",
      "Epoch 0, Loss MSE: 0.9582136670748392\n",
      "Epoch 100, Loss MSE: 0.5781585772832235\n",
      "Epoch 200, Loss MSE: 0.5066219965616862\n",
      "Epoch 300, Loss MSE: 0.4373468856016795\n",
      "Epoch 400, Loss MSE: 0.42094968756039935\n",
      "Epoch 500, Loss MSE: 0.395768145720164\n",
      "Epoch 600, Loss MSE: 0.41116275389989215\n",
      "Epoch 700, Loss MSE: 0.40018896261850995\n",
      "Epoch 800, Loss MSE: 0.3403189579645793\n",
      "Epoch 900, Loss MSE: 0.34604030350844067\n",
      "Epoch 1000, Loss MSE: 0.34088101983070374\n",
      "Epoch 1100, Loss MSE: 0.44375357031822205\n",
      "Epoch 1200, Loss MSE: 0.3507743825515111\n",
      "Epoch 1300, Loss MSE: 0.2984754840532939\n",
      "Epoch 1400, Loss MSE: 0.30515416463216144\n",
      "Epoch 0, Loss: 98.51461283365886, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 100, Loss: 56.737325032552086, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 200, Loss: 51.29206212361654, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 300, Loss: 47.365071614583336, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 400, Loss: 44.370104471842446, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 500, Loss: 42.37612342834473, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 600, Loss: 40.735154469807945, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 700, Loss: 38.992069244384766, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 800, Loss: 37.708970387776695, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 900, Loss: 36.892520904541016, Alpha: 100.0, Balancing Iterations: 27\n",
      "Epoch 1000, Loss: 17.799084345499676, Alpha: 50.005, Balancing Iterations: 27\n",
      "Epoch 1100, Loss: 0.3074880391359329, Alpha: 0.010000000000005116, Balancing Iterations: 27\n",
      "Epoch 1200, Loss: 0.08522878587245941, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1300, Loss: 0.025281498829523723, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1400, Loss: 0.020161036712427933, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 0, Loss MSE: 0.8285296410322189\n",
      "Epoch 100, Loss MSE: 0.5972977578639984\n",
      "Epoch 200, Loss MSE: 0.5818743258714676\n",
      "Epoch 300, Loss MSE: 0.5653620809316635\n",
      "Epoch 400, Loss MSE: 0.5524436682462692\n",
      "Epoch 500, Loss MSE: 0.5387108325958252\n",
      "Epoch 600, Loss MSE: 0.5279161781072617\n",
      "Epoch 700, Loss MSE: 0.5143338218331337\n",
      "Epoch 800, Loss MSE: 0.5068312436342239\n",
      "Epoch 900, Loss MSE: 0.496794618666172\n",
      "Epoch 1000, Loss MSE: 0.48789627104997635\n",
      "Epoch 1100, Loss MSE: 0.47603315114974976\n",
      "Epoch 1200, Loss MSE: 0.467690072953701\n",
      "Epoch 1300, Loss MSE: 0.4602241516113281\n",
      "Epoch 1400, Loss MSE: 0.45352547615766525\n",
      "Run 5/10\n",
      "Epoch 0, Loss MSE: 1.138347903887431\n",
      "Epoch 100, Loss MSE: 0.5403875211874644\n",
      "Epoch 200, Loss MSE: 0.49970988432566327\n",
      "Epoch 300, Loss MSE: 0.42437665661176044\n",
      "Epoch 400, Loss MSE: 0.4210998018582662\n",
      "Epoch 500, Loss MSE: 0.41360774636268616\n",
      "Epoch 600, Loss MSE: 0.39796461661656696\n",
      "Epoch 700, Loss MSE: 0.4043552180131276\n",
      "Epoch 800, Loss MSE: 0.37293267250061035\n",
      "Epoch 900, Loss MSE: 0.3712771534919739\n",
      "Epoch 1000, Loss MSE: 0.390071302652359\n",
      "Epoch 1100, Loss MSE: 0.32594333092371625\n",
      "Epoch 1200, Loss MSE: 0.4490925172964732\n",
      "Epoch 1300, Loss MSE: 0.30917446811993915\n",
      "Epoch 1400, Loss MSE: 0.3023207485675812\n",
      "Epoch 0, Loss: 94.97729746500652, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 100, Loss: 56.92800267537435, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 200, Loss: 52.987918853759766, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 300, Loss: 46.437522888183594, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 400, Loss: 43.72728983561198, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 500, Loss: 41.72987620035807, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 600, Loss: 40.50074768066406, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 700, Loss: 39.02854919433594, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 800, Loss: 38.448394775390625, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 900, Loss: 36.38243548075358, Alpha: 100.0, Balancing Iterations: 27\n",
      "Epoch 1000, Loss: 18.371354738871258, Alpha: 50.005, Balancing Iterations: 27\n",
      "Epoch 1100, Loss: 0.2895701825618744, Alpha: 0.010000000000005116, Balancing Iterations: 27\n",
      "Epoch 1200, Loss: 0.024742068722844124, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1300, Loss: 0.020881129428744316, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1400, Loss: 0.052922278021772705, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 0, Loss MSE: 0.7874857038259506\n",
      "Epoch 100, Loss MSE: 0.5880642086267471\n",
      "Epoch 200, Loss MSE: 0.5692251920700073\n",
      "Epoch 300, Loss MSE: 0.5569609254598618\n",
      "Epoch 400, Loss MSE: 0.5473916083574295\n",
      "Epoch 500, Loss MSE: 0.5384663343429565\n",
      "Epoch 600, Loss MSE: 0.5236893370747566\n",
      "Epoch 700, Loss MSE: 0.5170175209641457\n",
      "Epoch 800, Loss MSE: 0.501143142580986\n",
      "Epoch 900, Loss MSE: 0.4910762459039688\n",
      "Epoch 1000, Loss MSE: 0.47999683767557144\n",
      "Epoch 1100, Loss MSE: 0.47175130993127823\n",
      "Epoch 1200, Loss MSE: 0.4618583023548126\n",
      "Epoch 1300, Loss MSE: 0.45374248921871185\n",
      "Epoch 1400, Loss MSE: 0.44329606741666794\n",
      "Run 6/10\n",
      "Epoch 0, Loss MSE: 0.9947458108266195\n",
      "Epoch 100, Loss MSE: 0.5506321986516317\n",
      "Epoch 200, Loss MSE: 0.5000030795733134\n",
      "Epoch 300, Loss MSE: 0.4464600781599681\n",
      "Epoch 400, Loss MSE: 0.42330898841222125\n",
      "Epoch 500, Loss MSE: 0.4278258780638377\n",
      "Epoch 600, Loss MSE: 0.4262424409389496\n",
      "Epoch 700, Loss MSE: 0.38858476281166077\n",
      "Epoch 800, Loss MSE: 0.4070653021335602\n",
      "Epoch 900, Loss MSE: 0.41080212593078613\n",
      "Epoch 1000, Loss MSE: 0.3336130380630493\n",
      "Epoch 1100, Loss MSE: 0.36221494277318317\n",
      "Epoch 1200, Loss MSE: 0.34251973032951355\n",
      "Epoch 1300, Loss MSE: 0.303252433737119\n",
      "Epoch 1400, Loss MSE: 0.31291282176971436\n",
      "Epoch 0, Loss: 102.67494201660156, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 100, Loss: 57.87841033935547, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 200, Loss: 54.373181660970054, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 300, Loss: 47.43968963623047, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 400, Loss: 44.86452992757162, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 500, Loss: 42.54813893636068, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 600, Loss: 41.12045415242513, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 700, Loss: 39.073464711507164, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 800, Loss: 38.6923942565918, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 900, Loss: 37.48781077067057, Alpha: 100.0, Balancing Iterations: 27\n",
      "Epoch 1000, Loss: 17.940091451009113, Alpha: 50.005, Balancing Iterations: 27\n",
      "Epoch 1100, Loss: 0.3410811424255371, Alpha: 0.010000000000005116, Balancing Iterations: 27\n",
      "Epoch 1200, Loss: 0.054110618929068245, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1300, Loss: 0.026591607679923374, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1400, Loss: 0.027854466810822487, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 0, Loss MSE: 0.9675715416669846\n",
      "Epoch 100, Loss MSE: 0.6008592396974564\n",
      "Epoch 200, Loss MSE: 0.5857566744089127\n",
      "Epoch 300, Loss MSE: 0.5640308260917664\n",
      "Epoch 400, Loss MSE: 0.554861918091774\n",
      "Epoch 500, Loss MSE: 0.5415068119764328\n",
      "Epoch 600, Loss MSE: 0.5303475484251976\n",
      "Epoch 700, Loss MSE: 0.5134350433945656\n",
      "Epoch 800, Loss MSE: 0.49722930788993835\n",
      "Epoch 900, Loss MSE: 0.4832301139831543\n",
      "Epoch 1000, Loss MSE: 0.47202254831790924\n",
      "Epoch 1100, Loss MSE: 0.45876771211624146\n",
      "Epoch 1200, Loss MSE: 0.44894086569547653\n",
      "Epoch 1300, Loss MSE: 0.44123512506484985\n",
      "Epoch 1400, Loss MSE: 0.4291602894663811\n",
      "Run 7/10\n",
      "Epoch 0, Loss MSE: 1.1614803870519002\n",
      "Epoch 100, Loss MSE: 0.6043495337168375\n",
      "Epoch 200, Loss MSE: 0.5028942922751108\n",
      "Epoch 300, Loss MSE: 0.49459991852442425\n",
      "Epoch 400, Loss MSE: 0.4233700931072235\n",
      "Epoch 500, Loss MSE: 0.4104879895846049\n",
      "Epoch 600, Loss MSE: 0.4048011600971222\n",
      "Epoch 700, Loss MSE: 0.3861825962861379\n",
      "Epoch 800, Loss MSE: 0.3601839741071065\n",
      "Epoch 900, Loss MSE: 0.3650293250878652\n",
      "Epoch 1000, Loss MSE: 0.3448880612850189\n",
      "Epoch 1100, Loss MSE: 0.34341780344645184\n",
      "Epoch 1200, Loss MSE: 0.45419397950172424\n",
      "Epoch 1300, Loss MSE: 0.2965920865535736\n",
      "Epoch 1400, Loss MSE: 0.3190114696820577\n",
      "Epoch 0, Loss: 113.19939931233723, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 100, Loss: 58.556392669677734, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 200, Loss: 54.24186706542969, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 300, Loss: 49.06672032674154, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 400, Loss: 46.56646855672201, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 500, Loss: 45.736165364583336, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 600, Loss: 44.20341491699219, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 700, Loss: 42.0616569519043, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 800, Loss: 40.80319849650065, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 900, Loss: 38.96255366007487, Alpha: 100.0, Balancing Iterations: 27\n",
      "Epoch 1000, Loss: 19.282740592956543, Alpha: 50.005, Balancing Iterations: 27\n",
      "Epoch 1100, Loss: 0.3001754879951477, Alpha: 0.010000000000005116, Balancing Iterations: 27\n",
      "Epoch 1200, Loss: 0.09229631846149762, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1300, Loss: 0.027901288742820423, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1400, Loss: 0.019879679506023724, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 0, Loss MSE: 0.8066384643316269\n",
      "Epoch 100, Loss MSE: 0.5972559303045273\n",
      "Epoch 200, Loss MSE: 0.5737757310271263\n",
      "Epoch 300, Loss MSE: 0.5612631291151047\n",
      "Epoch 400, Loss MSE: 0.5471615344285965\n",
      "Epoch 500, Loss MSE: 0.5377855971455574\n",
      "Epoch 600, Loss MSE: 0.525036558508873\n",
      "Epoch 700, Loss MSE: 0.5154746025800705\n",
      "Epoch 800, Loss MSE: 0.5076769217848778\n",
      "Epoch 900, Loss MSE: 0.497680626809597\n",
      "Epoch 1000, Loss MSE: 0.49069277197122574\n",
      "Epoch 1100, Loss MSE: 0.4804423749446869\n",
      "Epoch 1200, Loss MSE: 0.47729961574077606\n",
      "Epoch 1300, Loss MSE: 0.4670252352952957\n",
      "Epoch 1400, Loss MSE: 0.4615495949983597\n",
      "Run 8/10\n",
      "Epoch 0, Loss MSE: 0.8835107485453287\n",
      "Epoch 100, Loss MSE: 0.5384795367717743\n",
      "Epoch 200, Loss MSE: 0.5005953013896942\n",
      "Epoch 300, Loss MSE: 0.465695599714915\n",
      "Epoch 400, Loss MSE: 0.4233289162317912\n",
      "Epoch 500, Loss MSE: 0.5746555030345917\n",
      "Epoch 600, Loss MSE: 0.4131675561269124\n",
      "Epoch 700, Loss MSE: 0.3929341236750285\n",
      "Epoch 800, Loss MSE: 0.40436206261316937\n",
      "Epoch 900, Loss MSE: 0.35532692074775696\n",
      "Epoch 1000, Loss MSE: 0.33724212646484375\n",
      "Epoch 1100, Loss MSE: 0.35932162404060364\n",
      "Epoch 1200, Loss MSE: 0.46013590693473816\n",
      "Epoch 1300, Loss MSE: 0.36915626128514606\n",
      "Epoch 1400, Loss MSE: 0.3065074384212494\n",
      "Epoch 0, Loss: 102.73403676350911, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 100, Loss: 55.662123362223305, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 200, Loss: 50.66300710042318, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 300, Loss: 45.472740173339844, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 400, Loss: 42.53323872884115, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 500, Loss: 40.48030217488607, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 600, Loss: 38.77593231201172, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 700, Loss: 37.04689280192057, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 800, Loss: 35.56268183390299, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 900, Loss: 34.94050216674805, Alpha: 100.0, Balancing Iterations: 27\n",
      "Epoch 1000, Loss: 16.63876438140869, Alpha: 50.005, Balancing Iterations: 27\n",
      "Epoch 1100, Loss: 0.29908721645673114, Alpha: 0.010000000000005116, Balancing Iterations: 27\n",
      "Epoch 1200, Loss: 0.033923386596143246, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1300, Loss: 0.021307353240748245, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1400, Loss: 0.016320482827723026, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 0, Loss MSE: 0.7301872670650482\n",
      "Epoch 100, Loss MSE: 0.5886007994413376\n",
      "Epoch 200, Loss MSE: 0.5632955655455589\n",
      "Epoch 300, Loss MSE: 0.5476886332035065\n",
      "Epoch 400, Loss MSE: 0.5354106575250626\n",
      "Epoch 500, Loss MSE: 0.5213294923305511\n",
      "Epoch 600, Loss MSE: 0.5045025050640106\n",
      "Epoch 700, Loss MSE: 0.49144256114959717\n",
      "Epoch 800, Loss MSE: 0.48010731488466263\n",
      "Epoch 900, Loss MSE: 0.46879708766937256\n",
      "Epoch 1000, Loss MSE: 0.4570910856127739\n",
      "Epoch 1100, Loss MSE: 0.44566502422094345\n",
      "Epoch 1200, Loss MSE: 0.4358729273080826\n",
      "Epoch 1300, Loss MSE: 0.4246695414185524\n",
      "Epoch 1400, Loss MSE: 0.4158751145005226\n",
      "Run 9/10\n",
      "Epoch 0, Loss MSE: 1.0320590535799663\n",
      "Epoch 100, Loss MSE: 0.5412596166133881\n",
      "Epoch 200, Loss MSE: 0.4895339210828145\n",
      "Epoch 300, Loss MSE: 0.4563160240650177\n",
      "Epoch 400, Loss MSE: 0.4435529410839081\n",
      "Epoch 500, Loss MSE: 0.48366719484329224\n",
      "Epoch 600, Loss MSE: 0.438029279311498\n",
      "Epoch 700, Loss MSE: 0.4041055639584859\n",
      "Epoch 800, Loss MSE: 0.4057255784670512\n",
      "Epoch 900, Loss MSE: 0.36194337407747906\n",
      "Epoch 1000, Loss MSE: 0.4236624638239543\n",
      "Epoch 1100, Loss MSE: 0.5105123917261759\n",
      "Epoch 1200, Loss MSE: 0.3867940406004588\n",
      "Epoch 1300, Loss MSE: 0.36286203066507977\n",
      "Epoch 1400, Loss MSE: 0.33955862124760944\n",
      "Epoch 0, Loss: 106.3080571492513, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 100, Loss: 57.115105946858726, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 200, Loss: 52.281359354654946, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 300, Loss: 46.25699361165365, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 400, Loss: 44.77047475179037, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 500, Loss: 42.86131922403971, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 600, Loss: 40.41543197631836, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 700, Loss: 39.05549875895182, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 800, Loss: 37.64070510864258, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 900, Loss: 36.995548248291016, Alpha: 100.0, Balancing Iterations: 27\n",
      "Epoch 1000, Loss: 17.899365425109863, Alpha: 50.005, Balancing Iterations: 27\n",
      "Epoch 1100, Loss: 0.2377821753422419, Alpha: 0.010000000000005116, Balancing Iterations: 27\n",
      "Epoch 1200, Loss: 0.05888728052377701, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1300, Loss: 0.03150379626701275, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1400, Loss: 0.020386225854357082, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 0, Loss MSE: 0.8055175691843033\n",
      "Epoch 100, Loss MSE: 0.5937163382768631\n",
      "Epoch 200, Loss MSE: 0.5696235299110413\n",
      "Epoch 300, Loss MSE: 0.5558479875326157\n",
      "Epoch 400, Loss MSE: 0.5433873385190964\n",
      "Epoch 500, Loss MSE: 0.5316196233034134\n",
      "Epoch 600, Loss MSE: 0.5209898352622986\n",
      "Epoch 700, Loss MSE: 0.5076206028461456\n",
      "Epoch 800, Loss MSE: 0.4989045113325119\n",
      "Epoch 900, Loss MSE: 0.4879082590341568\n",
      "Epoch 1000, Loss MSE: 0.4782505929470062\n",
      "Epoch 1100, Loss MSE: 0.4689495638012886\n",
      "Epoch 1200, Loss MSE: 0.45718954503536224\n",
      "Epoch 1300, Loss MSE: 0.4479287713766098\n",
      "Epoch 1400, Loss MSE: 0.43675311654806137\n",
      "Run 10/10\n",
      "Epoch 0, Loss MSE: 1.0022720694541931\n",
      "Epoch 100, Loss MSE: 0.6693107088406881\n",
      "Epoch 200, Loss MSE: 0.5333839853604635\n",
      "Epoch 300, Loss MSE: 0.49419721961021423\n",
      "Epoch 400, Loss MSE: 0.6395145654678345\n",
      "Epoch 500, Loss MSE: 0.4419424335161845\n",
      "Epoch 600, Loss MSE: 0.5095146397749583\n",
      "Epoch 700, Loss MSE: 0.41261371970176697\n",
      "Epoch 800, Loss MSE: 0.39974602063496906\n",
      "Epoch 900, Loss MSE: 0.3695472578207652\n",
      "Epoch 1000, Loss MSE: 0.3611255685488383\n",
      "Epoch 1100, Loss MSE: 0.35595226287841797\n",
      "Epoch 1200, Loss MSE: 0.3515968322753906\n",
      "Epoch 1300, Loss MSE: 0.32850273450215656\n",
      "Epoch 1400, Loss MSE: 0.3102637032667796\n",
      "Epoch 0, Loss: 106.53043619791667, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 100, Loss: 57.914415995279946, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 200, Loss: 53.5789426167806, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 300, Loss: 48.80848821004232, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 400, Loss: 45.37356948852539, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 500, Loss: 42.911513010660805, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 600, Loss: 41.03943634033203, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 700, Loss: 39.82873853047689, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 800, Loss: 38.12968063354492, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 900, Loss: 36.50725173950195, Alpha: 100.0, Balancing Iterations: 27\n",
      "Epoch 1000, Loss: 18.2984717686971, Alpha: 50.005, Balancing Iterations: 27\n",
      "Epoch 1100, Loss: 0.31329547862211865, Alpha: 0.010000000000005116, Balancing Iterations: 27\n",
      "Epoch 1200, Loss: 0.03660444232324759, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1300, Loss: 0.033648028038442135, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1400, Loss: 0.014103728501747051, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 0, Loss MSE: 0.7258763015270233\n",
      "Epoch 100, Loss MSE: 0.5946936905384064\n",
      "Epoch 200, Loss MSE: 0.5781261026859283\n",
      "Epoch 300, Loss MSE: 0.5686081945896149\n",
      "Epoch 400, Loss MSE: 0.5579067468643188\n",
      "Epoch 500, Loss MSE: 0.5478210970759392\n",
      "Epoch 600, Loss MSE: 0.5322550684213638\n",
      "Epoch 700, Loss MSE: 0.5193987116217613\n",
      "Epoch 800, Loss MSE: 0.5085727497935295\n",
      "Epoch 900, Loss MSE: 0.4944651946425438\n",
      "Epoch 1000, Loss MSE: 0.48315175622701645\n",
      "Epoch 1100, Loss MSE: 0.47412319481372833\n",
      "Epoch 1200, Loss MSE: 0.46303698420524597\n",
      "Epoch 1300, Loss MSE: 0.4537624195218086\n",
      "Epoch 1400, Loss MSE: 0.4393491670489311\n",
      "MB+PB Model: EPEHE Mean: 0.5217272967100144, Std: 0.051163058044328864\n",
      "Baseline Model: EPEHE Mean: 1.1896092414855957, Std: 0.020641717645905472\n",
      "MB+PB Model: Factual Loss Mean: 0.7167988657951355, Std: 0.029811268033345002\n",
      "Baseline Model: Factual Loss Mean: 1.2581061244010925, Std: 0.04765352489082916\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "from torch.utils.data import TensorDataset\n",
    "import os\n",
    "import pickle\n",
    "\n",
    "\n",
    "# Initialize lists to store results\n",
    "epehe_mb_pb_list = []\n",
    "epehe_baseline_list = []\n",
    "factual_loss_mb_pb_list = []\n",
    "factual_loss_baseline_list = []\n",
    "\n",
    "n_runs = 10  # Number of runs\n",
    "\n",
    "# Run the training and evaluation loop 10 times\n",
    "for run in range(n_runs):\n",
    "    print(f\"Run {run+1}/{n_runs}\")\n",
    "\n",
    "    # Initialize and train the baseline CATE learner\n",
    "    baseline_cate_learner = snet(input_dim=dim_X + 1, hidden_dim=16)\n",
    "    baseline_cate_learner.to(device)\n",
    "    baseline_cate_learner = train_baseline(observation_data, baseline_cate_learner, num_epochs=num_epochs, device=device, lr=0.001)\n",
    "\n",
    "    # Initialize and train the mb+pb models\n",
    "    model_f = BoundedContinuousFunctionModel(input_dim=dim_X, output_dim=1)\n",
    "    model_f.to(device)\n",
    "\n",
    "    model_g = BoundedContinuousFunctionModel(input_dim=1, output_dim=1)\n",
    "    model_g.to(device)\n",
    "\n",
    "    generator = Generator(input_dim=gen_input_dim, output_dim=gen_output_dim)\n",
    "    generator.to(device)\n",
    "\n",
    "    cate_learner = snet(input_dim=dim_X + 1 + gen_output_dim, hidden_dim=16)\n",
    "    cate_learner.to(device)\n",
    "\n",
    "    generator, cate_learner = train_model_mb_plus_pb(\n",
    "        observational_data=observation_data,\n",
    "        rct_data=rct_data,\n",
    "        model_g=model_g,\n",
    "        model_f=model_f,\n",
    "        generator=generator,\n",
    "        cate_learner=cate_learner,\n",
    "        alpha_start=100,\n",
    "        alpha_end=0.01,\n",
    "        generator_input_dim=gen_input_dim,\n",
    "        num_epochs=num_epochs,\n",
    "        batch_size=200,\n",
    "        device=device,\n",
    "        lr_g=0.001,\n",
    "        lr_te=0.001,\n",
    "        lr_f=0.001\n",
    "    )\n",
    "\n",
    "    # Train the oracle model (baseline on RCT test data)\n",
    "    oracle = snet(input_dim=dim_X + 1, hidden_dim=16)\n",
    "    oracle.to(device)\n",
    "    rct_data_test = TensorDataset(X_rct_test_tensor, T_rct_test_tensor, Y_rct_test_tensor)\n",
    "    oracle = train_baseline(rct_data_test, oracle, num_epochs=num_epochs, device=device, lr=0.001)\n",
    "\n",
    "    # Move all models and data to CPU for evaluation\n",
    "    oracle = oracle.to('cpu')\n",
    "    baseline_cate_learner = baseline_cate_learner.to('cpu')\n",
    "    cate_learner = cate_learner.to('cpu')\n",
    "    generator = generator.to('cpu')\n",
    "    X_rct_test_tensor = X_rct_test_tensor.to('cpu')\n",
    "    \n",
    "    # Oracle ITE (True ITE)\n",
    "    T_one = torch.ones(X_rct_test_tensor.shape[0], 1)\n",
    "    T_zero = torch.zeros(X_rct_test_tensor.shape[0], 1)\n",
    "    Y_pred_one_test = oracle(torch.cat((X_rct_test_tensor, T_one), dim=1))\n",
    "    Y_pred_zero_test = oracle(torch.cat((X_rct_test_tensor, T_zero), dim=1))\n",
    "    true_ite = Y_pred_one_test - Y_pred_zero_test\n",
    "\n",
    "    # Baseline ITE\n",
    "    Y_pred_one_baseline = baseline_cate_learner(torch.cat((X_rct_test_tensor, T_one), dim=1))\n",
    "    Y_pred_zero_baseline = baseline_cate_learner(torch.cat((X_rct_test_tensor, T_zero), dim=1))\n",
    "    baseline_ite = Y_pred_one_baseline - Y_pred_zero_baseline\n",
    "\n",
    "    # MB+PB ITE with averaging over multiple U samples\n",
    "    n_samples = 10\n",
    "    y_list_one = []\n",
    "    y_list_zero = []\n",
    "\n",
    "    for _ in range(n_samples):\n",
    "        # Sample U from generator\n",
    "        Z_test = torch.randn(X_rct_test_tensor.shape[0], gen_input_dim)\n",
    "\n",
    "        # Generate U_hat_test using generator\n",
    "        U_hat_test = generator(Z_test)\n",
    "\n",
    "        # Compute predictions for T=1\n",
    "        test_input_te_one = torch.cat((X_rct_test_tensor, U_hat_test, T_one), dim=1)\n",
    "        Y_pred_one_model = cate_learner(test_input_te_one)\n",
    "        y_list_one.append(Y_pred_one_model)\n",
    "\n",
    "        # Compute predictions for T=0\n",
    "        test_input_te_zero = torch.cat((X_rct_test_tensor, U_hat_test, T_zero), dim=1)\n",
    "        Y_pred_zero_model = cate_learner(test_input_te_zero)\n",
    "        y_list_zero.append(Y_pred_zero_model)\n",
    "\n",
    "    # Average predictions over all samples of U\n",
    "    Y_pred_one_avg = torch.mean(torch.stack(y_list_one), dim=0)\n",
    "    Y_pred_zero_avg = torch.mean(torch.stack(y_list_zero), dim=0)\n",
    "\n",
    "    # Compute ITE from averaged predictions\n",
    "    ite_mb_pb = Y_pred_one_avg - Y_pred_zero_avg\n",
    "\n",
    "    # Compute EPEHE (error in PEHE) and factual loss\n",
    "    mse = nn.MSELoss()\n",
    "    epehe_mb_pb = torch.sqrt(mse(true_ite, ite_mb_pb)).item()\n",
    "    epehe_baseline = torch.sqrt(mse(true_ite, baseline_ite)).item()\n",
    "\n",
    "    factual_loss_mb_pb = mse(Y_pred_one_avg, Y_rct_test_tensor).item()\n",
    "    factual_loss_baseline = mse(Y_pred_one_baseline, Y_rct_test_tensor).item()\n",
    "\n",
    "    # Append results to the lists\n",
    "    epehe_mb_pb_list.append(epehe_mb_pb)\n",
    "    epehe_baseline_list.append(epehe_baseline)\n",
    "    factual_loss_mb_pb_list.append(factual_loss_mb_pb)\n",
    "    factual_loss_baseline_list.append(factual_loss_baseline)\n",
    "\n",
    "# Compute mean and standard deviation for EPEHE and factual loss\n",
    "epehe_mb_pb_mean = np.mean(epehe_mb_pb_list)\n",
    "epehe_mb_pb_std = np.std(epehe_mb_pb_list)\n",
    "epehe_baseline_mean = np.mean(epehe_baseline_list)\n",
    "epehe_baseline_std = np.std(epehe_baseline_list)\n",
    "\n",
    "factual_loss_mb_pb_mean = np.mean(factual_loss_mb_pb_list)\n",
    "factual_loss_mb_pb_std = np.std(factual_loss_mb_pb_list)\n",
    "factual_loss_baseline_mean = np.mean(factual_loss_baseline_list)\n",
    "factual_loss_baseline_std = np.std(factual_loss_baseline_list)\n",
    "\n",
    "# Print results\n",
    "print(f'MB+PB Model: EPEHE Mean: {epehe_mb_pb_mean}, Std: {epehe_mb_pb_std}')\n",
    "print(f'Baseline Model: EPEHE Mean: {epehe_baseline_mean}, Std: {epehe_baseline_std}')\n",
    "print(f'MB+PB Model: Factual Loss Mean: {factual_loss_mb_pb_mean}, Std: {factual_loss_mb_pb_std}')\n",
    "print(f'Baseline Model: Factual Loss Mean: {factual_loss_baseline_mean}, Std: {factual_loss_baseline_std}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the experiment_results_folder if it doesn't exist\n",
    "results_folder = 'experiment_results'\n",
    "os.makedirs(results_folder, exist_ok=True)\n",
    "\n",
    "# Prepare the results to be saved\n",
    "results = {\n",
    "    'epehe_mb_pb_mean': epehe_mb_pb_mean,\n",
    "    'epehe_mb_pb_std': epehe_mb_pb_std,\n",
    "    'epehe_baseline_mean': epehe_baseline_mean,\n",
    "    'epehe_baseline_std': epehe_baseline_std,\n",
    "    'factual_loss_mb_pb_mean': factual_loss_mb_pb_mean,\n",
    "    'factual_loss_mb_pb_std': factual_loss_mb_pb_std,\n",
    "    'factual_loss_baseline_mean': factual_loss_baseline_mean,\n",
    "    'factual_loss_baseline_std': factual_loss_baseline_std\n",
    "}\n",
    "\n",
    "# Save the results to a pickle file\n",
    "results_file = os.path.join(results_folder, 'actg_experiment_results.pkl')\n",
    "with open(results_file, 'wb') as f:\n",
    "    pickle.dump(results, f)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
