{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import torch\n",
    "import numpy as np \n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from scipy.spatial import Delaunay\n",
    "from scipy.sparse import lil_matrix\n",
    "from scipy.sparse.linalg import spsolve\n",
    "\n",
    "import torch\n",
    "from torch_geometric.data import Data\n",
    "from torch_geometric.loader import DataLoader\n",
    "import time\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.tri as mtri\n",
    "\n",
    "import sys \n",
    "sys.path.append('../')\n",
    "from utils import plot_hr_graph\n",
    "\n",
    "device = 'cuda' \n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "with open('../data/hr_train.pkl', 'rb') as f:\n",
    "    data_list = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(4):\n",
    "    plot_hr_graph(data_list[i], data_list[i].x[:,0].cpu().detach().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.nn import MLP, Linear\n",
    "from torch_geometric.utils import to_dense_batch\n",
    "from models.stat import MMDLoss\n",
    "\n",
    "class TransformerReconstructor(torch.nn.Module):\n",
    "    def __init__(self, **kwargs):\n",
    "        super().__init__()\n",
    "        assert {'node_feature_dim', 'latent_dim', 'use_boundary_encoding', 'use_pos', 'be_dim',\n",
    "                'pos_dim', 'n_layers', 'n_heads'}.issubset(kwargs)\n",
    "\n",
    "        additional_inputs  = kwargs['be_dim'] if kwargs['use_boundary_encoding'] else 0\n",
    "        additional_inputs += kwargs['pos_dim'] if kwargs['use_pos'] else 0\n",
    "        \n",
    "        self.tokenizer = MLP(in_channels=kwargs['node_feature_dim']+additional_inputs, \n",
    "                                hidden_channels=kwargs['latent_dim'], out_channels=kwargs['latent_dim'], \n",
    "                                num_layers=2, act='relu', norm='layer')\n",
    "        \n",
    "        # Create transformer encoder with layers\n",
    "        encoder_layer = nn.TransformerEncoderLayer(\n",
    "            d_model=kwargs['latent_dim'],\n",
    "            nhead=kwargs['n_heads'],\n",
    "            dim_feedforward=kwargs['latent_dim'],\n",
    "            dropout=0.0,\n",
    "            activation='gelu',\n",
    "            batch_first=True,\n",
    "            norm_first=False\n",
    "        )\n",
    "        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=kwargs['n_layers'])\n",
    "        \n",
    "        self.detokenizer = MLP(in_channels=kwargs['latent_dim'], \n",
    "                                hidden_channels=kwargs['latent_dim'], out_channels=kwargs['node_feature_dim'], \n",
    "                                num_layers=2, act='relu', norm=None, plain_last=True)\n",
    "        \n",
    "        self.use_boundary_encoding = kwargs['use_boundary_encoding']\n",
    "        self.use_pos = kwargs['use_pos']\n",
    "\n",
    "    def forward(self, data):\n",
    "        if self.use_boundary_encoding:\n",
    "            data.x = torch.cat([data.x, data.boundary_encoding], dim=1)\n",
    "\n",
    "        if self.use_pos:\n",
    "            data.x = torch.cat([data.x, data.pos], dim=1)\n",
    "        \n",
    "        # Project the input node features to a higher-dimensional space (token space)\n",
    "        data.x = self.tokenizer(data.x)\n",
    "        \n",
    "        # Convert to dense batch with mask\n",
    "        x_dense, mask = to_dense_batch(data.x, data.batch)\n",
    "        # x_dense: [batch_size, max_num_nodes, latent_dim]\n",
    "        # mask: [batch_size, max_num_nodes]\n",
    "        \n",
    "        # Create attention mask for transformer\n",
    "        # We need to create a mask where True values are masked (ignored)\n",
    "        # PyTorch transformer expects shape [batch_size, max_num_nodes] for src_key_padding_mask\n",
    "        src_key_padding_mask = ~mask  # Invert mask: True where padding should be ignored\n",
    "        \n",
    "        # Encode through transformer layers with masking\n",
    "        x_encoded = self.encoder(x_dense, src_key_padding_mask=src_key_padding_mask)\n",
    "\n",
    "        # Convert back to flat tensor (only keep valid nodes)\n",
    "        # Extract only the valid (non-padded) nodes\n",
    "        batch_size = x_dense.size(0)\n",
    "        x_list = []\n",
    "        for i in range(batch_size):\n",
    "            valid_nodes = mask[i].sum().item()\n",
    "            x_list.append(x_encoded[i, :valid_nodes])\n",
    "\n",
    "        x = torch.cat(x_list, dim=0)\n",
    "        x = self.detokenizer(x)\n",
    "        \n",
    "        return x\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = TransformerReconstructor(node_feature_dim=2, latent_dim=64, use_boundary_encoding=False, use_pos=True, \n",
    "                                   be_dim=2, pos_dim=2, n_layers=4, n_heads=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def count_parameters(model):\n",
    "    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "print(count_parameters(model))\n",
    "print(count_parameters(model.encoder))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_obs = 10\n",
    "sigma_tc = torch.tensor([0.01]).to(device)\n",
    "\n",
    "device = 'cuda'\n",
    "model = model.to(device)\n",
    "optimizer= torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "def loss_function(graph):\n",
    "\n",
    "    graph_copy = graph.clone()\n",
    "\n",
    "    num_graphs = graph_copy.batch.max().item() + 1\n",
    "    idx_list = []\n",
    "\n",
    "    for graph_idx in range(num_graphs):\n",
    "        # Find node indices belonging to this graph\n",
    "        node_idx = (graph_copy.batch == graph_idx).nonzero(as_tuple=False).view(-1)\n",
    "        # Randomly permute and select n_obs (or fewer if graph has fewer nodes)\n",
    "        k = min(n_obs, node_idx.size(0))\n",
    "        perm = torch.randperm(node_idx.size(0), device=graph.x.device)[:k]\n",
    "        idx_list.append(node_idx[perm])\n",
    "\n",
    "    # Concatenate selected indices\n",
    "    idx = torch.cat(idx_list, dim=0)\n",
    "\n",
    "    # Zero out features\n",
    "    graph_copy.x.zero_()\n",
    "    graph_copy.x = torch.cat((graph_copy.x, graph_copy.x), dim=1)\n",
    "\n",
    "    # Copy and add noise\n",
    "    y = graph.x[idx, 0].clone()\n",
    "    y += torch.randn_like(y) * sigma_tc\n",
    "\n",
    "    # Place noisy observations\n",
    "    graph_copy.x[idx, 0] = y\n",
    "    graph_copy.x[idx, 1] = 1\n",
    "\n",
    "    u = model(graph_copy)\n",
    "    loss = torch.mean((graph.x - u)**2.)\n",
    "\n",
    "    return loss, (None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "loader_train = DataLoader(data_list, batch_size=100, shuffle=True)\n",
    "data_train = next(iter(loader_train))\n",
    "model.train()\n",
    "LOSS = []\n",
    "time_train_start = time.time()\n",
    "\n",
    "# Create checkpoint directory if it doesn't exist\n",
    "checkpoint_dir = \"../runs/transformer_supervised/\"\n",
    "os.makedirs(checkpoint_dir, exist_ok=True)\n",
    "\n",
    "# for epoch in range(100_000):\n",
    "for epoch in range(20_000):\n",
    "    start_time_b = time.time()\n",
    "    for data in loader_train:\n",
    "        optimizer.zero_grad()\n",
    "        data.to(device)\n",
    "        loss, aux = loss_function(data)\n",
    "        LOSS.append(loss.cpu().detach().numpy())\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    # scheduler.step()\n",
    "    \n",
    "    print(f'time epoch = {time.time() - start_time_b:.3f}s')\n",
    "    print(\"Epoch: \", epoch, \"Loss: \", loss.item())\n",
    "    \n",
    "    # Save checkpoint every 50 epochs\n",
    "    if (epoch + 1) % 50 == 0:\n",
    "        checkpoint_path = os.path.join(checkpoint_dir, f\"checkpoint_epoch_{epoch+1}.pt\")\n",
    "        torch.save({\n",
    "            'epoch': epoch + 1,\n",
    "            'model_state_dict': model.state_dict(),\n",
    "            'optimizer_state_dict': optimizer.state_dict(),\n",
    "            'loss': loss.item(),\n",
    "            'loss_history': LOSS,\n",
    "            'training_time': time.time() - time_train_start\n",
    "        }, checkpoint_path)\n",
    "        print(f\"Checkpoint saved at epoch {epoch+1}: {checkpoint_path}\")\n",
    "\n",
    "# Save final checkpoint\n",
    "final_checkpoint_path = os.path.join(checkpoint_dir, \"final_checkpoint.pt\")\n",
    "torch.save({\n",
    "    'epoch': epoch + 1,\n",
    "    'model_state_dict': model.state_dict(),\n",
    "    'optimizer_state_dict': optimizer.state_dict(),\n",
    "    'loss': loss.item(),\n",
    "    'loss_history': LOSS,\n",
    "    'training_time': time.time() - time_train_start\n",
    "}, final_checkpoint_path)\n",
    "print(f\"Final checkpoint saved: {final_checkpoint_path}\")\n",
    "\n",
    "print(f\"Total training time: {time.time() - time_train_start:.2f}s\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint = torch.load(\"../runs/transformer_supervised/checkpoint_epoch_1150.pt\")\n",
    "model.load_state_dict(checkpoint['model_state_dict']); optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n",
    "LOSS = checkpoint['loss_history']\n",
    "training_time = checkpoint['training_time']\n",
    "print(f\"Total training time: {training_time:.2f} seconds ({training_time/3600:.2f} hours)\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.semilogy(LOSS)\n",
    "plt.grid()\n",
    "plt.show()\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "with open('../data/hr_test.pkl', 'rb') as f:\n",
    "    data_test_list = pickle.load(f)\n",
    "data_test_list = data_test_list[:1000]\n",
    "N_test = len(data_test_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "idx_text = 0\n",
    "graph_test = data_test_list[idx_text]\n",
    "loader_decode_ABC = DataLoader([graph_test], batch_size=1)\n",
    "graph_loaded = next(iter(loader_decode_ABC)).to(device)\n",
    "\n",
    "n_obs = 10\n",
    "sigma = torch.FloatTensor([0.01])\n",
    "\n",
    "ObsIdx = np.random.choice(range(graph_test.pos.shape[0]), size=(n_obs,), replace=False)\n",
    "ObsIdx = torch.tensor(ObsIdx)\n",
    "y_n = (graph_test.x).reshape(-1,)[ObsIdx] + sigma * torch.randn(ObsIdx.shape[0])\n",
    "y_n = y_n.to(device)\n",
    "\n",
    "graph_copy = graph_loaded.clone()\n",
    "perm = torch.randperm(graph_copy.x.shape[0])\n",
    "idx = perm[:n_obs * (torch.max(graph_copy.batch)+1)]\n",
    "\n",
    "# graph_copy.x *= 0.\n",
    "# y = graph_loaded.x[idx].clone()\n",
    "# y += torch.randn_like(y) * sigma_tc\n",
    "# graph_copy.x[idx] = y\n",
    "\n",
    "graph_copy.x.zero_()\n",
    "graph_copy.x = torch.cat((graph_copy.x, graph_copy.x), dim=1)\n",
    "\n",
    "# Copy and add noise\n",
    "y = graph_loaded.x[idx, 0].clone()\n",
    "y += torch.randn_like(y) * sigma_tc\n",
    "\n",
    "# Place noisy observations\n",
    "graph_copy.x[idx, 0] = y\n",
    "graph_copy.x[idx, 1] = 1\n",
    "\n",
    "\n",
    "u_decode = model(graph_copy)\n",
    "\n",
    "print('data')\n",
    "plot_hr_graph(graph_test, graph_test.x[:,0].cpu().detach().numpy(),\n",
    "              ObsIdx=ObsIdx)\n",
    "\n",
    "print('u_min_norm')\n",
    "plot_hr_graph(graph_test, u_decode[:,0].detach().cpu().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import time\n",
    "\n",
    "# Number of runs\n",
    "num_runs = 100\n",
    "n_obs = 10\n",
    "sigma = 0.01\n",
    "\n",
    "mae_list = []\n",
    "\n",
    "# Start timer\n",
    "start_time = time.time()\n",
    "\n",
    "for i in range(num_runs):\n",
    "    # Clone test graph\n",
    "    graph_copy = graph_loaded.clone()\n",
    "\n",
    "    # Random observation indices\n",
    "    ObsIdx = np.random.choice(range(graph_copy.pos.shape[0]), size=(n_obs,), replace=False)\n",
    "    ObsIdx = torch.tensor(ObsIdx, device=graph_copy.x.device)\n",
    "\n",
    "    # Add noise to true values at observation points\n",
    "    # y_obs = graph_loaded.x[ObsIdx].clone()\n",
    "    # y_noisy = y_obs + sigma * torch.randn_like(y_obs)\n",
    "\n",
    "    # Zero out all node features\n",
    "    # graph_copy.x *= 0.0\n",
    "\n",
    "    # Place noisy observations at sampled locations\n",
    "    # graph_copy.x[ObsIdx] = y_noisy\n",
    "\n",
    "    # Zero out features\n",
    "    graph_copy.x.zero_()\n",
    "    graph_copy.x = torch.cat((graph_copy.x, graph_copy.x), dim=1)\n",
    "\n",
    "    # Copy and add noise\n",
    "    y = graph_loaded.x[ObsIdx, 0].clone()\n",
    "    y += torch.randn_like(y) * sigma_tc\n",
    "\n",
    "    # Place noisy observations\n",
    "    graph_copy.x[ObsIdx, 0] = y\n",
    "    graph_copy.x[ObsIdx, 1] = 1\n",
    "\n",
    "\n",
    "    # Decode/predict\n",
    "    u_decode = model(graph_copy)\n",
    "\n",
    "    # Compute MAE for this run\n",
    "    true = graph_loaded.x.reshape(-1)\n",
    "    pred = u_decode[:,0].reshape(-1)\n",
    "    mae = torch.mean(torch.abs(pred - true)).item()\n",
    "    mae_list.append(mae)\n",
    "\n",
    "# End timer\n",
    "total_time = time.time() - start_time\n",
    "\n",
    "# Compute final stats\n",
    "mean_mae = np.mean(mae_list)\n",
    "std_mae = np.std(mae_list)\n",
    "\n",
    "print(f\"Mean MAE over {num_runs} runs: {mean_mae:.6f}\")\n",
    "print(f\"Stddev of MAE: {std_mae:.6f}\")\n",
    "print(f\"Evaluation time: {total_time/100:.5f} seconds\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "n_obs = 10\n",
    "sigma = 0.01\n",
    "\n",
    "\n",
    "mae_list = []\n",
    "\n",
    "# Start timer\n",
    "start_time = time.time()\n",
    "\n",
    "loader_test = DataLoader(data_test_list, batch_size=1, shuffle=True)\n",
    "# loader_test = DataLoader([data_test_list[4]], batch_size=1, shuffle=True)\n",
    "\n",
    "for graph in loader_test:\n",
    "    graph.to(device)\n",
    "    graph_copy = graph.clone().to(device)\n",
    "    \n",
    "\n",
    "    num_graphs = graph_copy.batch.max().item() + 1\n",
    "    idx_list = []\n",
    "\n",
    "    for graph_idx in range(num_graphs):\n",
    "        # Find node indices belonging to this graph\n",
    "        node_idx = (graph_copy.batch == graph_idx).nonzero(as_tuple=False).view(-1)\n",
    "        # Randomly permute and select n_obs (or fewer if graph has fewer nodes)\n",
    "        k = min(n_obs, node_idx.size(0))\n",
    "        perm = torch.randperm(node_idx.size(0), device=graph.x.device)[:k]\n",
    "        idx_list.append(node_idx[perm])\n",
    "\n",
    "    # Concatenate selected indices\n",
    "    idx = torch.cat(idx_list, dim=0).to(device)\n",
    "\n",
    "    # # Zero out features\n",
    "    # graph_copy.x.zero_()\n",
    "\n",
    "    # # Copy and add noise\n",
    "    # y = graph.x[idx].clone()\n",
    "    # y += torch.randn_like(y) * sigma_tc\n",
    "\n",
    "    # # Place noisy observations\n",
    "    # graph_copy.x[idx] = y\n",
    "    \n",
    "        # Zero out features\n",
    "    graph_copy.x.zero_()\n",
    "    graph_copy.x = torch.cat((graph_copy.x, graph_copy.x), dim=1)\n",
    "\n",
    "    # Copy and add noise\n",
    "    y = graph.x[idx, 0].clone()\n",
    "    y += torch.randn_like(y) * sigma_tc\n",
    "\n",
    "    # Place noisy observations\n",
    "    graph_copy.x[idx, 0] = y\n",
    "    graph_copy.x[idx, 1] = 1\n",
    "\n",
    "    # Decode/predict with your GCN\n",
    "    u_decode = model(graph_copy)\n",
    "\n",
    "    true =   graph.x\n",
    "    pred = u_decode\n",
    "\n",
    "    batch = graph.batch\n",
    "    num_graphs = batch.max().item() + 1\n",
    "\n",
    "    for i in range(num_graphs):\n",
    "        # Mask nodes belonging to graph i\n",
    "        mask = (batch == i)\n",
    "        true_i = true[mask]\n",
    "        pred_i = pred[mask]\n",
    "        mae_i = torch.mean(torch.abs(pred_i - true_i)).item()\n",
    "        mae_list.append(mae_i)\n",
    "    \n",
    "# End timer\n",
    "total_time = time.time() - start_time\n",
    "\n",
    "# Compute final stats\n",
    "mean_mae = np.mean(mae_list)\n",
    "std_mae = np.std(mae_list)\n",
    "\n",
    "print(f\"Mean MAE over {N_test} runs: {mean_mae:.6f}\")\n",
    "print(f\"Stddev of MAE: {std_mae:.6f}\")\n",
    "print(f\"Evaluation time: {total_time/1000:.5f} seconds\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Mean MAE over {N_test} runs: {mean_mae:.6e}\")\n",
    "print(f\"Stddev of MAE: {std_mae:.6e}\")\n",
    "print(f\"Evaluation time: {total_time/1000:.5e} seconds\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "n_obs = 30\n",
    "sigma = 0.01\n",
    "\n",
    "\n",
    "mae_list = []\n",
    "\n",
    "# Start timer\n",
    "start_time = time.time()\n",
    "\n",
    "loader_test = DataLoader(data_test_list, batch_size=1, shuffle=True)\n",
    "# loader_test = DataLoader([data_test_list[4]], batch_size=1, shuffle=True)\n",
    "\n",
    "for graph in loader_test:\n",
    "    graph.to(device)\n",
    "    graph_copy = graph.clone().to(device)\n",
    "    \n",
    "\n",
    "    num_graphs = graph_copy.batch.max().item() + 1\n",
    "    idx_list = []\n",
    "\n",
    "    for graph_idx in range(num_graphs):\n",
    "        # Find node indices belonging to this graph\n",
    "        node_idx = (graph_copy.batch == graph_idx).nonzero(as_tuple=False).view(-1)\n",
    "        # Randomly permute and select n_obs (or fewer if graph has fewer nodes)\n",
    "        k = min(n_obs, node_idx.size(0))\n",
    "        perm = torch.randperm(node_idx.size(0), device=graph.x.device)[:k]\n",
    "        idx_list.append(node_idx[perm])\n",
    "\n",
    "    # Concatenate selected indices\n",
    "    idx = torch.cat(idx_list, dim=0).to(device)\n",
    "\n",
    "    # # Zero out features\n",
    "    # graph_copy.x.zero_()\n",
    "\n",
    "    # # Copy and add noise\n",
    "    # y = graph.x[idx].clone()\n",
    "    # y += torch.randn_like(y) * sigma_tc\n",
    "\n",
    "    # # Place noisy observations\n",
    "    # graph_copy.x[idx] = y\n",
    "    \n",
    "        # Zero out features\n",
    "    graph_copy.x.zero_()\n",
    "    graph_copy.x = torch.cat((graph_copy.x, graph_copy.x), dim=1)\n",
    "\n",
    "    # Copy and add noise\n",
    "    y = graph.x[idx, 0].clone()\n",
    "    y += torch.randn_like(y) * sigma_tc\n",
    "\n",
    "    # Place noisy observations\n",
    "    graph_copy.x[idx, 0] = y\n",
    "    graph_copy.x[idx, 1] = 1\n",
    "\n",
    "    # Decode/predict with your GCN\n",
    "    u_decode = model(graph_copy)\n",
    "\n",
    "    true =  graph.x\n",
    "    pred = u_decode\n",
    "\n",
    "    batch = graph.batch\n",
    "    num_graphs = batch.max().item() + 1\n",
    "\n",
    "    for i in range(num_graphs):\n",
    "        # Mask nodes belonging to graph i\n",
    "        mask = (batch == i)\n",
    "        true_i = true[mask]\n",
    "        pred_i = pred[mask]\n",
    "        mae_i = torch.mean(torch.abs(pred_i - true_i)).item()\n",
    "        mae_list.append(mae_i)\n",
    "    \n",
    "# End timer\n",
    "total_time = time.time() - start_time\n",
    "\n",
    "# Compute final stats\n",
    "mean_mae = np.mean(mae_list)\n",
    "std_mae = np.std(mae_list)\n",
    "\n",
    "print(f\"Mean MAE over {N_test} runs: {mean_mae:.6f}\")\n",
    "print(f\"Stddev of MAE: {std_mae:.6f}\")\n",
    "print(f\"Evaluation time: {total_time/1000:.5f} seconds\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Mean MAE over {N_test} runs: {mean_mae:.6e}\")\n",
    "print(f\"Stddev of MAE: {std_mae:.6e}\")\n",
    "print(f\"Evaluation time: {total_time/1000:.5e} seconds\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
