{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Current working directory: /hpc/home/aa671/phd/generating_confounders/Deconfounding-MD\n",
      "Device:  cuda:0\n"
     ]
    }
   ],
   "source": [
    "# import libraries\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from generate_data import *\n",
    "from train_models import *\n",
    "from utils import *\n",
    "from generate_data import *\n",
    "from models import *\n",
    "\n",
    "seed = 2024\n",
    "np.random.seed(seed)\n",
    "print(\"Current working directory:\", os.getcwd())\n",
    "\n",
    "# check if CUDA is available\n",
    "use_cuda = torch.cuda.is_available()\n",
    "device = torch.device(\"cuda:0\" if use_cuda else \"cpu\")\n",
    "print(\"Device: \", device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Size of unconfounded (RCT) features: (500, 6)\n",
      "Size of confounded (Observational) features: (518, 6)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/hpc/home/aa671/.local/lib/python3.8/site-packages/sklearn/base.py:443: UserWarning: X has feature names, but StandardScaler was fitted without feature names\n",
      "  warnings.warn(\n",
      "/hpc/home/aa671/.local/lib/python3.8/site-packages/sklearn/base.py:443: UserWarning: X has feature names, but StandardScaler was fitted without feature names\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "data = sample_jobs(n_unc=50)\n",
    "data_rct = sample_jobs(n_unc=500)\n",
    "# Print the shape of unconfounded and confounded features\n",
    "x_unc = data_rct['x_unc']\n",
    "x_conf = data['x_conf']\n",
    "\n",
    "print(f\"Size of unconfounded (RCT) features: {x_unc.shape}\")\n",
    "print(f\"Size of confounded (Observational) features: {x_conf.shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_to_tensors(x, t, y, device='cpu'):\n",
    "    X_tensor = torch.tensor(x, dtype=torch.float32, device=device)\n",
    "    T_tensor = torch.tensor(t, dtype=torch.float32, device=device)\n",
    "    Y_tensor = torch.tensor(y, dtype=torch.float32, device=device)\n",
    "    return X_tensor, T_tensor, Y_tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Size of test features: torch.Size([500, 6])\n",
      "Size of observational train features: torch.Size([518, 6])\n",
      "Size of RCT train: torch.Size([49, 1])\n"
     ]
    }
   ],
   "source": [
    "# Get the unconfounded (RCT) and confounded (observational) data\n",
    "x_unc = data_rct['x_unc']\n",
    "t_unc = data_rct['t_unc']\n",
    "y_unc = data_rct['y_unc']\n",
    "\n",
    "t_unc_train = data['t_unc']\n",
    "y_unc_train = data['y_unc']\n",
    "\n",
    "x_conf = data['x_conf']\n",
    "t_conf = data['t_conf']\n",
    "y_conf = data['y_conf']\n",
    "\n",
    "# Split the RCT (unconfounded) data into training and test sets\n",
    "# _, X_rct_test, T_rct_train, T_rct_test, Y_rct_train, Y_rct_test = train_test_split(x_unc, t_unc, y_unc, test_size=0.9, random_state=42)\n",
    "X_rct_test = x_unc\n",
    "T_rct_train = t_unc_train\n",
    "Y_rct_train = y_unc_train\n",
    "T_rct_test = t_unc\n",
    "Y_rct_test = y_unc\n",
    "\n",
    "# Convert RCT (unconfounded) and observational (confounded) data into tensors\n",
    "X_rct_test_tensor, T_rct_train_tensor, Y_rct_train_tensor = convert_to_tensors(X_rct_test, T_rct_train, Y_rct_train, device=device)\n",
    "X_rct_test_tensor, T_rct_test_tensor, Y_rct_test_tensor = convert_to_tensors(X_rct_test, T_rct_test, Y_rct_test)\n",
    "\n",
    "X_obs_train_tensor, T_obs_train_tensor, Y_obs_train_tensor = convert_to_tensors(x_conf, t_conf, y_conf, device=device)\n",
    "\n",
    "observation_data = TensorDataset(X_obs_train_tensor, T_obs_train_tensor, Y_obs_train_tensor)\n",
    "rct_data = TensorDataset(T_rct_train_tensor, Y_rct_train_tensor)\n",
    "# Print the sizes to check everything\n",
    "print(f\"Size of test features: {X_rct_test_tensor.shape}\")\n",
    "print(f\"Size of observational train features: {X_obs_train_tensor.shape}\")\n",
    "print(f\"Size of RCT train: {T_rct_train_tensor.shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs = 1500"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "import torch.optim as optim\n",
    "import torch.nn as nn\n",
    "\n",
    "def train_baseline(observational_data, baseline_cate_learner, num_epochs=1000, batch_size=256, device='cpu', lr=0.001):\n",
    "    # Loss function\n",
    "    mse = nn.MSELoss()\n",
    "\n",
    "    # Optimizer\n",
    "    optimizer = optim.Adam(baseline_cate_learner.parameters(), lr=lr)\n",
    "\n",
    "    # Create a DataLoader for batching\n",
    "    data_loader = DataLoader(observational_data, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "    # Training loop\n",
    "    for epoch in range(num_epochs):\n",
    "        epoch_loss = 0.0\n",
    "        for X_batch, T_batch, Y_batch in data_loader:\n",
    "            # Move data to device (if using GPU)\n",
    "            X_batch, T_batch, Y_batch = X_batch.to(device), T_batch.to(device), Y_batch.to(device)\n",
    "            \n",
    "            # Concatenate X and T as input to the model\n",
    "            XT = torch.cat((X_batch, T_batch), dim=1)\n",
    "            \n",
    "            # Forward pass\n",
    "            Y_pred = baseline_cate_learner(XT)\n",
    "            \n",
    "            # Calculate loss\n",
    "            loss = mse(Y_pred, Y_batch)\n",
    "            epoch_loss += loss.item()\n",
    "\n",
    "            # Backward pass and optimize\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "        # Print loss every 100 epochs\n",
    "        if epoch % 100 == 0:\n",
    "            print(f'Epoch {epoch}, Loss MSE: {epoch_loss / len(data_loader)}')\n",
    "\n",
    "    return baseline_cate_learner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0, Loss MSE: 1.3313281536102295\n",
      "Epoch 100, Loss MSE: 0.033292693396409355\n",
      "Epoch 200, Loss MSE: 0.025210936621685203\n",
      "Epoch 300, Loss MSE: 0.030192839602629345\n",
      "Epoch 400, Loss MSE: 0.024184523785758454\n",
      "Epoch 500, Loss MSE: 0.024091920834810782\n",
      "Epoch 600, Loss MSE: 0.027587355114519596\n",
      "Epoch 700, Loss MSE: 0.02382471289214057\n",
      "Epoch 800, Loss MSE: 0.030192064121365547\n",
      "Epoch 900, Loss MSE: 0.02356916326001131\n",
      "Epoch 1000, Loss MSE: 0.02379178556536014\n",
      "Epoch 1100, Loss MSE: 0.03470160625874996\n",
      "Epoch 1200, Loss MSE: 0.023374529109181214\n",
      "Epoch 1300, Loss MSE: 0.023327981849433854\n",
      "Epoch 1400, Loss MSE: 0.027510570362210274\n"
     ]
    }
   ],
   "source": [
    "dim_X = X_obs_train_tensor.shape[1]\n",
    "\n",
    "# Training loop for the baseline\n",
    "baseline_cate_learner = snet(input_dim=dim_X+1, hidden_dim=16)\n",
    "baseline_cate_learner.to(device)\n",
    "\n",
    "baseline_cate_learner =  train_baseline(observation_data, baseline_cate_learner,num_epochs=num_epochs, device=device, lr=0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "import torch.optim as optim\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "\n",
    "def train_model_pb(observational_data, rct_data,model_f, cate_learner, \n",
    "                           alpha_start=100, alpha_end=0.01, num_epochs=500, \n",
    "                           balancing_iterations_start=5, balancing_iterations_end=50, \n",
    "                           batch_size=512, device='cpu', \n",
    "                           lr_te=0.001, lr_f=0.001):\n",
    "    \n",
    "    # Extract RCT data (treatment and outcome)\n",
    "    T_rct, Y_rct = rct_data[:][0], rct_data[:][1]\n",
    "    T_rct, Y_rct = T_rct.view(-1, 1), Y_rct  # Ensure correct shapes\n",
    "\n",
    "    # Loss function\n",
    "    mse = nn.MSELoss()\n",
    "\n",
    "    # Optimizers for different models\n",
    "    optimizer_te = optim.Adam(cate_learner.parameters(), lr=lr_te)\n",
    "    optimizer_f = optim.Adam(model_f.parameters(), lr=lr_f)\n",
    "\n",
    "    # Prepare DataLoader for the observational dataset\n",
    "    data_loader = DataLoader(observational_data, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "    # Training loop\n",
    "    for epoch in range(num_epochs):\n",
    "        epoch_loss = 0.0\n",
    "\n",
    "        # Update alpha and balancing iterations based on epoch\n",
    "        if epoch < int((num_epochs * 2 / 3) - 100):\n",
    "            alpha = alpha_start\n",
    "            balancing_iterations = balancing_iterations_start\n",
    "        elif epoch > int((num_epochs * 2 / 3) + 100):\n",
    "            alpha = alpha_end\n",
    "            balancing_iterations = balancing_iterations_end\n",
    "        else:\n",
    "            balancing_iterations = int((balancing_iterations_start + balancing_iterations_end) / 2)\n",
    "            alpha = alpha_start - (alpha_start - alpha_end) * (epoch - int(num_epochs * 2 / 3 - 100)) / 200\n",
    "\n",
    "        # Loop through observational dataset in batches\n",
    "        for X_batch, T_batch, Y_batch in data_loader:\n",
    "            X_batch, T_batch, Y_batch = X_batch.to(device), T_batch.to(device), Y_batch.to(device)\n",
    "\n",
    "            # Prepare inputs for CATE learner\n",
    "            input_te = torch.cat((X_batch, T_batch), dim=1)\n",
    "            Y_pred = cate_learner(input_te)\n",
    "\n",
    "            # Predict for T=1 and T=0\n",
    "            input_te_one = torch.cat((X_batch, torch.ones(X_batch.shape[0], 1, device=device)), dim=1)\n",
    "            Y_pred_one = cate_learner(input_te_one)\n",
    "\n",
    "            input_te_zero = torch.cat((X_batch, torch.zeros(X_batch.shape[0], 1, device=device)), dim=1)\n",
    "            Y_pred_zero = cate_learner(input_te_zero)\n",
    "\n",
    "\n",
    "            if X_batch.shape[0] > Y_rct[T_rct == 1].shape[0]:\n",
    "                idx = torch.randint(0, X_batch.shape[0], (Y_rct[T_rct == 1].shape[0],), device=device)\n",
    "                X_batch_small = X_batch[idx]\n",
    "            else:\n",
    "                X_batch_small = X_batch  # Default to full batch if the condition is false\n",
    "\n",
    "\n",
    "            f_Y_rct_1 = model_f(X_batch_small) * Y_rct[T_rct == 1].view(-1, 1)\n",
    "            f_Y_pred_1 = model_f(X_batch) * Y_pred_one.view(-1, 1)\n",
    "\n",
    "            if X_batch.shape[0] > Y_rct[T_rct == 0].shape[0]:\n",
    "                idx = torch.randint(0, X_batch.shape[0], (Y_rct[T_rct == 0].shape[0],), device=device)\n",
    "                X_batch_small = X_batch[idx]\n",
    "            else:\n",
    "                X_batch_small = X_batch  # Default to full batch if the condition is false\n",
    "\n",
    "\n",
    "            f_Y_rct_0 = model_f(X_batch_small) * Y_rct[T_rct == 0].view(-1, 1)\n",
    "            f_Y_pred_0 = model_f(X_batch) * Y_pred_zero.view(-1, 1)\n",
    "\n",
    "\n",
    "            # Compute losses\n",
    "            loss1 = mse(Y_pred, Y_batch)\n",
    "            loss2 = mse(f_Y_pred_1.mean(), f_Y_rct_1.mean())\n",
    "            loss3 = mse(f_Y_pred_0.mean(), f_Y_rct_0.mean())\n",
    "            loss = alpha * loss1 + loss2 + loss3 \n",
    "\n",
    "            # Update generator and CATE learner\n",
    "            optimizer_te.zero_grad()\n",
    "            loss.backward(retain_graph=True)\n",
    "            optimizer_te.step()\n",
    "\n",
    "            epoch_loss += loss.item()\n",
    "\n",
    "            # Balancing iterations\n",
    "            for _ in range(balancing_iterations):\n",
    "\n",
    "                if X_batch.shape[0] > Y_rct[T_rct == 1].view(-1, 1).shape[0]:\n",
    "                    idx = np.random.choice(X_batch.shape[0], Y_rct[T_rct == 1].view(-1, 1).shape[0], replace=False)\n",
    "                    X_batch_small = X_batch[idx]\n",
    "                f_Y_rct_1 = model_f(X_batch_small)*Y_rct[T_rct == 1].view(-1, 1).detach()\n",
    "                f_Y_pred_1 = model_f(X_batch)*Y_pred_one.view(-1, 1).detach()\n",
    "                if X_batch.shape[0] > Y_rct[T_rct == 0].view(-1, 1).shape[0]:\n",
    "                    idx = np.random.choice(X_batch.shape[0], Y_rct[T_rct == 0].view(-1, 1).shape[0], replace=False)\n",
    "                    X_batch_small = X_batch[idx]\n",
    "                f_Y_rct_0 = model_f(X_batch_small)*Y_rct[T_rct == 0].view(-1, 1).detach()\n",
    "                f_Y_pred_0 = model_f(X_batch)*Y_pred_zero.view(-1, 1).detach()\n",
    "\n",
    "                loss4 = mse(f_Y_rct_1.mean(), f_Y_pred_1.mean())\n",
    "                loss5 = mse(f_Y_rct_0.mean(), f_Y_pred_0.mean())\n",
    "    \n",
    "\n",
    "                loss_f = -loss4 - loss5 \n",
    "                #print(\"loss_f = \", loss_f)\n",
    "                optimizer_f.zero_grad()\n",
    "                loss_f.backward()\n",
    "                optimizer_f.step()\n",
    "\n",
    "        # Logging and printing every 100 epochs\n",
    "        if epoch % 100 == 0:\n",
    "            print(f'Epoch {epoch}, Loss: {epoch_loss / len(data_loader)}, Alpha: {alpha}, Balancing Iterations: {balancing_iterations}')\n",
    "\n",
    "    return cate_learner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0, Loss: 132.70598347981772, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 100, Loss: 5.581383466720581, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 200, Loss: 3.9191433588663735, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 300, Loss: 3.9136710166931152, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 400, Loss: 3.970258196194967, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 500, Loss: 3.5651170015335083, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 600, Loss: 4.669689973195394, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 700, Loss: 3.502026995023092, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 800, Loss: 3.654472589492798, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 900, Loss: 3.417949597040812, Alpha: 100.0, Balancing Iterations: 27\n",
      "Epoch 1000, Loss: 2.3843308289845786, Alpha: 50.005, Balancing Iterations: 27\n",
      "Epoch 1100, Loss: 0.15934467315673828, Alpha: 0.010000000000005116, Balancing Iterations: 27\n",
      "Epoch 1200, Loss: 0.018467079227169354, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1300, Loss: 0.0037731238019963107, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1400, Loss: 0.002407668080801765, Alpha: 0.01, Balancing Iterations: 50\n"
     ]
    }
   ],
   "source": [
    "# Define the CATE learner\n",
    "cate_learner = snet(input_dim=dim_X + 1 , hidden_dim=16)\n",
    "cate_learner.to(device)\n",
    "\n",
    "# Define model_f, model_g, and the generator\n",
    "model_f = BoundedContinuousFunctionModel(input_dim=dim_X, output_dim=1)\n",
    "model_f.to(device)\n",
    "\n",
    "# Train the new models with the mb+pb method\n",
    "cate_learner = train_model_pb(\n",
    "    observational_data=observation_data,\n",
    "    rct_data=rct_data,\n",
    "    model_f=model_f,\n",
    "    cate_learner=cate_learner,\n",
    "    alpha_start=100,\n",
    "    alpha_end=.01,\n",
    "    num_epochs=num_epochs,\n",
    "    batch_size=200,\n",
    "    device=device,\n",
    "    lr_te=0.001,\n",
    "    lr_f=0.001\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0, Loss MSE: 0.8359586894512177\n",
      "Epoch 100, Loss MSE: 0.14294926822185516\n",
      "Epoch 200, Loss MSE: 0.1429857388138771\n",
      "Epoch 300, Loss MSE: 0.14136195927858353\n",
      "Epoch 400, Loss MSE: 0.1415085420012474\n",
      "Epoch 500, Loss MSE: 0.14089980721473694\n",
      "Epoch 600, Loss MSE: 0.14179131761193275\n",
      "Epoch 700, Loss MSE: 0.1403287649154663\n",
      "Epoch 800, Loss MSE: 0.13919684290885925\n",
      "Epoch 900, Loss MSE: 0.13824927806854248\n",
      "Epoch 1000, Loss MSE: 0.13749440386891365\n",
      "Epoch 1100, Loss MSE: 0.13803290203213692\n",
      "Epoch 1200, Loss MSE: 0.1365506760776043\n",
      "Epoch 1300, Loss MSE: 0.13447891920804977\n",
      "Epoch 1400, Loss MSE: 0.1346089169383049\n"
     ]
    }
   ],
   "source": [
    "# Train a baseline model on the RCT_test call it oracle\n",
    "oracle = snet(input_dim=dim_X + 1, hidden_dim=16)\n",
    "oracle.to(device)\n",
    "rct_data_test = TensorDataset(X_rct_test_tensor, T_rct_test_tensor, Y_rct_test_tensor)\n",
    "\n",
    "oracle = train_baseline(rct_data_test, oracle, num_epochs=num_epochs, device=device, lr=0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EPEHE for pb model: 0.0714312493801117\n",
      "EPEHE for baseline model: 0.7774333953857422\n"
     ]
    }
   ],
   "source": [
    "# Move all models and data to CPU for evaluation\n",
    "oracle = oracle.to('cpu')\n",
    "baseline_cate_learner = baseline_cate_learner.to('cpu')\n",
    "cate_learner = cate_learner.to('cpu')\n",
    "\n",
    "# Move test data to CPU\n",
    "X_rct_test_tensor = X_rct_test_tensor.to('cpu')\n",
    "T_one = torch.ones(X_rct_test_tensor.shape[0], 1)\n",
    "T_zero = torch.zeros(X_rct_test_tensor.shape[0], 1)\n",
    "\n",
    "# Oracle ITE (True ITE)\n",
    "Y_pred_one_test = oracle(torch.cat((X_rct_test_tensor, T_one), dim=1))\n",
    "Y_pred_zero_test = oracle(torch.cat((X_rct_test_tensor, T_zero), dim=1))\n",
    "true_ite = Y_pred_one_test - Y_pred_zero_test\n",
    "\n",
    "# Baseline ITE\n",
    "Y_pred_one_baseline = baseline_cate_learner(torch.cat((X_rct_test_tensor, T_one), dim=1))\n",
    "Y_pred_zero_baseline = baseline_cate_learner(torch.cat((X_rct_test_tensor, T_zero), dim=1))\n",
    "baseline_ite = Y_pred_one_baseline - Y_pred_zero_baseline\n",
    "\n",
    "# ITE from pb model\n",
    "Y_pred_one_pb = cate_learner(torch.cat((X_rct_test_tensor, T_one), dim=1))\n",
    "Y_pred_zero_pb = cate_learner(torch.cat((X_rct_test_tensor, T_zero), dim=1))\n",
    "\n",
    "ite_pb = Y_pred_one_pb - Y_pred_zero_pb\n",
    "\n",
    "\n",
    "# Compute MSE and EPEHE\n",
    "mse = nn.MSELoss()\n",
    "epehe_pb = torch.sqrt(mse(true_ite, ite_pb))\n",
    "epehe_baseline = torch.sqrt(mse(true_ite, baseline_ite))\n",
    "\n",
    "# Optionally print or log the results\n",
    "print(f'EPEHE for pb model: {epehe_pb.item()}')\n",
    "print(f'EPEHE for baseline model: {epehe_baseline.item()}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "import torch.optim as optim\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "\n",
    "def train_model_mb_plus_pb(observational_data, rct_data, model_g, model_f, generator, cate_learner, \n",
    "                           alpha_start=100, alpha_end=0.01, num_epochs=500, \n",
    "                           balancing_iterations_start=5, balancing_iterations_end=50, \n",
    "                           generator_input_dim=1, batch_size=512, device='cpu', \n",
    "                           lr_g=0.001, lr_te=0.001, lr_f=0.001):\n",
    "    \n",
    "    # Extract RCT data (treatment and outcome)\n",
    "    T_rct, Y_rct = rct_data[:][0], rct_data[:][1]\n",
    "    T_rct, Y_rct = T_rct.view(-1, 1), Y_rct  # Ensure correct shapes\n",
    "\n",
    "    # Loss function\n",
    "    mse = nn.MSELoss()\n",
    "\n",
    "    # Optimizers for different models\n",
    "    optimizer_g = optim.Adam(generator.parameters(), lr=lr_g)\n",
    "    optimizer_te = optim.Adam(cate_learner.parameters(), lr=lr_te)\n",
    "    optimizer_f = optim.Adam(model_f.parameters(), lr=lr_f)\n",
    "    optimizer_g_model = optim.Adam(model_g.parameters(), lr=lr_g)\n",
    "\n",
    "    # Prepare DataLoader for the observational dataset\n",
    "    data_loader = DataLoader(observational_data, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "    # Training loop\n",
    "    for epoch in range(num_epochs):\n",
    "        epoch_loss = 0.0\n",
    "\n",
    "        # Update alpha and balancing iterations based on epoch\n",
    "        if epoch < int((num_epochs * 2 / 3) - 100):\n",
    "            alpha = alpha_start\n",
    "            balancing_iterations = balancing_iterations_start\n",
    "        elif epoch > int((num_epochs * 2 / 3) + 100):\n",
    "            alpha = alpha_end\n",
    "            balancing_iterations = balancing_iterations_end\n",
    "        else:\n",
    "            balancing_iterations = int((balancing_iterations_start + balancing_iterations_end) / 2)\n",
    "            alpha = alpha_start - (alpha_start - alpha_end) * (epoch - int(num_epochs * 2 / 3 - 100)) / 200\n",
    "\n",
    "        # Loop through observational dataset in batches\n",
    "        for X_batch, T_batch, Y_batch in data_loader:\n",
    "            X_batch, T_batch, Y_batch = X_batch.to(device), T_batch.to(device), Y_batch.to(device)\n",
    "\n",
    "            # Generate U_hat from generator\n",
    "            Z = torch.randn(X_batch.shape[0], generator_input_dim, device=device)\n",
    "            U_hat = generator(Z)\n",
    "\n",
    "            # Prepare inputs for CATE learner\n",
    "            input_te = torch.cat((X_batch, U_hat, T_batch), dim=1)\n",
    "            Y_pred = cate_learner(input_te)\n",
    "\n",
    "            # Predict for T=1 and T=0\n",
    "            input_te_one = torch.cat((X_batch, U_hat, torch.ones(X_batch.shape[0], 1, device=device)), dim=1)\n",
    "            Y_pred_one = cate_learner(input_te_one)\n",
    "\n",
    "            input_te_zero = torch.cat((X_batch, U_hat, torch.zeros(X_batch.shape[0], 1, device=device)), dim=1)\n",
    "            Y_pred_zero = cate_learner(input_te_zero)\n",
    "\n",
    "\n",
    "            if X_batch.shape[0] > Y_rct[T_rct == 1].shape[0]:\n",
    "                idx = torch.randint(0, X_batch.shape[0], (Y_rct[T_rct == 1].shape[0],), device=device)\n",
    "                X_batch_small = X_batch[idx]\n",
    "            else:\n",
    "                X_batch_small = X_batch  # Default to full batch if the condition is false\n",
    "\n",
    "\n",
    "            f_Y_rct_1 = model_f(X_batch_small) * Y_rct[T_rct == 1].view(-1, 1)\n",
    "            f_Y_pred_1 = model_f(X_batch) * Y_pred_one.view(-1, 1)\n",
    "\n",
    "            if X_batch.shape[0] > Y_rct[T_rct == 0].shape[0]:\n",
    "                idx = torch.randint(0, X_batch.shape[0], (Y_rct[T_rct == 0].shape[0],), device=device)\n",
    "                X_batch_small = X_batch[idx]\n",
    "            else:\n",
    "                X_batch_small = X_batch  # Default to full batch if the condition is false\n",
    "\n",
    "\n",
    "            f_Y_rct_0 = model_f(X_batch_small) * Y_rct[T_rct == 0].view(-1, 1)\n",
    "            f_Y_pred_0 = model_f(X_batch) * Y_pred_zero.view(-1, 1)\n",
    "\n",
    "            g_Y_pred_1 = model_g(Y_pred_one.view(-1, 1))\n",
    "            g_Y_rct_1 = model_g(Y_rct[T_rct == 1].view(-1, 1))\n",
    "            g_Y_pred_0 = model_g(Y_pred_zero.view(-1, 1))\n",
    "            g_Y_rct_0 = model_g(Y_rct[T_rct == 0].view(-1, 1))\n",
    "\n",
    "            # Compute losses\n",
    "            loss1 = mse(Y_pred, Y_batch)\n",
    "            loss2 = mse(f_Y_pred_1.mean(), f_Y_rct_1.mean())\n",
    "            loss3 = mse(f_Y_pred_0.mean(), f_Y_rct_0.mean())\n",
    "            loss2p = mse(g_Y_pred_1.mean(), g_Y_rct_1.mean())\n",
    "            loss3p = mse(g_Y_pred_0.mean(), g_Y_rct_0.mean())\n",
    "            loss = alpha * loss1 + loss2 + loss3 + loss2p + loss3p\n",
    "\n",
    "            # Update generator and CATE learner\n",
    "            optimizer_te.zero_grad()\n",
    "            optimizer_g.zero_grad()\n",
    "            loss.backward(retain_graph=True)\n",
    "            optimizer_te.step()\n",
    "            optimizer_g.step()\n",
    "\n",
    "            epoch_loss += loss.item()\n",
    "\n",
    "            # Balancing iterations\n",
    "            for _ in range(balancing_iterations):\n",
    "\n",
    "                if X_batch.shape[0] > Y_rct[T_rct == 1].view(-1, 1).shape[0]:\n",
    "                    idx = np.random.choice(X_batch.shape[0], Y_rct[T_rct == 1].view(-1, 1).shape[0], replace=False)\n",
    "                    X_batch_small = X_batch[idx]\n",
    "                f_Y_rct_1 = model_f(X_batch_small)*Y_rct[T_rct == 1].view(-1, 1).detach()\n",
    "                f_Y_pred_1 = model_f(X_batch)*Y_pred_one.view(-1, 1).detach()\n",
    "                if X_batch.shape[0] > Y_rct[T_rct == 0].view(-1, 1).shape[0]:\n",
    "                    idx = np.random.choice(X_batch.shape[0], Y_rct[T_rct == 0].view(-1, 1).shape[0], replace=False)\n",
    "                    X_batch_small = X_batch[idx]\n",
    "                f_Y_rct_0 = model_f(X_batch_small)*Y_rct[T_rct == 0].view(-1, 1).detach()\n",
    "                f_Y_pred_0 = model_f(X_batch)*Y_pred_zero.view(-1, 1).detach()\n",
    "\n",
    "                g_Y_pred_1 = model_g(Y_pred_one.view(-1, 1)).detach()\n",
    "                g_Y_rct_1 = model_g(Y_rct[T_rct == 1].view(-1, 1)).detach()\n",
    "                g_Y_pred_0 = model_g(Y_pred_zero.view(-1, 1)).detach()\n",
    "                g_Y_rct_0 = model_g(Y_rct[T_rct == 0].view(-1, 1)).detach()\n",
    "\n",
    "                loss4 = mse(f_Y_rct_1.mean(), f_Y_pred_1.mean())\n",
    "                loss5 = mse(f_Y_rct_0.mean(), f_Y_pred_0.mean())\n",
    "                loss4p = mse(g_Y_pred_1.mean(), g_Y_rct_1.mean())\n",
    "                loss5p = mse(g_Y_pred_0.mean(), g_Y_rct_0.mean())\n",
    "\n",
    "                loss_f = -loss4 - loss5 -  loss4p - loss5p\n",
    "                #print(\"loss_f = \", loss_f)\n",
    "                optimizer_f.zero_grad()\n",
    "                optimizer_g_model.zero_grad()\n",
    "                loss_f.backward()\n",
    "                optimizer_f.step()\n",
    "                optimizer_g_model.step()\n",
    "\n",
    "        # Logging and printing every 100 epochs\n",
    "        if epoch % 100 == 0:\n",
    "            print(f'Epoch {epoch}, Loss: {epoch_loss / len(data_loader)}, Alpha: {alpha}, Balancing Iterations: {balancing_iterations}')\n",
    "\n",
    "    return generator, cate_learner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0, Loss: 153.7868194580078, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 100, Loss: 4.121299544970195, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 200, Loss: 4.072532773017883, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 300, Loss: 3.6777504285176597, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 400, Loss: 3.681694825490316, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 500, Loss: 3.840420047442118, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 600, Loss: 3.8060447374979653, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 700, Loss: 3.54782497882843, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 800, Loss: 3.620096206665039, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 900, Loss: 3.7090580463409424, Alpha: 100.0, Balancing Iterations: 27\n",
      "Epoch 1000, Loss: 1.975253959496816, Alpha: 50.005, Balancing Iterations: 27\n",
      "Epoch 1100, Loss: 0.17652573684851328, Alpha: 0.010000000000005116, Balancing Iterations: 27\n",
      "Epoch 1200, Loss: 0.029620071252187092, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1300, Loss: 0.007661931682378054, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1400, Loss: 0.003325920940066377, Alpha: 0.01, Balancing Iterations: 50\n"
     ]
    }
   ],
   "source": [
    "\n",
    "dim_X = X_obs_train_tensor.shape[1]\n",
    "gen_input_dim = 2\n",
    "gen_output_dim = 1   \n",
    "\n",
    "# Define model_f, model_g, and the generator\n",
    "model_f = BoundedContinuousFunctionModel(input_dim=dim_X, output_dim=1)\n",
    "model_f.to(device)\n",
    "\n",
    "model_g = BoundedContinuousFunctionModel(input_dim=1, output_dim=1)\n",
    "model_g.to(device)\n",
    "\n",
    "generator = Generator(input_dim=gen_input_dim, output_dim=gen_output_dim)\n",
    "generator.to(device)\n",
    "\n",
    "# Define the CATE learner\n",
    "cate_learner = snet(input_dim=dim_X + 1 + gen_output_dim, hidden_dim=16)\n",
    "cate_learner.to(device)\n",
    "\n",
    "# Train the new models with the mb+pb method\n",
    "generator, cate_learner = train_model_mb_plus_pb(\n",
    "    observational_data=observation_data,\n",
    "    rct_data=rct_data,\n",
    "    model_g=model_g,\n",
    "    model_f=model_f,\n",
    "    generator=generator,\n",
    "    cate_learner=cate_learner,\n",
    "    alpha_start=100,\n",
    "    alpha_end=.01,\n",
    "    generator_input_dim=gen_input_dim,\n",
    "    num_epochs=num_epochs,\n",
    "    batch_size=200,\n",
    "    device=device,\n",
    "    lr_g=0.001,\n",
    "    lr_te=0.001,\n",
    "    lr_f=0.001\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0, Loss MSE: 0.8356349468231201\n",
      "Epoch 100, Loss MSE: 0.14511456713080406\n",
      "Epoch 200, Loss MSE: 0.1427590176463127\n",
      "Epoch 300, Loss MSE: 0.14175209403038025\n",
      "Epoch 400, Loss MSE: 0.14152823388576508\n",
      "Epoch 500, Loss MSE: 0.14105702191591263\n",
      "Epoch 600, Loss MSE: 0.14162277802824974\n",
      "Epoch 700, Loss MSE: 0.14061838388442993\n",
      "Epoch 800, Loss MSE: 0.14082134142518044\n",
      "Epoch 900, Loss MSE: 0.1395990550518036\n",
      "Epoch 1000, Loss MSE: 0.13946270942687988\n",
      "Epoch 1100, Loss MSE: 0.1393289789557457\n",
      "Epoch 1200, Loss MSE: 0.13799501210451126\n",
      "Epoch 1300, Loss MSE: 0.13871362805366516\n",
      "Epoch 1400, Loss MSE: 0.13719722256064415\n"
     ]
    }
   ],
   "source": [
    "# Train a baseline model on the RCT_test call it oracle\n",
    "oracle = snet(input_dim=dim_X + 1, hidden_dim=16)\n",
    "oracle.to(device)\n",
    "rct_data_test = TensorDataset(X_rct_test_tensor, T_rct_test_tensor, Y_rct_test_tensor)\n",
    "\n",
    "oracle = train_baseline(rct_data_test, oracle, num_epochs=num_epochs, device=device, lr=0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EPEHE for mb+pb model: 0.05362270772457123\n",
      "EPEHE for baseline model: 0.7702730894088745\n"
     ]
    }
   ],
   "source": [
    "# Move all models and data to CPU for evaluation\n",
    "oracle = oracle.to('cpu')\n",
    "baseline_cate_learner = baseline_cate_learner.to('cpu')\n",
    "cate_learner = cate_learner.to('cpu')\n",
    "generator = generator.to('cpu')\n",
    "\n",
    "# Move test data to CPU\n",
    "X_rct_test_tensor = X_rct_test_tensor.to('cpu')\n",
    "T_one = torch.ones(X_rct_test_tensor.shape[0], 1)\n",
    "T_zero = torch.zeros(X_rct_test_tensor.shape[0], 1)\n",
    "\n",
    "# Oracle ITE (True ITE)\n",
    "Y_pred_one_test = oracle(torch.cat((X_rct_test_tensor, T_one), dim=1))\n",
    "Y_pred_zero_test = oracle(torch.cat((X_rct_test_tensor, T_zero), dim=1))\n",
    "true_ite = Y_pred_one_test - Y_pred_zero_test\n",
    "\n",
    "# Baseline ITE\n",
    "Y_pred_one_baseline = baseline_cate_learner(torch.cat((X_rct_test_tensor, T_one), dim=1))\n",
    "Y_pred_zero_baseline = baseline_cate_learner(torch.cat((X_rct_test_tensor, T_zero), dim=1))\n",
    "baseline_ite = Y_pred_one_baseline - Y_pred_zero_baseline\n",
    "\n",
    "# MB+PB ITE with averaging over multiple U samples\n",
    "n_samples = 10  # Number of U samples\n",
    "y_list_one = []\n",
    "y_list_zero = []\n",
    "\n",
    "for _ in range(n_samples):\n",
    "    # Sample U from generator\n",
    "    Z_test = torch.randn(X_rct_test_tensor.shape[0], gen_input_dim)\n",
    "\n",
    "    # Generate U_hat_test using generator\n",
    "    U_hat_test = generator(Z_test)\n",
    "\n",
    "    # Compute predictions for T=1\n",
    "    test_input_te_one = torch.cat((X_rct_test_tensor, U_hat_test, T_one), dim=1)\n",
    "    Y_pred_one_model = cate_learner(test_input_te_one)\n",
    "    y_list_one.append(Y_pred_one_model)\n",
    "\n",
    "    # Compute predictions for T=0\n",
    "    test_input_te_zero = torch.cat((X_rct_test_tensor, U_hat_test, T_zero), dim=1)\n",
    "    Y_pred_zero_model = cate_learner(test_input_te_zero)\n",
    "    y_list_zero.append(Y_pred_zero_model)\n",
    "\n",
    "# Average predictions over all samples of U\n",
    "Y_pred_one_avg = torch.mean(torch.stack(y_list_one), dim=0)\n",
    "Y_pred_zero_avg = torch.mean(torch.stack(y_list_zero), dim=0)\n",
    "\n",
    "# Compute ITE from averaged predictions\n",
    "ite_mb_pb = Y_pred_one_avg - Y_pred_zero_avg\n",
    "\n",
    "# Compute MSE and EPEHE\n",
    "mse = nn.MSELoss()\n",
    "epehe_mb_pb = torch.sqrt(mse(true_ite, ite_mb_pb))\n",
    "epehe_baseline = torch.sqrt(mse(true_ite, baseline_ite))\n",
    "\n",
    "# Optionally print or log the results\n",
    "print(f'EPEHE for mb+pb model: {epehe_mb_pb.item()}')\n",
    "print(f'EPEHE for baseline model: {epehe_baseline.item()}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Factual loss for oracle: 0.13867010176181793\n",
      "Factual loss for baseline: 0.2881060540676117\n",
      "Factual loss for mb+pb model: 0.1562529355287552\n"
     ]
    }
   ],
   "source": [
    "# factual loss for the oracle\n",
    "factual_loss_oracle = mse(Y_pred_one_test, Y_rct_test_tensor)\n",
    "print(f'Factual loss for oracle: {factual_loss_oracle.item()}')\n",
    "\n",
    "# factual loss for the baseline\n",
    "Y_pred_baseline = baseline_cate_learner(torch.cat((X_rct_test_tensor, T_rct_test_tensor), dim=1))\n",
    "factual_loss_baseline = mse(Y_pred_baseline, Y_rct_test_tensor)\n",
    "print(f'Factual loss for baseline: {factual_loss_baseline.item()}')\n",
    "\n",
    "# factual loss for the mb+pb model\n",
    "Y_pred_model = cate_learner(torch.cat((X_rct_test_tensor, U_hat_test, T_rct_test_tensor), dim=1))\n",
    "factual_loss_model = mse(Y_pred_model, Y_rct_test_tensor)\n",
    "print(f'Factual loss for mb+pb model: {factual_loss_model.item()}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run 1/10\n",
      "Epoch 0, Loss MSE: 1.1716575622558594\n",
      "Epoch 100, Loss MSE: 0.027475386935596664\n",
      "Epoch 200, Loss MSE: 0.11168681395550568\n",
      "Epoch 300, Loss MSE: 0.025425665080547333\n",
      "Epoch 400, Loss MSE: 0.024384191492572427\n",
      "Epoch 500, Loss MSE: 0.024048878277729575\n",
      "Epoch 600, Loss MSE: 0.04385280857483546\n",
      "Epoch 700, Loss MSE: 0.0317945114026467\n",
      "Epoch 800, Loss MSE: 0.025119701866060495\n",
      "Epoch 900, Loss MSE: 0.03400368553896745\n",
      "Epoch 1000, Loss MSE: 0.031078992101053398\n",
      "Epoch 1100, Loss MSE: 0.027521762376030285\n",
      "Epoch 1200, Loss MSE: 0.023874834024657805\n",
      "Epoch 1300, Loss MSE: 0.02887062542140484\n",
      "Epoch 1400, Loss MSE: 0.025519614884008963\n",
      "Epoch 0, Loss: 102.18389638264973, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 100, Loss: 3.984748045603434, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 200, Loss: 3.7056237856547036, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 300, Loss: 4.895785530408223, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 400, Loss: 3.651490648587545, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 500, Loss: 3.9788663387298584, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 600, Loss: 3.581846594810486, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 700, Loss: 3.5896712144215903, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 800, Loss: 4.433554728825887, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 900, Loss: 4.380411187807719, Alpha: 100.0, Balancing Iterations: 27\n",
      "Epoch 1000, Loss: 1.9721859296162922, Alpha: 50.005, Balancing Iterations: 27\n",
      "Epoch 1100, Loss: 0.1607560565074285, Alpha: 0.010000000000005116, Balancing Iterations: 27\n",
      "Epoch 1200, Loss: 0.01606638915836811, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1300, Loss: 0.002971056771154205, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1400, Loss: 0.002274013706482947, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 0, Loss MSE: 0.5267947018146515\n",
      "Epoch 100, Loss MSE: 0.14482300728559494\n",
      "Epoch 200, Loss MSE: 0.1415700986981392\n",
      "Epoch 300, Loss MSE: 0.14188216626644135\n",
      "Epoch 400, Loss MSE: 0.14104381203651428\n",
      "Epoch 500, Loss MSE: 0.14078320562839508\n",
      "Epoch 600, Loss MSE: 0.14028862863779068\n",
      "Epoch 700, Loss MSE: 0.1397748813033104\n",
      "Epoch 800, Loss MSE: 0.13879087567329407\n",
      "Epoch 900, Loss MSE: 0.1379171796143055\n",
      "Epoch 1000, Loss MSE: 0.13733838498592377\n",
      "Epoch 1100, Loss MSE: 0.13699572160840034\n",
      "Epoch 1200, Loss MSE: 0.13678763061761856\n",
      "Epoch 1300, Loss MSE: 0.1359999030828476\n",
      "Epoch 1400, Loss MSE: 0.13530504703521729\n",
      "Run 2/10\n",
      "Epoch 0, Loss MSE: 1.5197114149729412\n",
      "Epoch 100, Loss MSE: 0.057019103318452835\n",
      "Epoch 200, Loss MSE: 0.04274422178665797\n",
      "Epoch 300, Loss MSE: 0.027332752322157223\n",
      "Epoch 400, Loss MSE: 0.030694957822561264\n",
      "Epoch 500, Loss MSE: 0.02938052856673797\n",
      "Epoch 600, Loss MSE: 0.03305089039107164\n",
      "Epoch 700, Loss MSE: 0.057653903029859066\n",
      "Epoch 800, Loss MSE: 0.025033873893941443\n",
      "Epoch 900, Loss MSE: 0.03305198500553767\n",
      "Epoch 1000, Loss MSE: 0.024197256532109652\n",
      "Epoch 1100, Loss MSE: 0.026080786405752104\n",
      "Epoch 1200, Loss MSE: 0.026797425622741382\n",
      "Epoch 1300, Loss MSE: 0.02370336277332778\n",
      "Epoch 1400, Loss MSE: 0.024035708978772163\n",
      "Epoch 0, Loss: 168.98950703938803, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 100, Loss: 5.127988934516907, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 200, Loss: 3.681480566660563, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 300, Loss: 3.8282753626505532, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 400, Loss: 4.1143583456675215, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 500, Loss: 3.637691617012024, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 600, Loss: 3.626002788543701, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 700, Loss: 3.6367661158243814, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 800, Loss: 3.6261774698893228, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 900, Loss: 3.783367156982422, Alpha: 100.0, Balancing Iterations: 27\n",
      "Epoch 1000, Loss: 1.9064061244328816, Alpha: 50.005, Balancing Iterations: 27\n",
      "Epoch 1100, Loss: 0.18511147300402322, Alpha: 0.010000000000005116, Balancing Iterations: 27\n",
      "Epoch 1200, Loss: 0.03429586191972097, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1300, Loss: 0.008791127552588781, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1400, Loss: 0.0033438647321114936, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 0, Loss MSE: 0.8830142617225647\n",
      "Epoch 100, Loss MSE: 0.14470025897026062\n",
      "Epoch 200, Loss MSE: 0.14229804277420044\n",
      "Epoch 300, Loss MSE: 0.14210806787014008\n",
      "Epoch 400, Loss MSE: 0.14131999760866165\n",
      "Epoch 500, Loss MSE: 0.14075720682740211\n",
      "Epoch 600, Loss MSE: 0.14087755233049393\n",
      "Epoch 700, Loss MSE: 0.14106786996126175\n",
      "Epoch 800, Loss MSE: 0.13993357867002487\n",
      "Epoch 900, Loss MSE: 0.1392943561077118\n",
      "Epoch 1000, Loss MSE: 0.13940921425819397\n",
      "Epoch 1100, Loss MSE: 0.13820067793130875\n",
      "Epoch 1200, Loss MSE: 0.13657186925411224\n",
      "Epoch 1300, Loss MSE: 0.135783351957798\n",
      "Epoch 1400, Loss MSE: 0.13486485183238983\n",
      "Run 3/10\n",
      "Epoch 0, Loss MSE: 1.3978720903396606\n",
      "Epoch 100, Loss MSE: 0.03749436264236768\n",
      "Epoch 200, Loss MSE: 0.025593299525401864\n",
      "Epoch 300, Loss MSE: 0.02543460309971124\n",
      "Epoch 400, Loss MSE: 0.024902079758855205\n",
      "Epoch 500, Loss MSE: 0.024589345324784517\n",
      "Epoch 600, Loss MSE: 0.06408122026671965\n",
      "Epoch 700, Loss MSE: 0.03235286163787047\n",
      "Epoch 800, Loss MSE: 0.034013623371720314\n",
      "Epoch 900, Loss MSE: 0.0242163488001097\n",
      "Epoch 1000, Loss MSE: 0.024692339162963133\n",
      "Epoch 1100, Loss MSE: 0.02768487048645814\n",
      "Epoch 1200, Loss MSE: 0.025313695892691612\n",
      "Epoch 1300, Loss MSE: 0.023739313376912225\n",
      "Epoch 1400, Loss MSE: 0.02328751518507488\n",
      "Epoch 0, Loss: 56.23856735229492, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 100, Loss: 4.7438260316848755, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 200, Loss: 3.842932144800822, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 300, Loss: 3.685189684232076, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 400, Loss: 3.669167677561442, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 500, Loss: 3.49934458732605, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 600, Loss: 3.514249841372172, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 700, Loss: 3.8561034202575684, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 800, Loss: 3.4819226264953613, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 900, Loss: 4.246646006902059, Alpha: 100.0, Balancing Iterations: 27\n",
      "Epoch 1000, Loss: 1.8814613819122314, Alpha: 50.005, Balancing Iterations: 27\n",
      "Epoch 1100, Loss: 0.13602218528588614, Alpha: 0.010000000000005116, Balancing Iterations: 27\n",
      "Epoch 1200, Loss: 0.00863336306065321, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1300, Loss: 0.0022845377679914236, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1400, Loss: 0.0020108675428976617, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 0, Loss MSE: 0.6796872913837433\n",
      "Epoch 100, Loss MSE: 0.14116429537534714\n",
      "Epoch 200, Loss MSE: 0.1403430812060833\n",
      "Epoch 300, Loss MSE: 0.1403985172510147\n",
      "Epoch 400, Loss MSE: 0.14080387353897095\n",
      "Epoch 500, Loss MSE: 0.14080969989299774\n",
      "Epoch 600, Loss MSE: 0.14031002670526505\n",
      "Epoch 700, Loss MSE: 0.13960516452789307\n",
      "Epoch 800, Loss MSE: 0.13819380849599838\n",
      "Epoch 900, Loss MSE: 0.13752510398626328\n",
      "Epoch 1000, Loss MSE: 0.13750602677464485\n",
      "Epoch 1100, Loss MSE: 0.13626568764448166\n",
      "Epoch 1200, Loss MSE: 0.13598497956991196\n",
      "Epoch 1300, Loss MSE: 0.13576640188694\n",
      "Epoch 1400, Loss MSE: 0.1343042217195034\n",
      "Run 4/10\n",
      "Epoch 0, Loss MSE: 1.709114710489909\n",
      "Epoch 100, Loss MSE: 0.04207369436820348\n",
      "Epoch 200, Loss MSE: 0.03083591101070245\n",
      "Epoch 300, Loss MSE: 0.03327474556863308\n",
      "Epoch 400, Loss MSE: 0.02506199370448788\n",
      "Epoch 500, Loss MSE: 0.03350412162641684\n",
      "Epoch 600, Loss MSE: 0.03432677003244559\n",
      "Epoch 700, Loss MSE: 0.02406880430256327\n",
      "Epoch 800, Loss MSE: 0.023403775788513787\n",
      "Epoch 900, Loss MSE: 0.024413239133233827\n",
      "Epoch 1000, Loss MSE: 0.46455190020302933\n",
      "Epoch 1100, Loss MSE: 0.02334856001349787\n",
      "Epoch 1200, Loss MSE: 0.026517124225695927\n",
      "Epoch 1300, Loss MSE: 0.022787387264543213\n",
      "Epoch 1400, Loss MSE: 0.022914658504305407\n",
      "Epoch 0, Loss: 85.87676747639973, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 100, Loss: 3.7945450941721597, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 200, Loss: 4.735350449879964, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 300, Loss: 4.945220152537028, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 400, Loss: 4.757333437601726, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 500, Loss: 3.699249505996704, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 600, Loss: 3.6718873580296836, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 700, Loss: 3.667954762776693, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 800, Loss: 4.487608432769775, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 900, Loss: 4.585378448168437, Alpha: 100.0, Balancing Iterations: 27\n",
      "Epoch 1000, Loss: 1.7848923007647197, Alpha: 50.005, Balancing Iterations: 27\n",
      "Epoch 1100, Loss: 0.14701970418294272, Alpha: 0.010000000000005116, Balancing Iterations: 27\n",
      "Epoch 1200, Loss: 0.01460055261850357, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1300, Loss: 0.002670297399163246, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1400, Loss: 0.0019005232801040013, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 0, Loss MSE: 0.6917324066162109\n",
      "Epoch 100, Loss MSE: 0.1451461911201477\n",
      "Epoch 200, Loss MSE: 0.14173396676778793\n",
      "Epoch 300, Loss MSE: 0.14141248911619186\n",
      "Epoch 400, Loss MSE: 0.13996345922350883\n",
      "Epoch 500, Loss MSE: 0.14108143001794815\n",
      "Epoch 600, Loss MSE: 0.14000345766544342\n",
      "Epoch 700, Loss MSE: 0.13969076424837112\n",
      "Epoch 800, Loss MSE: 0.13884126394987106\n",
      "Epoch 900, Loss MSE: 0.1387999840080738\n",
      "Epoch 1000, Loss MSE: 0.13725779950618744\n",
      "Epoch 1100, Loss MSE: 0.13671742379665375\n",
      "Epoch 1200, Loss MSE: 0.13552124425768852\n",
      "Epoch 1300, Loss MSE: 0.13483329862356186\n",
      "Epoch 1400, Loss MSE: 0.1349015086889267\n",
      "Run 5/10\n",
      "Epoch 0, Loss MSE: 0.7897202571233114\n",
      "Epoch 100, Loss MSE: 0.0367515596250693\n",
      "Epoch 200, Loss MSE: 0.025495828556207318\n",
      "Epoch 300, Loss MSE: 0.025577901125264663\n",
      "Epoch 400, Loss MSE: 0.026239453349262476\n",
      "Epoch 500, Loss MSE: 0.024697893667810906\n",
      "Epoch 600, Loss MSE: 0.024389441202705104\n",
      "Epoch 700, Loss MSE: 0.034706961363554\n",
      "Epoch 800, Loss MSE: 0.02669484820216894\n",
      "Epoch 900, Loss MSE: 0.024514419104283054\n",
      "Epoch 1000, Loss MSE: 0.024023376987315714\n",
      "Epoch 1100, Loss MSE: 0.023919435450807214\n",
      "Epoch 1200, Loss MSE: 0.024097162143637735\n",
      "Epoch 1300, Loss MSE: 0.02359511798325305\n",
      "Epoch 1400, Loss MSE: 0.02372028826115032\n",
      "Epoch 0, Loss: 137.78648376464844, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 100, Loss: 4.102425654729207, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 200, Loss: 3.7756950855255127, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 300, Loss: 3.8736804326375327, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 400, Loss: 4.640312274297078, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 500, Loss: 3.924748659133911, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 600, Loss: 3.623108665148417, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 700, Loss: 4.921132683753967, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 800, Loss: 3.806077798207601, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 900, Loss: 3.56368895371755, Alpha: 100.0, Balancing Iterations: 27\n",
      "Epoch 1000, Loss: 1.9435461163520813, Alpha: 50.005, Balancing Iterations: 27\n",
      "Epoch 1100, Loss: 0.18503362933794656, Alpha: 0.010000000000005116, Balancing Iterations: 27\n",
      "Epoch 1200, Loss: 0.03672362491488457, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1300, Loss: 0.010744482899705568, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1400, Loss: 0.003705889917910099, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 0, Loss MSE: 0.9612376987934113\n",
      "Epoch 100, Loss MSE: 0.14654354006052017\n",
      "Epoch 200, Loss MSE: 0.14369994401931763\n",
      "Epoch 300, Loss MSE: 0.1414012610912323\n",
      "Epoch 400, Loss MSE: 0.14127692580223083\n",
      "Epoch 500, Loss MSE: 0.140136931091547\n",
      "Epoch 600, Loss MSE: 0.14010879769921303\n",
      "Epoch 700, Loss MSE: 0.14009036868810654\n",
      "Epoch 800, Loss MSE: 0.1395486518740654\n",
      "Epoch 900, Loss MSE: 0.13973824679851532\n",
      "Epoch 1000, Loss MSE: 0.13867735862731934\n",
      "Epoch 1100, Loss MSE: 0.13731446117162704\n",
      "Epoch 1200, Loss MSE: 0.13697949796915054\n",
      "Epoch 1300, Loss MSE: 0.13488328829407692\n",
      "Epoch 1400, Loss MSE: 0.1340857669711113\n",
      "Run 6/10\n",
      "Epoch 0, Loss MSE: 1.5295063654581706\n",
      "Epoch 100, Loss MSE: 0.033905530658861004\n",
      "Epoch 200, Loss MSE: 0.026162532235806186\n",
      "Epoch 300, Loss MSE: 0.026347998529672623\n",
      "Epoch 400, Loss MSE: 0.030829137812058132\n",
      "Epoch 500, Loss MSE: 0.06172037745515505\n",
      "Epoch 600, Loss MSE: 0.02450567790462325\n",
      "Epoch 700, Loss MSE: 0.04574040696024895\n",
      "Epoch 800, Loss MSE: 0.02439237484941259\n",
      "Epoch 900, Loss MSE: 0.029088786492745083\n",
      "Epoch 1000, Loss MSE: 0.025889426469802856\n",
      "Epoch 1100, Loss MSE: 0.02453176483201484\n",
      "Epoch 1200, Loss MSE: 0.024564983662761126\n",
      "Epoch 1300, Loss MSE: 0.025229356562097866\n",
      "Epoch 1400, Loss MSE: 0.031080065295100212\n",
      "Epoch 0, Loss: 52.77665710449219, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 100, Loss: 3.663210471471151, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 200, Loss: 3.6902496814727783, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 300, Loss: 3.6236950159072876, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 400, Loss: 3.556501030921936, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 500, Loss: 3.725273013114929, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 600, Loss: 4.65170141061147, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 700, Loss: 3.55344291528066, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 800, Loss: 4.170839309692383, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 900, Loss: 3.3718226750691733, Alpha: 100.0, Balancing Iterations: 27\n",
      "Epoch 1000, Loss: 1.858259677886963, Alpha: 50.005, Balancing Iterations: 27\n",
      "Epoch 1100, Loss: 0.13980508347352347, Alpha: 0.010000000000005116, Balancing Iterations: 27\n",
      "Epoch 1200, Loss: 0.005913238351543744, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1300, Loss: 0.0019908839215834937, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1400, Loss: 0.0016061179727936785, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 0, Loss MSE: 0.855495423078537\n",
      "Epoch 100, Loss MSE: 0.14467915147542953\n",
      "Epoch 200, Loss MSE: 0.14243124425411224\n",
      "Epoch 300, Loss MSE: 0.14105403423309326\n",
      "Epoch 400, Loss MSE: 0.13975033536553383\n",
      "Epoch 500, Loss MSE: 0.13990513235330582\n",
      "Epoch 600, Loss MSE: 0.13947594910860062\n",
      "Epoch 700, Loss MSE: 0.13867200165987015\n",
      "Epoch 800, Loss MSE: 0.1381230764091015\n",
      "Epoch 900, Loss MSE: 0.13884500786662102\n",
      "Epoch 1000, Loss MSE: 0.1375291347503662\n",
      "Epoch 1100, Loss MSE: 0.13726213574409485\n",
      "Epoch 1200, Loss MSE: 0.13713311403989792\n",
      "Epoch 1300, Loss MSE: 0.13564543798565865\n",
      "Epoch 1400, Loss MSE: 0.13482989370822906\n",
      "Run 7/10\n",
      "Epoch 0, Loss MSE: 0.7096010049184164\n",
      "Epoch 100, Loss MSE: 0.030847812071442604\n",
      "Epoch 200, Loss MSE: 0.03159233182668686\n",
      "Epoch 300, Loss MSE: 0.03364989347755909\n",
      "Epoch 400, Loss MSE: 0.027088866879542668\n",
      "Epoch 500, Loss MSE: 0.02545091633995374\n",
      "Epoch 600, Loss MSE: 0.02478656576325496\n",
      "Epoch 700, Loss MSE: 0.024395651610878605\n",
      "Epoch 800, Loss MSE: 0.034917816519737244\n",
      "Epoch 900, Loss MSE: 0.026327223206559818\n",
      "Epoch 1000, Loss MSE: 0.02752232210089763\n",
      "Epoch 1100, Loss MSE: 0.02370260367752053\n",
      "Epoch 1200, Loss MSE: 0.024132780051634956\n",
      "Epoch 1300, Loss MSE: 0.027608271377782028\n",
      "Epoch 1400, Loss MSE: 0.025543290966500837\n",
      "Epoch 0, Loss: 63.48448944091797, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 100, Loss: 3.9487290382385254, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 200, Loss: 3.7376348972320557, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 300, Loss: 3.784339427947998, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 400, Loss: 3.6202234427134194, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 500, Loss: 3.702706813812256, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 600, Loss: 3.599709709485372, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 700, Loss: 3.6816763083140054, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 800, Loss: 3.3982857863108316, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 900, Loss: 3.5134313901265464, Alpha: 100.0, Balancing Iterations: 27\n",
      "Epoch 1000, Loss: 1.924835244814555, Alpha: 50.005, Balancing Iterations: 27\n",
      "Epoch 1100, Loss: 0.1510993887980779, Alpha: 0.010000000000005116, Balancing Iterations: 27\n",
      "Epoch 1200, Loss: 0.012623577378690243, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1300, Loss: 0.0028011768978709974, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1400, Loss: 0.002076260706720253, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 0, Loss MSE: 0.7982994914054871\n",
      "Epoch 100, Loss MSE: 0.14485348761081696\n",
      "Epoch 200, Loss MSE: 0.14327678829431534\n",
      "Epoch 300, Loss MSE: 0.14204447716474533\n",
      "Epoch 400, Loss MSE: 0.1416688859462738\n",
      "Epoch 500, Loss MSE: 0.1414138600230217\n",
      "Epoch 600, Loss MSE: 0.141249418258667\n",
      "Epoch 700, Loss MSE: 0.14117241278290749\n",
      "Epoch 800, Loss MSE: 0.13991009443998337\n",
      "Epoch 900, Loss MSE: 0.13978419452905655\n",
      "Epoch 1000, Loss MSE: 0.13843387365341187\n",
      "Epoch 1100, Loss MSE: 0.1375342458486557\n",
      "Epoch 1200, Loss MSE: 0.1364058032631874\n",
      "Epoch 1300, Loss MSE: 0.13575530052185059\n",
      "Epoch 1400, Loss MSE: 0.1351257935166359\n",
      "Run 8/10\n",
      "Epoch 0, Loss MSE: 1.4290262858072917\n",
      "Epoch 100, Loss MSE: 0.034545104329784714\n",
      "Epoch 200, Loss MSE: 0.026011959960063297\n",
      "Epoch 300, Loss MSE: 0.026005047373473644\n",
      "Epoch 400, Loss MSE: 0.05681194613377253\n",
      "Epoch 500, Loss MSE: 0.06826965635021527\n",
      "Epoch 600, Loss MSE: 0.0368817113339901\n",
      "Epoch 700, Loss MSE: 0.0368468351662159\n",
      "Epoch 800, Loss MSE: 0.028012257690231007\n",
      "Epoch 900, Loss MSE: 0.028964429162442684\n",
      "Epoch 1000, Loss MSE: 0.024228289859214176\n",
      "Epoch 1100, Loss MSE: 0.024959970420847338\n",
      "Epoch 1200, Loss MSE: 0.03144618651519219\n",
      "Epoch 1300, Loss MSE: 0.028348503013451893\n",
      "Epoch 1400, Loss MSE: 0.02374560459672163\n",
      "Epoch 0, Loss: 131.33087412516275, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 100, Loss: 4.293428063392639, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 200, Loss: 4.030023415883382, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 300, Loss: 3.810687224070231, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 400, Loss: 3.8178915977478027, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 500, Loss: 4.792545398076375, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 600, Loss: 3.7890143394470215, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 700, Loss: 3.7198458512624106, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 800, Loss: 3.4349211057027182, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 900, Loss: 4.381944179534912, Alpha: 100.0, Balancing Iterations: 27\n",
      "Epoch 1000, Loss: 1.9756211439768474, Alpha: 50.005, Balancing Iterations: 27\n",
      "Epoch 1100, Loss: 0.17702394227186838, Alpha: 0.010000000000005116, Balancing Iterations: 27\n",
      "Epoch 1200, Loss: 0.029894066974520683, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1300, Loss: 0.007683811709284782, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1400, Loss: 0.0030189845710992813, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 0, Loss MSE: 0.5489409863948822\n",
      "Epoch 100, Loss MSE: 0.1430509388446808\n",
      "Epoch 200, Loss MSE: 0.14194925129413605\n",
      "Epoch 300, Loss MSE: 0.14116107672452927\n",
      "Epoch 400, Loss MSE: 0.14019408077001572\n",
      "Epoch 500, Loss MSE: 0.13988502323627472\n",
      "Epoch 600, Loss MSE: 0.13898615539073944\n",
      "Epoch 700, Loss MSE: 0.13838230818510056\n",
      "Epoch 800, Loss MSE: 0.13693826645612717\n",
      "Epoch 900, Loss MSE: 0.1359325535595417\n",
      "Epoch 1000, Loss MSE: 0.13389664515852928\n",
      "Epoch 1100, Loss MSE: 0.13383418321609497\n",
      "Epoch 1200, Loss MSE: 0.13290511444211006\n",
      "Epoch 1300, Loss MSE: 0.13250995054841042\n",
      "Epoch 1400, Loss MSE: 0.13284841924905777\n",
      "Run 9/10\n",
      "Epoch 0, Loss MSE: 0.9204278389612833\n",
      "Epoch 100, Loss MSE: 0.5108007943878571\n",
      "Epoch 200, Loss MSE: 0.025607555852426838\n",
      "Epoch 300, Loss MSE: 0.12917540657023588\n",
      "Epoch 400, Loss MSE: 0.14737875324984392\n",
      "Epoch 500, Loss MSE: 0.026229992974549532\n",
      "Epoch 600, Loss MSE: 0.026715691822270553\n",
      "Epoch 700, Loss MSE: 0.02375374447243909\n",
      "Epoch 800, Loss MSE: 0.05899455274144808\n",
      "Epoch 900, Loss MSE: 0.026867639894286793\n",
      "Epoch 1000, Loss MSE: 0.03660871647298336\n",
      "Epoch 1100, Loss MSE: 0.031832681968808174\n",
      "Epoch 1200, Loss MSE: 0.02443428337574005\n",
      "Epoch 1300, Loss MSE: 0.033247131233414016\n",
      "Epoch 1400, Loss MSE: 0.023679784731939435\n",
      "Epoch 0, Loss: 161.6810302734375, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 100, Loss: 4.298317591349284, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 200, Loss: 4.732818603515625, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 300, Loss: 3.71436870098114, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 400, Loss: 3.6686501502990723, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 500, Loss: 3.791015307108561, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 600, Loss: 4.6163562933603925, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 700, Loss: 3.493426243464152, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 800, Loss: 3.502745509147644, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 900, Loss: 3.52598774433136, Alpha: 100.0, Balancing Iterations: 27\n",
      "Epoch 1000, Loss: 2.0036227703094482, Alpha: 50.005, Balancing Iterations: 27\n",
      "Epoch 1100, Loss: 0.17178077002366385, Alpha: 0.010000000000005116, Balancing Iterations: 27\n",
      "Epoch 1200, Loss: 0.020972386623422306, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1300, Loss: 0.004485169968878229, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1400, Loss: 0.002472432485471169, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 0, Loss MSE: 0.7529472410678864\n",
      "Epoch 100, Loss MSE: 0.14340707659721375\n",
      "Epoch 200, Loss MSE: 0.1408318281173706\n",
      "Epoch 300, Loss MSE: 0.14141977578401566\n",
      "Epoch 400, Loss MSE: 0.1406293362379074\n",
      "Epoch 500, Loss MSE: 0.14027605205774307\n",
      "Epoch 600, Loss MSE: 0.14007974416017532\n",
      "Epoch 700, Loss MSE: 0.13937439024448395\n",
      "Epoch 800, Loss MSE: 0.13876686245203018\n",
      "Epoch 900, Loss MSE: 0.1373586505651474\n",
      "Epoch 1000, Loss MSE: 0.13721884787082672\n",
      "Epoch 1100, Loss MSE: 0.1360204741358757\n",
      "Epoch 1200, Loss MSE: 0.13575024157762527\n",
      "Epoch 1300, Loss MSE: 0.13477149978280067\n",
      "Epoch 1400, Loss MSE: 0.13453823328018188\n",
      "Run 10/10\n",
      "Epoch 0, Loss MSE: 1.4710252682367961\n",
      "Epoch 100, Loss MSE: 0.0425597969442606\n",
      "Epoch 200, Loss MSE: 0.03120768504838149\n",
      "Epoch 300, Loss MSE: 0.024980703252367675\n",
      "Epoch 400, Loss MSE: 0.028378413679699104\n",
      "Epoch 500, Loss MSE: 0.026980972848832607\n",
      "Epoch 600, Loss MSE: 0.045193431278069816\n",
      "Epoch 700, Loss MSE: 0.0339517667889595\n",
      "Epoch 800, Loss MSE: 0.024140108725987375\n",
      "Epoch 900, Loss MSE: 0.024076673629072804\n",
      "Epoch 1000, Loss MSE: 0.024242499920849998\n",
      "Epoch 1100, Loss MSE: 0.02382274637542044\n",
      "Epoch 1200, Loss MSE: 0.024321492450932663\n",
      "Epoch 1300, Loss MSE: 0.02372427824108551\n",
      "Epoch 1400, Loss MSE: 0.023235253582242876\n",
      "Epoch 0, Loss: 85.51816813151042, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 100, Loss: 4.17865792910258, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 200, Loss: 3.881063540776571, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 300, Loss: 3.724332809448242, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 400, Loss: 4.71735417842865, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 500, Loss: 3.498849074045817, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 600, Loss: 3.490237593650818, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 700, Loss: 3.53003458182017, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 800, Loss: 3.4481625159581504, Alpha: 100, Balancing Iterations: 5\n",
      "Epoch 900, Loss: 3.4937217235565186, Alpha: 100.0, Balancing Iterations: 27\n",
      "Epoch 1000, Loss: 2.3041524489720664, Alpha: 50.005, Balancing Iterations: 27\n",
      "Epoch 1100, Loss: 0.15065243343512216, Alpha: 0.010000000000005116, Balancing Iterations: 27\n",
      "Epoch 1200, Loss: 0.008524418342858553, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1300, Loss: 0.002089928680409988, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 1400, Loss: 0.0022356457387407622, Alpha: 0.01, Balancing Iterations: 50\n",
      "Epoch 0, Loss MSE: 0.8179391324520111\n",
      "Epoch 100, Loss MSE: 0.1447092667222023\n",
      "Epoch 200, Loss MSE: 0.14212239533662796\n",
      "Epoch 300, Loss MSE: 0.14189323782920837\n",
      "Epoch 400, Loss MSE: 0.14184490218758583\n",
      "Epoch 500, Loss MSE: 0.14096280932426453\n",
      "Epoch 600, Loss MSE: 0.14075727760791779\n",
      "Epoch 700, Loss MSE: 0.13992856442928314\n",
      "Epoch 800, Loss MSE: 0.13900325447320938\n",
      "Epoch 900, Loss MSE: 0.13812844082713127\n",
      "Epoch 1000, Loss MSE: 0.13742323964834213\n",
      "Epoch 1100, Loss MSE: 0.13745279610157013\n",
      "Epoch 1200, Loss MSE: 0.13651084899902344\n",
      "Epoch 1300, Loss MSE: 0.13585980981588364\n",
      "Epoch 1400, Loss MSE: 0.13639681041240692\n",
      "MB+PB Model: EPEHE Mean: 0.08283810429275036, Std: 0.024412362414383944\n",
      "Baseline Model: EPEHE Mean: 0.8220659792423248, Std: 0.030175448489736953\n",
      "MB+PB Model: Factual Loss Mean: 0.1668328806757927, Std: 0.014193834627982077\n",
      "Baseline Model: Factual Loss Mean: 0.3825508296489716, Std: 0.02200866852223781\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "from torch.utils.data import TensorDataset\n",
    "import os\n",
    "import pickle\n",
    "\n",
    "\n",
    "# Initialize lists to store results\n",
    "epehe_mb_pb_list = []\n",
    "epehe_baseline_list = []\n",
    "factual_loss_mb_pb_list = []\n",
    "factual_loss_baseline_list = []\n",
    "\n",
    "n_runs = 10  # Number of runs\n",
    "\n",
    "# Run the training and evaluation loop 10 times\n",
    "for run in range(n_runs):\n",
    "    print(f\"Run {run+1}/{n_runs}\")\n",
    "\n",
    "    # Initialize and train the baseline CATE learner\n",
    "    baseline_cate_learner = snet(input_dim=dim_X + 1, hidden_dim=16)\n",
    "    baseline_cate_learner.to(device)\n",
    "    baseline_cate_learner = train_baseline(observation_data, baseline_cate_learner, num_epochs=num_epochs, device=device, lr=0.001)\n",
    "\n",
    "    # Initialize and train the mb+pb models\n",
    "    model_f = BoundedContinuousFunctionModel(input_dim=dim_X, output_dim=1)\n",
    "    model_f.to(device)\n",
    "\n",
    "    model_g = BoundedContinuousFunctionModel(input_dim=1, output_dim=1)\n",
    "    model_g.to(device)\n",
    "\n",
    "    generator = Generator(input_dim=gen_input_dim, output_dim=gen_output_dim)\n",
    "    generator.to(device)\n",
    "\n",
    "    cate_learner = snet(input_dim=dim_X + 1 + gen_output_dim, hidden_dim=16)\n",
    "    cate_learner.to(device)\n",
    "\n",
    "    generator, cate_learner = train_model_mb_plus_pb(\n",
    "        observational_data=observation_data,\n",
    "        rct_data=rct_data,\n",
    "        model_g=model_g,\n",
    "        model_f=model_f,\n",
    "        generator=generator,\n",
    "        cate_learner=cate_learner,\n",
    "        alpha_start=100,\n",
    "        alpha_end=0.01,\n",
    "        generator_input_dim=gen_input_dim,\n",
    "        num_epochs=num_epochs,\n",
    "        batch_size=200,\n",
    "        device=device,\n",
    "        lr_g=0.001,\n",
    "        lr_te=0.001,\n",
    "        lr_f=0.001\n",
    "    )\n",
    "\n",
    "    # Train the oracle model (baseline on RCT test data)\n",
    "    oracle = snet(input_dim=dim_X + 1, hidden_dim=16)\n",
    "    oracle.to(device)\n",
    "    rct_data_test = TensorDataset(X_rct_test_tensor, T_rct_test_tensor, Y_rct_test_tensor)\n",
    "    oracle = train_baseline(rct_data_test, oracle, num_epochs=num_epochs, device=device, lr=0.001)\n",
    "\n",
    "    # Move all models and data to CPU for evaluation\n",
    "    oracle = oracle.to('cpu')\n",
    "    baseline_cate_learner = baseline_cate_learner.to('cpu')\n",
    "    cate_learner = cate_learner.to('cpu')\n",
    "    generator = generator.to('cpu')\n",
    "    X_rct_test_tensor = X_rct_test_tensor.to('cpu')\n",
    "    \n",
    "    # Oracle ITE (True ITE)\n",
    "    T_one = torch.ones(X_rct_test_tensor.shape[0], 1)\n",
    "    T_zero = torch.zeros(X_rct_test_tensor.shape[0], 1)\n",
    "    Y_pred_one_test = oracle(torch.cat((X_rct_test_tensor, T_one), dim=1))\n",
    "    Y_pred_zero_test = oracle(torch.cat((X_rct_test_tensor, T_zero), dim=1))\n",
    "    true_ite = Y_pred_one_test - Y_pred_zero_test\n",
    "\n",
    "    # Baseline ITE\n",
    "    Y_pred_one_baseline = baseline_cate_learner(torch.cat((X_rct_test_tensor, T_one), dim=1))\n",
    "    Y_pred_zero_baseline = baseline_cate_learner(torch.cat((X_rct_test_tensor, T_zero), dim=1))\n",
    "    baseline_ite = Y_pred_one_baseline - Y_pred_zero_baseline\n",
    "\n",
    "    # MB+PB ITE with averaging over multiple U samples\n",
    "    n_samples = 10\n",
    "    y_list_one = []\n",
    "    y_list_zero = []\n",
    "\n",
    "    for _ in range(n_samples):\n",
    "        # Sample U from generator\n",
    "        Z_test = torch.randn(X_rct_test_tensor.shape[0], gen_input_dim)\n",
    "\n",
    "        # Generate U_hat_test using generator\n",
    "        U_hat_test = generator(Z_test)\n",
    "\n",
    "        # Compute predictions for T=1\n",
    "        test_input_te_one = torch.cat((X_rct_test_tensor, U_hat_test, T_one), dim=1)\n",
    "        Y_pred_one_model = cate_learner(test_input_te_one)\n",
    "        y_list_one.append(Y_pred_one_model)\n",
    "\n",
    "        # Compute predictions for T=0\n",
    "        test_input_te_zero = torch.cat((X_rct_test_tensor, U_hat_test, T_zero), dim=1)\n",
    "        Y_pred_zero_model = cate_learner(test_input_te_zero)\n",
    "        y_list_zero.append(Y_pred_zero_model)\n",
    "\n",
    "    # Average predictions over all samples of U\n",
    "    Y_pred_one_avg = torch.mean(torch.stack(y_list_one), dim=0)\n",
    "    Y_pred_zero_avg = torch.mean(torch.stack(y_list_zero), dim=0)\n",
    "\n",
    "    # Compute ITE from averaged predictions\n",
    "    ite_mb_pb = Y_pred_one_avg - Y_pred_zero_avg\n",
    "\n",
    "    # Compute EPEHE (error in PEHE) and factual loss\n",
    "    mse = nn.MSELoss()\n",
    "    epehe_mb_pb = torch.sqrt(mse(true_ite, ite_mb_pb)).item()\n",
    "    epehe_baseline = torch.sqrt(mse(true_ite, baseline_ite)).item()\n",
    "\n",
    "    factual_loss_mb_pb = mse(Y_pred_one_avg, Y_rct_test_tensor).item()\n",
    "    factual_loss_baseline = mse(Y_pred_one_baseline, Y_rct_test_tensor).item()\n",
    "\n",
    "    # Append results to the lists\n",
    "    epehe_mb_pb_list.append(epehe_mb_pb)\n",
    "    epehe_baseline_list.append(epehe_baseline)\n",
    "    factual_loss_mb_pb_list.append(factual_loss_mb_pb)\n",
    "    factual_loss_baseline_list.append(factual_loss_baseline)\n",
    "\n",
    "# Compute mean and standard deviation for EPEHE and factual loss\n",
    "epehe_mb_pb_mean = np.mean(epehe_mb_pb_list)\n",
    "epehe_mb_pb_std = np.std(epehe_mb_pb_list)\n",
    "epehe_baseline_mean = np.mean(epehe_baseline_list)\n",
    "epehe_baseline_std = np.std(epehe_baseline_list)\n",
    "\n",
    "factual_loss_mb_pb_mean = np.mean(factual_loss_mb_pb_list)\n",
    "factual_loss_mb_pb_std = np.std(factual_loss_mb_pb_list)\n",
    "factual_loss_baseline_mean = np.mean(factual_loss_baseline_list)\n",
    "factual_loss_baseline_std = np.std(factual_loss_baseline_list)\n",
    "\n",
    "# Print results\n",
    "print(f'MB+PB Model: EPEHE Mean: {epehe_mb_pb_mean}, Std: {epehe_mb_pb_std}')\n",
    "print(f'Baseline Model: EPEHE Mean: {epehe_baseline_mean}, Std: {epehe_baseline_std}')\n",
    "print(f'MB+PB Model: Factual Loss Mean: {factual_loss_mb_pb_mean}, Std: {factual_loss_mb_pb_std}')\n",
    "print(f'Baseline Model: Factual Loss Mean: {factual_loss_baseline_mean}, Std: {factual_loss_baseline_std}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the experiment_results_folder if it doesn't exist\n",
    "results_folder = 'experiment_results'\n",
    "os.makedirs(results_folder, exist_ok=True)\n",
    "\n",
    "# Prepare the results to be saved\n",
    "results = {\n",
    "    'epehe_mb_pb_mean': epehe_mb_pb_mean,\n",
    "    'epehe_mb_pb_std': epehe_mb_pb_std,\n",
    "    'epehe_baseline_mean': epehe_baseline_mean,\n",
    "    'epehe_baseline_std': epehe_baseline_std,\n",
    "    'factual_loss_mb_pb_mean': factual_loss_mb_pb_mean,\n",
    "    'factual_loss_mb_pb_std': factual_loss_mb_pb_std,\n",
    "    'factual_loss_baseline_mean': factual_loss_baseline_mean,\n",
    "    'factual_loss_baseline_std': factual_loss_baseline_std\n",
    "}\n",
    "\n",
    "# Save the results to a pickle file\n",
    "results_file = os.path.join(results_folder, 'nsw_experiment_results.pkl')\n",
    "with open(results_file, 'wb') as f:\n",
    "    pickle.dump(results, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
