{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import sys\n",
    "sys.path.append('../')\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.tri as mtri\n",
    "import numpy as np\n",
    "    \n",
    "filename = 'Car_supervised_onehot.ipynb'\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device='cuda'\n",
    "TRAIN = True\n",
    "\n",
    "# device='cpu'\n",
    "# TRAIN = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "dir_data  = './data/helm_car_frontForce_hdamp/'\n",
    "dir_model = './model/SUPER_onehot_helm_car/'\n",
    "dir_plt = dir_model+'plt/'\n",
    "os.makedirs(dir_model, exist_ok=True)\n",
    "os.makedirs(dir_plt, exist_ok=True)\n",
    "import os\n",
    "pwd = os.getcwd()\n",
    "print(pwd)\n",
    "os.system(f\"jupyter nbconvert {pwd}/{filename} --to python --output {pwd}/{dir_model}run_file.py\" )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataset_train[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "with open(dir_data+f'graphs.pkl', 'rb') as f:\n",
    "    dataset = pickle.load(f)\n",
    "print(len(dataset))\n",
    "\n",
    "# Collect all data into one array for scaling\n",
    "press_p = np.zeros((1,))\n",
    "\n",
    "for i in range(len(dataset)):\n",
    "    press_p = np.concatenate((dataset[i].x.reshape(-1,), press_p), axis=0) \n",
    "\n",
    "# Compute min and max\n",
    "press_p_min = np.min(press_p)\n",
    "press_p_max = np.max(press_p)\n",
    "press_p_mean = np.mean(press_p)\n",
    "press_p_std = np.std(press_p)\n",
    "\n",
    "print('press_p_min', press_p_min)\n",
    "print('press_p_max', press_p_max)\n",
    "\n",
    "# Apply min-max scaling to each sample\n",
    "for i in range(len(dataset)):\n",
    "    # scaled = (dataset[i].x - press_p_min) / (press_p_max - press_p_min)\n",
    "    scaled = (dataset[i].x - press_p_mean) / press_p_std\n",
    "\n",
    "    dataset[i].x = torch.cat((torch.tensor(scaled).to(torch.float)[:, None], dataset[i].f.to(torch.float)[:, None]), dim=1)\n",
    "    # dataset[i].x = (dataset[i].x).to(torch.float)[:, None]\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataset_train = dataset[:6000:5].copy()\n",
    "# dataset_test  = dataset[-6000::5].copy()\n",
    "\n",
    "dataset_train = dataset[:500:].copy()\n",
    "dataset_test  = dataset[500::].copy()\n",
    "\n",
    "print('len dataset_train', len(dataset_train))\n",
    "print('len dataset_test', len(dataset_test))\n",
    "del dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_train[0]\n",
    "dataset_train[0].edge_attr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pyvista as pv\n",
    "\n",
    "# Enable notebook plotting (use 'static', 'client', 'panel', or 'pythreejs')\n",
    "pv.set_jupyter_backend('static')\n",
    "\n",
    "def plot_graph(graph, u, filename=None):\n",
    "    vertices = graph.pos.cpu().numpy()\n",
    "    faces = graph.face.cpu().numpy().T  # shape (n_faces, 3)\n",
    "\n",
    "    # PyVista expects faces as a flat array: [3, v0, v1, v2, 3, v0, v1, v2, ...]\n",
    "    faces_flat = np.hstack([np.full((faces.shape[0],1), 3), faces]).astype(np.int64).flatten()\n",
    "\n",
    "    # Create PolyData\n",
    "    mesh = pv.PolyData(vertices, faces_flat)\n",
    "    \n",
    "    mesh.point_data['Pressure'] = u[:,0]\n",
    "    # mesh.point_data['helmholtz_forcing'] = graph.f.detach().cpu().numpy()\n",
    "    # Rotate the mesh if needed\n",
    "    # mesh.rotate_x(90)\n",
    "\n",
    "    # Create a plotter\n",
    "    p = pv.Plotter(notebook=True)\n",
    "\n",
    "    # Add mesh: choose colormap and colorbar location\n",
    "    p.add_mesh(\n",
    "        mesh, \n",
    "        scalars='Pressure', \n",
    "        # scalars='helmholtz_forcing', \n",
    "        cmap='coolwarm', \n",
    "        show_scalar_bar=True,\n",
    "        scalar_bar_args={\n",
    "            'title': 'Pressure',\n",
    "            # 'title': 'Helmholtz Forcing',\n",
    "            'vertical': True,      # True = vertical colorbar; False = horizontal\n",
    "            'position_x': 0.1,    # X position in normalized [0,1]\n",
    "            'position_y': 0.1,     # Y position in normalized [0,1]\n",
    "            # 'position_x': 0.85,    # X position in normalized [0,1]\n",
    "            # 'position_y': 0.1,     # Y position in normalized [0,1]\n",
    "            'width': 0.05,\n",
    "            'height': 0.8\n",
    "        }\n",
    "    )\n",
    "    # p.show_axes()\n",
    "    p.show_grid()\n",
    "    p.camera_position = [np.array([2, 3, -2]), (0.0, 0.0, 0.0), (0, 0, 0)]\n",
    "    # p.show_axes()\n",
    "    p.show()\n",
    "    # p.save_graphic(dir_plt + 'graph_plot.pdf')\n",
    "idx = np.random.randint(0, len(dataset_train))\n",
    "plot_graph(dataset_train[idx], dataset_train[idx].x, filename=dir_plt+'recon')\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.nn import GCNConv, MLP\n",
    "from torch_geometric.loader import DataLoader\n",
    "from torch_geometric.nn import pool\n",
    "\n",
    "from torch.nn import ModuleList\n",
    "from torch_geometric.nn import GCNConv\n",
    "from torch_geometric.nn.dense import Linear\n",
    "\n",
    "# loader = DataLoader(data_list, batch_size=100, shuffle=True)\n",
    "# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "# data = next(iter(loader)).to(device)\n",
    "# print(data.num_node_features)\n",
    "\n",
    "dim_z = 100\n",
    "# dim_z = 50\n",
    "dim_y = 2\n",
    "dim_u = 1\n",
    "dim_x = 3\n",
    "dim_node_features = dim_x + dim_u\n",
    "\n",
    "\n",
    "class GCN(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(GCN, self).__init__()\n",
    "        self.c = dim_z\n",
    "        self.convs = ModuleList([\n",
    "            GCNConv((dim_y + 3)*2, self.c),\n",
    "            GCNConv(self.c*2, self.c),\n",
    "            GCNConv(self.c*2, self.c),\n",
    "            GCNConv(self.c*2, self.c),  \n",
    "            GCNConv(self.c*2, self.c),\n",
    "            GCNConv(self.c, dim_u),\n",
    "        ])\n",
    "        \n",
    "    def forward(self, graph):\n",
    "\n",
    "        x = torch.cat([graph.x,  graph.pos], dim=1)\n",
    "\n",
    "        edge_index, edge_attr = graph.edge_index, graph.edge_attr\n",
    "\n",
    "        for ctr, conv in enumerate(self.convs[:-1]):\n",
    "            x_global = pool.global_mean_pool(x, graph.batch)\n",
    "            global_x_expanded = x_global[graph.batch]  # [num_nodes, hidden_dim]\n",
    "            \n",
    "            x = torch.cat([x, global_x_expanded], dim=1)  # [num_nodes, hidden_dim * 2]\n",
    "        \n",
    "            x = F.silu(conv(x, edge_index, edge_weight=edge_attr))\n",
    "            \n",
    "        x = self.convs[-1](x, edge_index, edge_weight=edge_attr)\n",
    "\n",
    "        return x\n",
    "    \n",
    "gcn = GCN().to(device)\n",
    "\n",
    "# # from torch.optim.lr_scheduler import ExponentialLR\n",
    "# optimizerE = torch.optim.AdamW(gcnE.parameters(), lr=0.0005)\n",
    "# # schedulerE = ExponentialLR(optimizerE, gamma=0.95)\n",
    "# optimizerD = torch.optim.AdamW(gcnD.parameters(), lr=0.0005)\n",
    "# # schedulerD = ExponentialLR(optimizerD, gamma=0.95)\n",
    "\n",
    "n_epochs = 10_000\n",
    "\n",
    "from torch.optim.lr_scheduler import ExponentialLR, CosineAnnealingLR\n",
    "optimizer = torch.optim.AdamW(gcn.parameters(), lr=0.001)\n",
    "scheduler = CosineAnnealingLR(optimizer, T_max=n_epochs, eta_min=1e-5)\n",
    "\n",
    "\n",
    "def count_parameters(model):\n",
    "    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "\n",
    "print(count_parameters(gcn))    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# n_obs = 10\n",
    "sigma_tc = torch.tensor([0.01]).to(device)\n",
    "\n",
    "NOBS = np.arange(20,50)\n",
    "norm_nobs = (1./NOBS) / (np.sum(1./NOBS))  # normalize to sum to 1\n",
    "def get_rand_nobs():\n",
    "    return np.random.choice(NOBS, p=norm_nobs, size=1)[0]\n",
    "\n",
    "def loss_function(graph):\n",
    "\n",
    "    graph_copy = graph.clone()\n",
    "\n",
    "    num_graphs = graph_copy.batch.max().item() + 1\n",
    "    idx_list = []\n",
    "\n",
    "    for graph_idx in range(num_graphs):\n",
    "        # Find node indices belonging to this graph\n",
    "        node_idx = (graph_copy.batch == graph_idx).nonzero(as_tuple=False).view(-1)\n",
    "        # Randomly permute and select n_obs (or fewer if graph has fewer nodes)\n",
    "        n_obs_train = get_rand_nobs()\n",
    "        k = min(n_obs_train, node_idx.size(0))\n",
    "        perm = torch.randperm(node_idx.size(0), device=graph.x.device)[:k]\n",
    "        idx_list.append(node_idx[perm])\n",
    "\n",
    "    # Concatenate selected indices\n",
    "    idx = torch.cat(idx_list, dim=0)\n",
    "\n",
    "    # # Zero out features\n",
    "    # graph_copy.x.zero_()\n",
    "\n",
    "    # # Copy and add noise\n",
    "    # y = graph.x[idx, :1].clone()\n",
    "    # y += torch.randn_like(y) * sigma_tc\n",
    "\n",
    "    # # Place noisy observations\n",
    "    # graph_copy.x = graph_copy.x[:,:1]\n",
    "    # graph_copy.x[idx] = y\n",
    "    \n",
    "    # Zero out features\n",
    "    graph_copy.x.zero_()\n",
    "    graph_copy.x = torch.cat((graph_copy.x[:,:1], graph_copy.x[:,:1]), dim=1)\n",
    "\n",
    "    # Copy and add noise\n",
    "    y = graph.x[idx, 0].clone()\n",
    "    y += torch.randn_like(y) * sigma_tc\n",
    "\n",
    "    # Place noisy observations\n",
    "    graph_copy.x[idx, 0] = y\n",
    "    graph_copy.x[idx, 1] = 1\n",
    "\n",
    "    u = gcn(graph_copy)\n",
    "    loss = torch.mean((graph.x - u)**2.)\n",
    "\n",
    "    return loss, (None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loader_train = DataLoader(dataset_train, batch_size=100, shuffle=True)\n",
    "print(dataset_train[0])\n",
    "data_train = next(iter(loader_train))\n",
    "print(data_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from IPython.display import clear_output\n",
    "\n",
    "if TRAIN:\n",
    "    print(\"Training mode: True\")\n",
    "    gcn.train()\n",
    "\n",
    "    LOSS = []\n",
    "    LOSS_D = []\n",
    "    LOSS_L = []\n",
    "    import time\n",
    "    time_train_start = time.time()\n",
    "    for epoch in range(n_epochs):\n",
    "        start_time_b = time.time()\n",
    "        for graph in loader_train:\n",
    "            optimizer.zero_grad()\n",
    "            graph.to(device)\n",
    "            loss, aux = loss_function(graph)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "        LOSS.append(loss.item())\n",
    "        scheduler.step()\n",
    "        \n",
    "        print(f'time epoch = {time.time() - start_time_b:.3f}s', )\n",
    "        print(scheduler.get_last_lr()) # will print last learning rate.\n",
    "\n",
    "        if epoch % 1000 == 0:\n",
    "            print(\"## SAVING MODEL ##\")\n",
    "            torch.save(gcn.state_dict(), dir_model+f'gcn_{epoch}.model')\n",
    "            torch.save(optimizer.state_dict(), dir_model+f'gcn_{epoch}.opt')\n",
    "            loss_dict = {'LOSS': LOSS, 'LOSS_L': LOSS_L, 'LOSS_D': LOSS_D,\n",
    "                        'time_train':time.time()-time_train_start,}\n",
    "            with open(dir_model+f'loss_dict.pkl', 'wb') as f:\n",
    "                pickle.dump(loss_dict, f)\n",
    "\n",
    "            \n",
    "        print(\"Epoch: \", epoch, \"Loss: \", loss.item())\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if TRAIN:\n",
    "\n",
    "    time_train = time.time() - time_train_start"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if TRAIN:\n",
    "\n",
    "    gcn.eval()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# SAVE MODELS\n",
    "if TRAIN:\n",
    "    torch.save(gcn.state_dict(), dir_model+'gcn.model')\n",
    "    torch.save(optimizer.state_dict(), dir_model+'gcn.opt')\n",
    "    \n",
    "    loss_dict = {'LOSS': LOSS, 'LOSS_L': LOSS_L, 'LOSS_D': LOSS_D, 'time_train':time_train,}\n",
    "    import pickle\n",
    "    with open(dir_model+'loss_dict.pkl', 'wb') as f:\n",
    "        pickle.dump(loss_dict, f)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "plt.rcParams.update({\n",
    "    \"text.usetex\": True,\n",
    "    'text.latex.preamble': r'\\usepackage{amsfonts, amsmath, amssymb}',\n",
    "    \"font.family\": \"sans-serif\",\n",
    "    \"font.sans-serif\": [\"Helvetica\"],\n",
    "    'axes.labelsize':   18,\n",
    "    'axes.titlesize':   18,\n",
    "    'xtick.labelsize' : 16,\n",
    "    'ytick.labelsize' : 16,\n",
    "          })\n",
    "# latex font definition\n",
    "plt.rc('legend',fontsize=14)\n",
    "plt.rc('text', usetex=True)\n",
    "plt.rc('font', **{'family':'serif','serif':['Computer Modern Roman']})\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(dir_model+'loss_dict.pkl', 'rb') as f:\n",
    "    loss_dict = pickle.load(f)\n",
    "LOSS = loss_dict['LOSS']\n",
    "time_train = loss_dict['time_train']\n",
    "\n",
    "print('time_train:', time_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.semilogy(LOSS)\n",
    "plt.grid()\n",
    "plt.show()\n",
    "plt.close()\n",
    "# plt.savefig(dir_plt+ 'losses.pdf')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "if TRAIN:\n",
    "    gcn.load_state_dict(torch.load(dir_model+f'gcn.model'))\n",
    "    gcn.eval()\n",
    "    optimizer.load_state_dict(torch.load(dir_model+f'gcn.opt'))\n",
    "\n",
    "# else:\n",
    "#     # LOAD MODELS\n",
    "#     epoch = 8200\n",
    "#     gcn.load_state_dict(torch.load(dir_model+f'gcn_{epoch}.model'))\n",
    "#     gcn.eval()\n",
    "#     optimizer.load_state_dict(torch.load(dir_model+f'gcn_{epoch}.opt'))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pyvista as pv\n",
    "\n",
    "# Enable notebook plotting (use 'static', 'client', 'panel', or 'pythreejs')\n",
    "pv.set_jupyter_backend('static')\n",
    "\n",
    "def plot_graph(graph, u, filename=None):\n",
    "    vertices = graph.pos.cpu().numpy()\n",
    "    faces = graph.face.cpu().numpy().T  # shape (n_faces, 3)\n",
    "\n",
    "    # PyVista expects faces as a flat array: [3, v0, v1, v2, 3, v0, v1, v2, ...]\n",
    "    faces_flat = np.hstack([np.full((faces.shape[0],1), 3), faces]).astype(np.int64).flatten()\n",
    "\n",
    "    # Create PolyData\n",
    "    mesh = pv.PolyData(vertices, faces_flat)\n",
    "    \n",
    "    mesh.point_data['pressure'] = u\n",
    "    # Rotate the mesh if needed\n",
    "    # mesh.rotate_x(90)\n",
    "\n",
    "    # Create a plotter\n",
    "    p = pv.Plotter(notebook=True)\n",
    "\n",
    "    # Add mesh: choose colormap and colorbar location\n",
    "    p.add_mesh(\n",
    "        mesh, \n",
    "        scalars='pressure', \n",
    "        # scalars='helmholtz_forcing', \n",
    "        cmap='coolwarm', \n",
    "        show_scalar_bar=True,\n",
    "        scalar_bar_args={\n",
    "            'title': 'Helmholtz Solution',\n",
    "            # 'title': 'Helmholtz Forcing',\n",
    "            'vertical': True,      # True = vertical colorbar; False = horizontal\n",
    "            'position_x': 0.1,    # X position in normalized [0,1]\n",
    "            'position_y': 0.1,     # Y position in normalized [0,1]\n",
    "            # 'position_x': 0.85,    # X position in normalized [0,1]\n",
    "            # 'position_y': 0.1,     # Y position in normalized [0,1]\n",
    "            'width': 0.05,\n",
    "            'height': 0.8\n",
    "        }\n",
    "    )\n",
    "\n",
    "    # Set camera orientation: e.g., top view, front view, or custom camera position\n",
    "    # p.view_xy()     # Top-down (Z+)\n",
    "    # p.view_xz()     # Front view (Y+)\n",
    "    # p.view_yz()     # Side view (X+)\n",
    "\n",
    "    # Or set custom camera position: (position, focal_point, view_up)\n",
    "    # Example: look from (2,2,2) toward origin with Y-up\n",
    "    p.show_axes()\n",
    "    p.camera_position = [np.array([2, 3, -2]), (0.0, 0.0, 0.0), (0, 0, 0)]\n",
    "    p.show()\n",
    "    # p.save_graphic(dir_plt + 'graph_plot.pdf')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "# n_obs = 20\n",
    "sigma = 0.01\n",
    "\n",
    "\n",
    "mae_list = []\n",
    "\n",
    "# Start timer\n",
    "start_time = time.time()\n",
    "\n",
    "loader_test = DataLoader(dataset_test[:100], batch_size=1, shuffle=True)\n",
    "# loader_test = DataLoader([data_test_list[4]], batch_size=1, shuffle=True)\n",
    "N_test = len(dataset_test[:100])\n",
    "\n",
    "mae_list = []  # overall MAE per graph\n",
    "mae_dim_list = []  # per-dimension MAE per graph\n",
    "\n",
    "with torch.no_grad():\n",
    "    for graph in loader_test:\n",
    "        graph.to(device)\n",
    "        graph_copy = graph.clone().to(device)\n",
    "        \n",
    "\n",
    "        num_graphs = graph_copy.batch.max().item() + 1\n",
    "        idx_list = []\n",
    "\n",
    "        for graph_idx in range(num_graphs):\n",
    "            node_idx = (graph_copy.batch == graph_idx).nonzero(as_tuple=False).view(-1)\n",
    "            # Randomly permute and select n_obs (or fewer if graph has fewer nodes)\n",
    "            n_obs_train = get_rand_nobs()\n",
    "            k = min(n_obs_train, node_idx.size(0))\n",
    "            perm = torch.randperm(node_idx.size(0), device=graph.x.device)[:k]\n",
    "            idx_list.append(node_idx[perm])\n",
    "\n",
    "        # Concatenate selected indices\n",
    "        idx = torch.cat(idx_list, dim=0).to(device)\n",
    "\n",
    "        # # Zero out features\n",
    "        # graph_copy.x.zero_()\n",
    "\n",
    "        # # Copy and add noise\n",
    "        # y = graph.x[idx, :1].clone()\n",
    "        # y += torch.randn_like(y) * sigma_tc\n",
    "\n",
    "        # # Place noisy observations\n",
    "        # graph_copy.x = graph_copy.x[:,:1]\n",
    "        # graph_copy.x[idx] = y\n",
    "\n",
    "\n",
    "        graph_copy.x.zero_()\n",
    "        # graph_copy.x = torch.cat((graph_copy.x, graph_copy.x), dim=1)\n",
    "        graph_copy.x = torch.cat((graph_copy.x[:,:1], graph_copy.x[:,:1]), dim=1)\n",
    "\n",
    "        # Copy and add noise\n",
    "        y = graph.x[idx, 0].clone()\n",
    "        y += torch.randn_like(y) * sigma_tc\n",
    "\n",
    "        # Place noisy observations\n",
    "        graph_copy.x[idx, 0] = y\n",
    "        graph_copy.x[idx, 1] = 1\n",
    "\n",
    "        # Decode/predict with your GCN\n",
    "        u_decode = gcn(graph_copy)\n",
    "\n",
    "        true = graph.x\n",
    "        pred = u_decode\n",
    "\n",
    "        batch = graph.batch\n",
    "        num_graphs = batch.max().item() + 1\n",
    "\n",
    "\n",
    "        for i in range(num_graphs):\n",
    "            # Mask nodes belonging to graph i\n",
    "            mask = (batch == i)\n",
    "            true_i = true[mask]    # shape: [num_nodes_in_graph_i, 3]\n",
    "            pred_i = pred[mask]    # shape: [num_nodes_in_graph_i, 3]\n",
    "\n",
    "            # Overall MAE for graph i\n",
    "            mae_i = torch.mean(torch.abs(pred_i - true_i)).item()\n",
    "            mae_list.append(mae_i)\n",
    "\n",
    "            # Per-dimension MAE for graph i\n",
    "            mae_per_dim = torch.mean(torch.abs(pred_i - true_i), dim=0)  # shape: [3]\n",
    "            mae_dim_list.append(mae_per_dim.cpu().numpy())  # convert to numpy for easier aggregation\n",
    "\n",
    "# End timer\n",
    "total_time = time.time() - start_time\n",
    "\n",
    "# Convert per-dimension MAEs to numpy array: shape [num_graphs, 3]\n",
    "mae_dim_array = np.vstack(mae_dim_list)\n",
    "\n",
    "# Compute final stats\n",
    "mean_mae = np.mean(mae_list)\n",
    "std_mae = np.std(mae_list)\n",
    "\n",
    "# Compute mean and stddev per dimension\n",
    "mean_mae_per_dim = np.mean(mae_dim_array, axis=0)\n",
    "std_mae_per_dim = np.std(mae_dim_array, axis=0)\n",
    "\n",
    "print(f\"Mean MAE over {N_test} runs: {mean_mae:.6f}\")\n",
    "print(f\"Stddev of MAE: {std_mae:.6f}\")\n",
    "print(f\"Evaluation time: {total_time/N_test:.5f} seconds\")\n",
    "\n",
    "print(\"Per-dimension MAE statistics (u, f):\")\n",
    "print(f\"  Mean:    {mean_mae_per_dim}\")\n",
    "print(f\"  Stddev:  {std_mae_per_dim}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gabi",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
