{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import torch.multiprocessing as mp\n",
    "# mp.set_start_method('spawn')\n",
    "\n",
    "import torch\n",
    "import numpy as np \n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from scipy.spatial import Delaunay\n",
    "from scipy.sparse import lil_matrix\n",
    "from scipy.sparse.linalg import spsolve\n",
    "\n",
    "import torch\n",
    "from torch_geometric.data import Data\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.tri as mtri\n",
    "\n",
    "from GABI.solver.heat_rect_unstruct import make_sln_graph, plot_slngraph\n",
    "\n",
    "\n",
    "device = 'cuda' \n",
    "\n",
    "filename = 'supervised_onehot.ipynb'\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "dir_data = './models/airfoil_gabi_1/'\n",
    "dir_save = './models/airfoil_SUPER_ONEHOT_1/'\n",
    "dir_plt = dir_save+'plt/'\n",
    "os.makedirs(dir_save, exist_ok=True)\n",
    "os.makedirs(dir_plt, exist_ok=True)\n",
    "\n",
    "import os\n",
    "pwd = os.getcwd()\n",
    "print(pwd)\n",
    "os.system(f\"jupyter nbconvert {pwd}/{filename} --to python --output {pwd}/{dir_save}run_file.py\" )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "with open(dir_data+'airfoil_thinned_train_1.pkl', 'rb') as f:\n",
    "    thin_dataset_1 = pickle.load(f)\n",
    "with open(dir_data+'airfoil_thinned_train_2.pkl', 'rb') as f:\n",
    "    thin_dataset_2 = pickle.load(f)\n",
    "    \n",
    "thin_dataset = thin_dataset_1 + thin_dataset_2\n",
    "\n",
    "print(len(thin_dataset))\n",
    "dataset_train = thin_dataset[:1000]\n",
    "dataset_test  = thin_dataset[-100:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "    \n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.tri as mtri\n",
    "from matplotlib.colors import Normalize\n",
    "\n",
    "def plot_graph(graph, u, scatter_data=None, filename='graph'):\n",
    "    pos = graph.pos.detach().cpu().numpy()\n",
    "    face = graph.face.detach().cpu().numpy()\n",
    "    # u = u.detach().cpu().numpy()\n",
    "    \n",
    "    triang = mtri.Triangulation(pos[:, 0], pos[:, 1], face.T)\n",
    "\n",
    "    # Create shared normalization (min/max of u)\n",
    "    norm = Normalize(vmin=u.min(), vmax=u.max())\n",
    "\n",
    "    # Plot filled contour\n",
    "    contourf = plt.tricontourf(triang, u, levels=20, cmap='viridis', norm=norm)\n",
    "\n",
    "    # Plot triangle edges\n",
    "    plt.triplot(triang, color='black', alpha=0.05, linewidth=0.01)\n",
    "\n",
    "    # Overlay scatter points\n",
    "    if scatter_data is not None:\n",
    "        scatter_vals = scatter_data[2]\n",
    "        plt.scatter(\n",
    "            scatter_data[0], scatter_data[1],\n",
    "            c=scatter_vals,\n",
    "            cmap='viridis',\n",
    "            norm=norm,  # match color scale\n",
    "            s=10.,\n",
    "            edgecolors='red',\n",
    "            alpha=1\n",
    "        )\n",
    "\n",
    "    # Colorbar linked to the contour (and now scatter as well)\n",
    "    plt.colorbar(contourf)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(filename + '.pdf')\n",
    "    plt.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.nn import GCNConv, MLP\n",
    "from torch_geometric.loader import DataLoader\n",
    "from torch_geometric.nn import pool\n",
    "\n",
    "from torch.nn import ModuleList\n",
    "from torch_geometric.nn import GCNConv\n",
    "from torch_geometric.nn.dense import Linear\n",
    "\n",
    "# loader = DataLoader(data_list, batch_size=100, shuffle=True)\n",
    "# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "# data = next(iter(loader)).to(device)\n",
    "# print(data.num_node_features)\n",
    "\n",
    "dim_z = 100\n",
    "# dim_z = 50\n",
    "dim_y = 2\n",
    "dim_u = 3\n",
    "dim_x = 2\n",
    "dim_node_features = dim_x + dim_u\n",
    "\n",
    "\n",
    "class GCN(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(GCN, self).__init__()\n",
    "        self.c = dim_z\n",
    "        self.convs = ModuleList([\n",
    "            GCNConv((dim_y + 2)*2, self.c),\n",
    "            GCNConv(self.c*2, self.c),\n",
    "            GCNConv(self.c*2, self.c),\n",
    "            GCNConv(self.c*2, self.c),  \n",
    "            GCNConv(self.c*2, self.c),\n",
    "            GCNConv(self.c*2, self.c), \n",
    "            GCNConv(self.c*2, self.c),\n",
    "            GCNConv(self.c, dim_u),\n",
    "        ])\n",
    "        \n",
    "    def forward(self, graph):\n",
    "                \n",
    "        x = torch.cat([graph.x,  graph.pos], dim=1)\n",
    "\n",
    "        edge_index, edge_attr = graph.edge_index, graph.edge_attr\n",
    "\n",
    "        for ctr, conv in enumerate(self.convs[:-1]):\n",
    "            x_global = pool.global_mean_pool(x, graph.batch)\n",
    "            global_x_expanded = x_global[graph.batch]  # [num_nodes, hidden_dim]\n",
    "            \n",
    "            x = torch.cat([x, global_x_expanded], dim=1)  # [num_nodes, hidden_dim * 2]\n",
    "        \n",
    "            x = F.silu(conv(x, edge_index, edge_weight=edge_attr))\n",
    "            \n",
    "        x = self.convs[-1](x, edge_index, edge_weight=edge_attr)\n",
    "\n",
    "        return x\n",
    "    \n",
    "gcn = GCN().to(device)\n",
    "\n",
    "from torch.optim.lr_scheduler import ExponentialLR\n",
    "optimizer = torch.optim.Adam(gcn.parameters(), lr=0.001)\n",
    "scheduler = ExponentialLR(optimizer, gamma=0.999)\n",
    "\n",
    "def count_parameters(model):\n",
    "    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "\n",
    "print(count_parameters(gcn))\n",
    "print(count_parameters(gcn))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def observation_map(graph,  graph_idx, n_obs):\n",
    "    \n",
    "    indices = torch.where(graph.airf_nodes * (graph.batch == graph_idx) == 1)[0]\n",
    "    # mask_graph\n",
    "    left_most = indices[torch.argmin(graph.pos[indices, 0])]\n",
    "    high_most = indices[torch.argmax(graph.pos[indices, 1])]\n",
    "    low_most  = indices[torch.argmin(graph.pos[indices, 1])]\n",
    "    # indices = np.where(data_test_graph.airf_nodes*0 == 0)[0]\n",
    "    # ObsIdx = np.random.choice(indices, size=n_obs-3, replace=False)\n",
    "    perm = torch.randperm(indices.size(0), device=graph.x.device)[:n_obs-3]\n",
    "    ObsIdx = indices[perm]\n",
    "    ObsIdx = torch.tensor(ObsIdx)\n",
    "    ObsIdx = torch.cat( (ObsIdx, torch.tensor([left_most,high_most, low_most]).to(device)), dim=0)  # add the first node\n",
    "    \n",
    "    return ObsIdx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NOBS = np.arange(5,50)\n",
    "norm_nobs = (1./NOBS) / (np.sum(1./NOBS))  # normalize to sum to 1\n",
    "def get_rand_nobs():\n",
    "    return np.random.choice(NOBS, p=norm_nobs, size=1)[0]\n",
    "\n",
    "samples = [get_rand_nobs() for _ in range(10_000)]\n",
    "\n",
    "\n",
    "plt.hist(samples, density=True, bins=48)\n",
    "plt.plot(NOBS, 1./NOBS / (np.sum(1./NOBS)))\n",
    "plt.show()\n",
    "plt.close()\n",
    "\n",
    "np.unique(samples, return_counts=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "sigma_tc = torch.tensor([0.01]).to(device)\n",
    "\n",
    "def loss_function(graph):\n",
    "\n",
    "    graph_copy = graph.clone()\n",
    "\n",
    "    num_graphs = graph_copy.batch.max().item() + 1\n",
    "    idx_list = []\n",
    "\n",
    "    for graph_idx in range(num_graphs):\n",
    "        # Find node indices belonging to this graph\n",
    "        # node_idx = (graph_copy.batch == graph_idx).nonzero(as_tuple=False).view(-1)\n",
    "        # Randomly permute and select n_obs (or fewer if graph has fewer nodes)\n",
    "        # k = min(n_obs, node_idx.size(0))\n",
    "        # perm = torch.randperm(node_idx.size(0), device=graph.x.device)[:k]\n",
    "        n_obs = get_rand_nobs()\n",
    "        node_obs_idx = observation_map(graph_copy, graph_idx, n_obs)        \n",
    "        idx_list.append(node_obs_idx)\n",
    "\n",
    "    # Concatenate selected indices\n",
    "    idx = torch.cat(idx_list, dim=0)\n",
    "\n",
    "    # # Zero out features\n",
    "    # graph_copy.x.zero_()\n",
    "\n",
    "    # # Copy and add noise\n",
    "    # y = graph.x[idx, :1].clone()\n",
    "    # y += torch.randn_like(y) * sigma_tc\n",
    "\n",
    "    # # Place noisy observations\n",
    "    # graph_copy.x = graph_copy.x[:,:1]\n",
    "    # graph_copy.x[idx] = y\n",
    "    \n",
    "    # Zero out features\n",
    "    graph_copy.x.zero_()\n",
    "    graph_copy.x = torch.cat((graph_copy.x[:,:1], graph_copy.x[:,:1]), dim=1)\n",
    "\n",
    "    # Copy and add noise\n",
    "    y = graph.x[idx, 0].clone()\n",
    "    y += torch.randn_like(y) * sigma_tc\n",
    "\n",
    "    # Place noisy observations\n",
    "    graph_copy.x[idx, 0] = y\n",
    "    graph_copy.x[idx, 1] = 1\n",
    "\n",
    "    u = gcn(graph_copy)\n",
    "    loss = torch.mean((graph.x - u)**2.)\n",
    "\n",
    "    return loss, (None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "loader_train = DataLoader(dataset_train, batch_size=100, shuffle=True)\n",
    "data_train = next(iter(loader_train))\n",
    "\n",
    "gcn.train()\n",
    "\n",
    "LOSS = []\n",
    "import time\n",
    "time_train_start = time.time()\n",
    "# for epoch in range(100_000):\n",
    "for epoch in range(10_000):\n",
    "\n",
    "    start_time_b = time.time()\n",
    "    for data in loader_train:\n",
    "        optimizer.zero_grad()\n",
    "        data.to(device)\n",
    "        loss, aux = loss_function(data)\n",
    "        LOSS.append(loss.cpu().detach().numpy())\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    scheduler.step()\n",
    "    print(f'time epoch = {time.time() - start_time_b:.3f}s', )\n",
    "        \n",
    "    print(\"Epoch: \", epoch, \"Loss: \", loss.item())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "time_train = time.time() - time_train_start"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(time_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gcn.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# SAVE MODELS\n",
    "torch.save(gcn.state_dict(), dir_save+'gcn.model')\n",
    "torch.save(optimizer.state_dict(), dir_save+'gcn.opt')\n",
    "\n",
    "loss_dict = {'LOSS': LOSS, 'time_train':time_train,}\n",
    "import pickle\n",
    "with open(dir_save+'loss_dict.pkl', 'wb') as f:\n",
    "    pickle.dump(loss_dict, f)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(dir_save+'loss_dict.pkl', 'rb') as f:\n",
    "    loss_dict = pickle.load(f)\n",
    "LOSS = loss_dict['LOSS']\n",
    "time_train = loss_dict['time_train']\n",
    "print('time_train', time_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.semilogy(LOSS)\n",
    "plt.grid()\n",
    "plt.show()\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# LOAD MODELS\n",
    "gcn.load_state_dict(torch.load(dir_save+'gcn.model'))\n",
    "gcn.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx_text = 0\n",
    "graph_test = dataset_test[idx_text]\n",
    "loader_decode_ABC = DataLoader([graph_test], batch_size=1)\n",
    "graph_loaded = next(iter(loader_decode_ABC)).to(device)\n",
    "# u_abc_decode = torch.vmap(gcnD, in_dims=(0, None), chunk_size=100)(z_grid[:, None, :], graph_ABC_loaded)\n",
    "\n",
    "# n_obs = 20\n",
    "sigma = torch.FloatTensor([0.01])\n",
    "\n",
    "# ObsIdx = np.random.choice(range(graph_test.pos.shape[0]), size=(n_obs,), replace=False)\n",
    "# ObsIdx = torch.tensor(ObsIdx)\n",
    "# y_n = (graph_test.x).reshape(-1,)[ObsIdx] + sigma * torch.randn(ObsIdx.shape[0])\n",
    "# y_n = y_n.to(device)\n",
    "\n",
    "graph_copy = graph_loaded.clone()\n",
    "# for graph_idx in range(num_graphs):\n",
    "    # Find node indices belonging to this graph\n",
    "node_idx = (graph_copy.batch == 0).nonzero(as_tuple=False).view(-1)\n",
    "# Randomly permute and select n_obs (or fewer if graph has fewer nodes)\n",
    "# k = min(n_obs, node_idx.size(0))\n",
    "n_obs = 20\n",
    "node_obs_idx = observation_map(graph_copy, 0, n_obs)\n",
    "\n",
    "# perm = torch.randperm(node_idx.size(0), device=graph.x.device)[:k]\n",
    "# idx_list = node_obs_idx\n",
    "\n",
    "# Concatenate selected indices\n",
    "ObsIdx = node_obs_idx\n",
    "\n",
    "# Zero out features\n",
    "# graph_copy.x.zero_()\n",
    "\n",
    "# # Copy and add noise\n",
    "# y = graph_loaded.x[ObsIdx, :1].clone()\n",
    "# y += torch.randn_like(y) * sigma_tc\n",
    "\n",
    "# # Place noisy observations\n",
    "# graph_copy.x = graph_copy.x[:,:1]\n",
    "# graph_copy.x[ObsIdx] = y\n",
    "\n",
    "graph_copy.x.zero_()\n",
    "# graph_copy.x = torch.cat((graph_copy.x, graph_copy.x), dim=1)\n",
    "graph_copy.x = torch.cat((graph_copy.x[:,:1], graph_copy.x[:,:1]), dim=1)\n",
    "\n",
    "# Copy and add noise\n",
    "y = graph_loaded.x[ObsIdx, 0].clone()\n",
    "y += torch.randn_like(y) * sigma_tc\n",
    "\n",
    "# Place noisy observations\n",
    "graph_copy.x[ObsIdx, 0] = y\n",
    "graph_copy.x[ObsIdx, 1] = 1\n",
    "\n",
    "u_decode = gcn(graph_copy)\n",
    "\n",
    "\n",
    "pos = graph_test.pos.detach().cpu().numpy()\n",
    "ObsIdx = ObsIdx.detach().cpu().numpy()\n",
    "plot_graph(graph_test, graph_test.x[:,0].detach().cpu().numpy(), \n",
    "           scatter_data=(pos[ObsIdx, 0], pos[ObsIdx, 1], y.detach().cpu().numpy()\n",
    "                         ), filename=dir_plt+'airf_p_true')\n",
    "plt.close()\n",
    "\n",
    "plot_graph(graph_test, u_decode[:,0].detach().cpu().numpy(), \n",
    "           scatter_data=(pos[ObsIdx, 0], pos[ObsIdx, 1], y.detach().cpu().numpy()\n",
    "                         ), filename=dir_plt+'airf_p_infer_pred')\n",
    "plt.close()\n",
    "\n",
    "plot_graph(graph_test, graph_test.x[:,1].detach().cpu().numpy(), \n",
    "           scatter_data=(pos[ObsIdx, 0], pos[ObsIdx, 1], y.detach().cpu().numpy()\n",
    "                         ), filename=dir_plt+'airf_vx_true')\n",
    "plt.close()\n",
    "\n",
    "plot_graph(graph_test, u_decode[:,1].detach().cpu().numpy(), \n",
    "           scatter_data=(pos[ObsIdx, 0], pos[ObsIdx, 1], y.detach().cpu().numpy()\n",
    "                         ), filename=dir_plt+'airf_vx_infer_pred')\n",
    "plt.close()\n",
    "\n",
    "plot_graph(graph_test, graph_test.x[:,2].detach().cpu().numpy(), \n",
    "           scatter_data=(pos[ObsIdx, 0], pos[ObsIdx, 1], y.detach().cpu().numpy()\n",
    "                         ), filename=dir_plt+'airf_vy_true')\n",
    "plt.close()\n",
    "\n",
    "plot_graph(graph_test, u_decode[:,2].detach().cpu().numpy(), \n",
    "           scatter_data=(pos[ObsIdx, 0], pos[ObsIdx, 1], y.detach().cpu().numpy()\n",
    "                         ), filename=dir_plt+'airf_vy_infer_pred')\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "# n_obs = 20\n",
    "sigma = 0.01\n",
    "\n",
    "\n",
    "mae_list = []\n",
    "\n",
    "# Start timer\n",
    "start_time = time.time()\n",
    "\n",
    "loader_test = DataLoader(dataset_test, batch_size=1, shuffle=True)\n",
    "# loader_test = DataLoader([data_test_list[4]], batch_size=1, shuffle=True)\n",
    "N_test = len(dataset_test)\n",
    "\n",
    "mae_list = []  # overall MAE per graph\n",
    "mae_dim_list = []  # per-dimension MAE per graph\n",
    "\n",
    "with torch.no_grad():\n",
    "    for graph in loader_test:\n",
    "        graph.to(device)\n",
    "        graph_copy = graph.clone().to(device)\n",
    "        \n",
    "\n",
    "        num_graphs = graph_copy.batch.max().item() + 1\n",
    "        idx_list = []\n",
    "\n",
    "        for graph_idx in range(num_graphs):\n",
    "            n_obs = get_rand_nobs()\n",
    "            node_obs_idx = observation_map(graph_copy, graph_idx, n_obs)        \n",
    "            idx_list.append(node_obs_idx)\n",
    "\n",
    "        # Concatenate selected indices\n",
    "        idx = torch.cat(idx_list, dim=0).to(device)\n",
    "\n",
    "        # # Zero out features\n",
    "        # graph_copy.x.zero_()\n",
    "\n",
    "        # # Copy and add noise\n",
    "        # y = graph.x[idx, :1].clone()\n",
    "        # y += torch.randn_like(y) * sigma_tc\n",
    "\n",
    "        # # Place noisy observations\n",
    "        # graph_copy.x = graph_copy.x[:,:1]\n",
    "        # graph_copy.x[idx] = y\n",
    "        \n",
    "                # Zero out features\n",
    "        graph_copy.x.zero_()\n",
    "        # graph_copy.x = torch.cat((graph_copy.x, graph_copy.x), dim=1)\n",
    "        graph_copy.x = torch.cat((graph_copy.x[:,:1], graph_copy.x[:,:1]), dim=1)\n",
    "\n",
    "\n",
    "        # Copy and add noise\n",
    "        y = graph.x[idx, 0].clone()\n",
    "        y += torch.randn_like(y) * sigma_tc\n",
    "\n",
    "        # Place noisy observations\n",
    "        graph_copy.x[idx, 0] = y\n",
    "        graph_copy.x[idx, 1] = 1\n",
    "\n",
    "        # Decode/predict with your GCN\n",
    "        u_decode = gcn(graph_copy)\n",
    "\n",
    "        true = graph.x\n",
    "        pred = u_decode\n",
    "\n",
    "        batch = graph.batch\n",
    "        num_graphs = batch.max().item() + 1\n",
    "\n",
    "\n",
    "        for i in range(num_graphs):\n",
    "            # Mask nodes belonging to graph i\n",
    "            mask = (batch == i)\n",
    "            true_i = true[mask]    # shape: [num_nodes_in_graph_i, 3]\n",
    "            pred_i = pred[mask]    # shape: [num_nodes_in_graph_i, 3]\n",
    "\n",
    "            # Overall MAE for graph i\n",
    "            mae_i = torch.mean(torch.abs(pred_i - true_i)).item()\n",
    "            mae_list.append(mae_i)\n",
    "\n",
    "            # Per-dimension MAE for graph i\n",
    "            mae_per_dim = torch.mean(torch.abs(pred_i - true_i), dim=0)  # shape: [3]\n",
    "            mae_dim_list.append(mae_per_dim.cpu().numpy())  # convert to numpy for easier aggregation\n",
    "\n",
    "# End timer\n",
    "total_time = time.time() - start_time\n",
    "\n",
    "# Convert per-dimension MAEs to numpy array: shape [num_graphs, 3]\n",
    "mae_dim_array = np.vstack(mae_dim_list)\n",
    "\n",
    "# Compute final stats\n",
    "mean_mae = np.mean(mae_list)\n",
    "std_mae = np.std(mae_list)\n",
    "\n",
    "# Compute mean and stddev per dimension\n",
    "mean_mae_per_dim = np.mean(mae_dim_array, axis=0)\n",
    "std_mae_per_dim = np.std(mae_dim_array, axis=0)\n",
    "\n",
    "print(f\"Mean MAE over {N_test} runs: {mean_mae:.6f}\")\n",
    "print(f\"Stddev of MAE: {std_mae:.6f}\")\n",
    "print(f\"Evaluation time: {total_time/N_test:.5f} seconds\")\n",
    "\n",
    "print(\"Per-dimension MAE statistics (p, vx, vy):\")\n",
    "print(f\"  Mean:    {mean_mae_per_dim}\")\n",
    "print(f\"  Stddev:  {std_mae_per_dim}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gabi",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
