{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import sys\n",
    "sys.path.append('../')\n",
    "# from utils import plot_mesh\n",
    "from dataset import CFDGraphsDataset\n",
    "from plotter import plot_mesh\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.tri as mtri\n",
    "import numpy as np\n",
    "    \n",
    "filename = 'airfoil_ABC.ipynb'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "dir = './models/airfoil_gabi_1/'\n",
    "dir_plt = dir+'plt/'\n",
    "os.makedirs(dir, exist_ok=True)\n",
    "os.makedirs(dir_plt, exist_ok=True)\n",
    "import os\n",
    "pwd = os.getcwd()\n",
    "print(pwd)\n",
    "os.system(f\"jupyter nbconvert {pwd}/{filename} --to python --output {pwd}/{dir}run_file.py\" )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# #assert plot_type in {'graph', 'pressure', 'velocity_mag', 'velocity_x', 'velocity_y', 'sdf', 'surface_pressure'}\n",
    "# for plot_type in ['graph']:\n",
    "#     for i in range(1):\n",
    "#         plot_mesh(dataset[i], plot_type, plot_predicted=False, show=False, add_farfield_info=True)\n",
    "# plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "with open(dir+'airfoil_thinned_train_1.pkl', 'rb') as f:\n",
    "    thin_dataset_1 = pickle.load(f)\n",
    "with open(dir+'airfoil_thinned_train_2.pkl', 'rb') as f:\n",
    "    thin_dataset_2 = pickle.load(f)\n",
    "    \n",
    "thin_dataset = thin_dataset_1 + thin_dataset_2\n",
    "\n",
    "print(len(thin_dataset))\n",
    "dataset_train = thin_dataset[:1000]\n",
    "dataset_test  = thin_dataset[1000:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(dataset_train[6])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# def plot_graph(graph, u, scatter_data = None, filename='graph'):\n",
    "#     pos = graph.pos.detach().cpu().numpy()\n",
    "#     face = graph.face.detach().cpu().numpy()\n",
    "#     contourf = plt.tricontourf(pos[:, 0], pos[:, 1], face.T, u , levels=20, cmap='viridis')\n",
    "#     # tri_plot = mtri.Triangulation(pos[:, 0], pos[:, 1], triangles=face.T)\n",
    "#     triang = mtri.Triangulation(pos[:, 0], pos[:, 1], face.T)\n",
    "#     plt.triplot(triang, color='black', alpha=0.1, linewidth=0.01)\n",
    "#     if scatter_data is not None:\n",
    "#         plt.scatter(scatter_data[0], scatter_data[1], c=scatter_data[2], s=10.,  cmap='viridis',\n",
    "#                     edgecolors='red', alpha=1)\n",
    "#     plt.colorbar(contourf)\n",
    "#     plt.tight_layout()\n",
    "#     plt.savefig(filename+'.pdf')\n",
    "    \n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.tri as mtri\n",
    "from matplotlib.colors import Normalize\n",
    "\n",
    "def plot_graph(graph, u, scatter_data=None, filename='graph'):\n",
    "    pos = graph.pos.detach().cpu().numpy()\n",
    "    face = graph.face.detach().cpu().numpy()\n",
    "    # u = u.detach().cpu().numpy()\n",
    "    \n",
    "    triang = mtri.Triangulation(pos[:, 0], pos[:, 1], face.T)\n",
    "\n",
    "    # Create shared normalization (min/max of u)\n",
    "    norm = Normalize(vmin=u.min(), vmax=u.max())\n",
    "\n",
    "    # Plot filled contour\n",
    "    contourf = plt.tricontourf(triang, u, levels=20, cmap='viridis', norm=norm)\n",
    "\n",
    "    # Plot triangle edges\n",
    "    plt.triplot(triang, color='black', alpha=0.05, linewidth=0.01)\n",
    "\n",
    "    # Overlay scatter points\n",
    "    if scatter_data is not None:\n",
    "        scatter_vals = scatter_data[2]\n",
    "        plt.scatter(\n",
    "            scatter_data[0], scatter_data[1],\n",
    "            c=scatter_vals,\n",
    "            cmap='viridis',\n",
    "            norm=norm,  # match color scale\n",
    "            s=10.,\n",
    "            edgecolors='red',\n",
    "            alpha=1\n",
    "        )\n",
    "\n",
    "    # Colorbar linked to the contour (and now scatter as well)\n",
    "    plt.colorbar(contourf)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(filename + '.pdf')\n",
    "    plt.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.nn import GCNConv, MLP\n",
    "from torch_geometric.loader import DataLoader\n",
    "from torch_geometric.nn import pool\n",
    "\n",
    "from torch.nn import ModuleList\n",
    "from torch_geometric.nn import GCNConv\n",
    "from torch_geometric.nn.dense import Linear\n",
    "\n",
    "device='cuda'\n",
    "\n",
    "dim_z = 100\n",
    "\n",
    "dim_u = 3\n",
    "dim_x = 2\n",
    "dim_node_features = dim_x + dim_u\n",
    "class GCN_E(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(GCN_E, self).__init__()\n",
    "\n",
    "        self.c = dim_z\n",
    "        self.convs = ModuleList([\n",
    "            GCNConv(dim_node_features*2, self.c),\n",
    "            GCNConv(self.c*2, self.c),\n",
    "            GCNConv(self.c*2, self.c),\n",
    "            GCNConv(self.c*2, self.c), \n",
    "            GCNConv(self.c*2, self.c),\n",
    "            GCNConv(self.c*2, self.c), \n",
    "            GCNConv(self.c*2, self.c),\n",
    "            GCNConv(self.c, dim_z),\n",
    "        ])\n",
    "\n",
    "    def forward(self, u, graph):\n",
    "        pos, edge_index, edge_attr = graph.pos, graph.edge_index, graph.edge_attr\n",
    "        x = torch.concat([u, pos], dim=1)\n",
    "        \n",
    "        for ctr, conv in enumerate(self.convs[:-1]):\n",
    "            x_global = pool.global_mean_pool(x, graph.batch)\n",
    "\n",
    "            global_x_expanded = x_global[graph.batch]  # [num_nodes, hidden_dim]\n",
    "            \n",
    "            x = torch.cat([x, global_x_expanded], dim=1)  # [num_nodes, hidden_dim * 2]\n",
    "        \n",
    "            x = F.silu(conv(x, edge_index, edge_weight=edge_attr))\n",
    "\n",
    "        x = self.convs[-1](x, edge_index, edge_weight=edge_attr)\n",
    "\n",
    "        x = pool.global_mean_pool(x, graph.batch)\n",
    "        \n",
    "        return x\n",
    "\n",
    "\n",
    "class GCN_D(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(GCN_D, self).__init__()\n",
    "        self.c = dim_z\n",
    "        self.convs = ModuleList([\n",
    "            GCNConv((dim_z + 2)*2, self.c),\n",
    "            GCNConv(self.c*2, self.c),\n",
    "            GCNConv(self.c*2, self.c),\n",
    "            GCNConv(self.c*2, self.c),  \n",
    "            GCNConv(self.c*2, self.c),\n",
    "            GCNConv(self.c*2, self.c),  \n",
    "            GCNConv(self.c*2, self.c),\n",
    "            GCNConv(self.c, dim_u),\n",
    "        ])\n",
    "        \n",
    "    def forward(self, z, graph):\n",
    "\n",
    "        z_expanded = z[graph.batch]\n",
    "\n",
    "        x = torch.cat([z_expanded,  graph.pos], dim=1)\n",
    "\n",
    "        edge_index, edge_attr = graph.edge_index, graph.edge_attr\n",
    "\n",
    "        for ctr, conv in enumerate(self.convs[:-1]):\n",
    "            x_global = pool.global_mean_pool(x, graph.batch)\n",
    "            global_x_expanded = x_global[graph.batch]  # [num_nodes, hidden_dim]\n",
    "            \n",
    "            x = torch.cat([x, global_x_expanded], dim=1)  # [num_nodes, hidden_dim * 2]\n",
    "        \n",
    "            x = F.silu(conv(x, edge_index, edge_weight=edge_attr))\n",
    "            \n",
    "        x = self.convs[-1](x, edge_index, edge_weight=edge_attr)\n",
    "\n",
    "        return x\n",
    "    \n",
    "gcnE = GCN_E().to(device)\n",
    "gcnD = GCN_D().to(device)\n",
    "\n",
    "from torch.optim.lr_scheduler import ExponentialLR\n",
    "optimizerE = torch.optim.Adam(gcnE.parameters(), lr=0.001)\n",
    "schedulerE = ExponentialLR(optimizerE, gamma=0.999)\n",
    "optimizerD = torch.optim.Adam(gcnD.parameters(), lr=0.001)\n",
    "schedulerD = ExponentialLR(optimizerD, gamma=0.999)\n",
    "\n",
    "\n",
    "from GABI.stat import MMDLoss\n",
    "mmd = MMDLoss()\n",
    "\n",
    "def loss_function(data):\n",
    "        z = gcnE(data.x, data)\n",
    "        data_clone = data.clone()\n",
    "        # data_clone.x = torch.zeros( ( data_clone.x.shape[0], dim_z+nfreq*2) ).to(device)\n",
    "        u = gcnD(z, data_clone)\n",
    "        lossL = torch.mean((data.x - u)**2.)\n",
    "        Xd = torch.randn_like(z).to(device)\n",
    "        LossD = mmd(z, Xd)\n",
    "        loss = lossL + LossD\n",
    "        return loss, (lossL, LossD)\n",
    "\n",
    "\n",
    "def count_parameters(model):\n",
    "    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "\n",
    "print(count_parameters(gcnE))\n",
    "print(count_parameters(gcnD))\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataset_train = [data.to(device) for data in dataset_train]\n",
    "loader_train = DataLoader(dataset_train, batch_size=100, shuffle=True)\n",
    "data_train = next(iter(loader_train))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(data_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "gcnE.train()\n",
    "gcnD.train()\n",
    "\n",
    "LOSS = []\n",
    "LOSS_D = []\n",
    "LOSS_L = []\n",
    "import time\n",
    "time_train_start = time.time()\n",
    "for epoch in range(10_000):\n",
    "    start_time_b = time.time()\n",
    "    for data in loader_train:\n",
    "        optimizerE.zero_grad()\n",
    "        optimizerD.zero_grad()\n",
    "        data.to(device)\n",
    "        loss, aux = loss_function(data)\n",
    "        lossL, LossD = aux\n",
    "        LOSS.append(loss.cpu().detach().numpy())\n",
    "        LOSS_D.append(LossD.cpu().detach().numpy())\n",
    "        LOSS_L.append(lossL.cpu().detach().numpy())\n",
    "        loss.backward()\n",
    "        optimizerE.step()\n",
    "        optimizerD.step()\n",
    "    schedulerE.step()\n",
    "    schedulerD.step()\n",
    "    print(f'time epoch = {time.time() - start_time_b:.3f}s', )\n",
    "        \n",
    "    print(\"Epoch: \", epoch, \"Loss: \", loss.item())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "time_train = time.time() - time_train_start"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gcnE.eval()\n",
    "gcnD.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "plt.rcParams.update({\n",
    "    \"text.usetex\": True,\n",
    "    'text.latex.preamble': r'\\usepackage{amsfonts, amsmath, amssymb}',\n",
    "    \"font.family\": \"sans-serif\",\n",
    "    \"font.sans-serif\": [\"Helvetica\"],\n",
    "    'axes.labelsize':   18,\n",
    "    'axes.titlesize':   18,\n",
    "    'xtick.labelsize' : 16,\n",
    "    'ytick.labelsize' : 16,\n",
    "          })\n",
    "# latex font definition\n",
    "plt.rc('legend',fontsize=14)\n",
    "plt.rc('text', usetex=True)\n",
    "plt.rc('font', **{'family':'serif','serif':['Computer Modern Roman']})\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.semilogy(LOSS)\n",
    "plt.grid()\n",
    "plt.show()\n",
    "plt.close()\n",
    "\n",
    "plt.semilogy(LOSS_L, label='Loss $\\mathsf{L}$')\n",
    "plt.semilogy(LOSS_D, label='Loss $\\mathsf{d}$')\n",
    "plt.grid()\n",
    "plt.legend()\n",
    "plt.savefig(dir_plt+ 'losses.pdf')\n",
    "plt.show()\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# SAVE MODELS\n",
    "torch.save(gcnE.state_dict(), dir+'gcnE.model')\n",
    "torch.save(optimizerE.state_dict(), dir+'gcnE.opt')\n",
    "torch.save(gcnD.state_dict(), dir+'gcnD.model')\n",
    "torch.save(optimizerD.state_dict(), dir+'gcnD.opt')\n",
    "\n",
    "\n",
    "loss_dict = {'LOSS': LOSS, 'LOSS_L': LOSS_L, 'LOSS_D': LOSS_D, 'time_train':time_train,}\n",
    "import pickle\n",
    "with open(dir+'loss_dict.pkl', 'wb') as f:\n",
    "    pickle.dump(loss_dict, f)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(dir+'loss_dict.pkl', 'rb') as f:\n",
    "    loss_dict = pickle.load(f)\n",
    "LOSS = loss_dict['LOSS']\n",
    "LOSS_L = loss_dict['LOSS_L']\n",
    "LOSS_D = loss_dict['LOSS_D']\n",
    "time_train = loss_dict['time_train']\n",
    "\n",
    "print('time_train:', time_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# LOAD MODELS\n",
    "gcnE.load_state_dict(torch.load(dir+'gcnE.model'))\n",
    "gcnE.eval()\n",
    "gcnD.load_state_dict(torch.load(dir+'gcnD.model'))\n",
    "gcnD.eval()\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = data_train.to(device)\n",
    "z = gcnE(data.x, data)\n",
    "z_hist = z.ravel().detach().cpu().numpy()\n",
    "plt.hist(z_hist, bins=30, density=True, label='Latent $z$ histogram')\n",
    "lin = np.linspace(-10, 10, 100)\n",
    "# plt.xlim(-5, 5)\n",
    "plt.plot(lin, np.exp(-lin**2/2)/np.sqrt(2*np.pi), label='Standard Gaussian')\n",
    "plt.legend()\n",
    "plt.savefig(dir_plt+ 'hist_latent.pdf')\n",
    "plt.show()\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loader_test = DataLoader(dataset_test, batch_size=1, shuffle=True)\n",
    "data_test = next(iter(loader_test))\n",
    "data_test = data_test.to(device)\n",
    "z = gcnE(data_test.x, data_test)\n",
    "u = gcnD(z, data_test)\n",
    "\n",
    "plot_graph(data_test, u[:,0].detach().cpu().numpy(), filename=dir_plt+'airf_recon_p')\n",
    "plt.close()\n",
    "plot_graph(data_test, data_test.x[:,0].detach().cpu().numpy(), filename=dir_plt+'airf_gt_p')\n",
    "plt.close()\n",
    "\n",
    "plot_graph(data_test, u[:,1].detach().cpu().numpy(), filename=dir_plt+'airf_recon_vx')\n",
    "plt.close()\n",
    "plot_graph(data_test, data_test.x[:,1].detach().cpu().numpy(), filename=dir_plt+'airf_gt_vx')\n",
    "plt.close()\n",
    "\n",
    "plot_graph(data_test, u[:,2].detach().cpu().numpy(), filename=dir_plt+'airf_recon_vy')\n",
    "plt.close()\n",
    "plot_graph(data_test, data_test.x[:,2].detach().cpu().numpy(), filename=dir_plt+'airf_gt_vy')\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_local_avg_physical_distance_vectorized(graph_batch, u_decode, flat_indices, radius):\n",
    "    \"\"\"\n",
    "    Vectorized computation of local average of u_decode using Euclidean distance.\n",
    "\n",
    "    Args:\n",
    "        graph_batch: PyG Batch or Data object with `.pos` attribute [N, 3].\n",
    "        u_decode: Tensor of shape [N, 3].\n",
    "        flat_indices: 1D tensor or list of node indices to compute average for.\n",
    "        radius: float — physical distance threshold.\n",
    "\n",
    "    Returns:\n",
    "        Tensor of shape [len(flat_indices), 3] — locally averaged u_decode values.\n",
    "    \"\"\"\n",
    "    pos = graph_batch.pos  # [N, 3]\n",
    "    u_decode = u_decode.to(pos.device)\n",
    "    pos = pos.to(u_decode.device)\n",
    "\n",
    "    flat_indices = torch.tensor(flat_indices, device=pos.device)  # [M]\n",
    "    center_pos = pos[flat_indices]  # [M, 3]\n",
    "\n",
    "    # Compute pairwise distances: [M, N]\n",
    "    dists = torch.norm(pos.unsqueeze(0) - center_pos.unsqueeze(1), dim=2)  # [M, N]\n",
    "\n",
    "    # Create neighbor mask: [M, N]\n",
    "    neighbor_mask = dists <= radius\n",
    "\n",
    "    # Masked mean\n",
    "    masked_u = u_decode.unsqueeze(0).expand(len(flat_indices), -1, -1)  # [M, N, 3]\n",
    "    neighbor_mask = neighbor_mask.unsqueeze(-1)  # [M, N, 1]\n",
    "    masked_u = masked_u * neighbor_mask  # zero out non-neighbors\n",
    "\n",
    "    # Count valid neighbors for averaging (avoid division by zero)\n",
    "    counts = neighbor_mask.sum(dim=1)  # [M, 1]\n",
    "    counts = counts.clamp(min=1)\n",
    "\n",
    "    avg_u = masked_u.sum(dim=1) / counts  # [M, 3]\n",
    "\n",
    "    return avg_u\n",
    "\n",
    "import torch\n",
    "\n",
    "def compute_local_avg_physical_distance(graph_batch, u_decode, flat_indices, radius):\n",
    "    \"\"\"\n",
    "    Compute local average of u_decode based on Euclidean distance in position space.\n",
    "\n",
    "    Args:\n",
    "        graph_batch: PyG Batch or Data object with `.pos` attribute [N, 3].\n",
    "        u_decode: Tensor of shape [N, 3] — node attributes to average.\n",
    "        flat_indices: 1D tensor or list of node indices to compute average for.\n",
    "        radius: float — physical distance threshold.\n",
    "\n",
    "    Returns:\n",
    "        Tensor of shape [len(flat_indices), 3] — locally averaged u_decode values.\n",
    "    \"\"\"\n",
    "    pos = graph_batch.pos  # [N, 3]\n",
    "    batch = graph_batch.batch  # [N]\n",
    "    u_decode = u_decode.to(pos.device)\n",
    "    flat_indices = torch.tensor(flat_indices, device=pos.device)\n",
    "\n",
    "    avg_values = []\n",
    "\n",
    "    for idx in flat_indices:\n",
    "        graph_id = batch[idx]  # get graph this node belongs to\n",
    "\n",
    "        # Mask nodes in same graph\n",
    "        same_graph_mask = batch == graph_id\n",
    "        pos_same_graph = pos[same_graph_mask]\n",
    "        u_same_graph = u_decode[same_graph_mask]\n",
    "        idx_in_graph = (same_graph_mask[:idx + 1]).sum() - 1  # convert global idx to local in subgraph\n",
    "\n",
    "        center_pos = pos[idx]\n",
    "        dists = torch.norm(pos_same_graph - center_pos, dim=1)  # distances in the same graph\n",
    "        neighbor_mask = dists <= radius\n",
    "\n",
    "        local_u = u_same_graph[neighbor_mask]\n",
    "        mean_value = local_u.mean(dim=0)\n",
    "        avg_values.append(mean_value)\n",
    "\n",
    "    return torch.stack(avg_values, dim=0)  # shape [len(flat_indices), 3]\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### ABC\n",
    "# idx_text = 5 #6\n",
    "idx_test = 5\n",
    "data_test_graph = dataset_test[idx_test].clone()\n",
    "loader_test_gbase = DataLoader([data_test_graph], batch_size=1)\n",
    "data_test_graph  = next(iter(loader_test_gbase))\n",
    "n_obs = 20\n",
    "# n_obs = 100\n",
    "sigma_tc = torch.FloatTensor([0.01])\n",
    "# sigma_tc *= 0.\n",
    "\n",
    "radius = 1e-2\n",
    "# radius = 1e-8\n",
    "# radius = 20.\n",
    "# ObsIdx = np.random.choice(range(data_test_graph.pos.shape[0]), size=(n_obs,), replace=False)\n",
    "# ObsIdx = torch.tensor(ObsIdx)\n",
    "# y_n = (data_test_graph.x[:,0]).reshape(-1,)[ObsIdx] + sigma_tc * torch.randn(ObsIdx.shape[0])\n",
    "# y_n = y_n.to(device)\n",
    "\n",
    "indices = np.where(data_test_graph.airf_nodes == 1)[0]\n",
    "left_most = indices[torch.argmin(data_test_graph.pos[indices, 0])]\n",
    "high_most = indices[torch.argmax(data_test_graph.pos[indices, 1])]\n",
    "low_most  = indices[torch.argmin(data_test_graph.pos[indices, 1])]\n",
    "# indices = np.where(data_test_graph.airf_nodes*0 == 0)[0]\n",
    "ObsIdx = np.random.choice(indices, size=n_obs-3, replace=False)\n",
    "ObsIdx = torch.tensor(ObsIdx)\n",
    "ObsIdx = torch.cat( (ObsIdx, torch.tensor([left_most,high_most, low_most])), dim=0)  # add the first node\n",
    "# y_n = compute_local_avg_physical_distance(\n",
    "#     data_test_graph,\n",
    "#     data_test_graph.x[:, 0],\n",
    "#     ObsIdx,\n",
    "#     radius\n",
    "# )\n",
    "y_n  = data_test_graph.x[ObsIdx, 0]\n",
    "y_n += sigma_tc * torch.randn_like(y_n)\n",
    "y_n = y_n.to(device)\n",
    "    \n",
    "batch_size = 500\n",
    "graph_test = dataset_test[idx_test]\n",
    "graph_test_batch = [graph_test.clone() for _ in range(batch_size)] \n",
    "loader_test = DataLoader(graph_test_batch, batch_size=batch_size)\n",
    "graph_test_batch_loaded  = next(iter(loader_test)).to(device)\n",
    " \n",
    "graph_offsets = torch.arange(batch_size) * graph_test.x.shape[0]  # [B]\n",
    "global_indices = graph_offsets[:, None] + ObsIdx[None, :]  # [B, K]\n",
    "flat_indices = global_indices.reshape(-1)  # [B*K]\n",
    "\n",
    "u_decode = []\n",
    "z_acc = []\n",
    "u_acc = []\n",
    "n_loops = 100\n",
    "\n",
    "norms_list   = torch.tensor([np.inf])\n",
    "z_prior_list = torch.tensor(np.zeros((1, dim_z))).to(device, dtype=torch.float)  # Initialize with zeros\n",
    "\n",
    "with torch.no_grad():\n",
    "    for i in range(n_loops):\n",
    "        z_prior = torch.randn(batch_size, dim_z).to(device)\n",
    "        u_decode = gcnD(z_prior, graph_test_batch_loaded)\n",
    "        \n",
    "        airfoil_pressure = u_decode[flat_indices, 0]  # [B*K]\n",
    "        # airfoil_pressure = compute_local_avg_physical_distance(\n",
    "        #     graph_test_batch_loaded,\n",
    "        #     u_decode[:, 0],\n",
    "        #     flat_indices,\n",
    "        #     radius\n",
    "        # )\n",
    "\n",
    "        y_selected = airfoil_pressure  # [B*K, F]\n",
    "        # y_selected *=0.\n",
    "        y_selected = y_selected.view(batch_size, n_obs)\n",
    "        y_selected = y_selected + sigma_tc.to(device) * torch.randn_like(y_selected).to(device) # [B, K]\n",
    "        norms = torch.norm(y_selected - y_n[None, :], dim=-1)\n",
    "        norms_list = torch.cat( (norms_list, norms.detach().cpu()), dim=0 )\n",
    "        z_prior_list = torch.cat( (z_prior_list, z_prior.detach()), dim=0 )\n",
    "    \n",
    "sorted_index = torch.argsort(norms_list[:], dim=0)  # [n_loops, B*K]\n",
    "z_abc   = z_prior_list[sorted_index[:100]]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loader_decode_ABC = DataLoader([graph_test], batch_size=1)\n",
    "graph_ABC_loaded = next(iter(loader_decode_ABC)).to(device)\n",
    "# z_abc = torch.vstack(z_acc)\n",
    "with torch.no_grad():\n",
    "    u_abc_decode = torch.vmap(gcnD, in_dims=(0, None))(z_abc[:, None, :], graph_ABC_loaded)\n",
    "# u_abc_decode.shape\n",
    "\n",
    "acc_ratio = z_abc.shape[0] / (batch_size * n_loops)\n",
    "print('acc_ratio = ', acc_ratio)\n",
    "print('num samples = ', z_abc.shape[0])\n",
    "\n",
    "u_min_norm = gcnD(z_prior_list[sorted_index[:1]], graph_ABC_loaded)\n",
    "u_min_norm = u_min_norm.cpu().detach().numpy()[:]\n",
    "\n",
    "print('best norm:', norms_list[sorted_index[0]])\n",
    "\n",
    "u_mean = np.mean(u_abc_decode.detach().cpu().numpy(), axis=0)\n",
    "u_std = np.std(u_abc_decode.detach().cpu().numpy(), axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loader_test = DataLoader(dataset_test, batch_size=1, shuffle=True)\n",
    "data_test = next(iter(loader_test))\n",
    "data_test = data_test.to(device)\n",
    "z = gcnE(data_test.x, data_test)\n",
    "u = gcnD(z, data_test)\n",
    "\n",
    "pos = graph_test.pos.detach().cpu().numpy()\n",
    "\n",
    "plot_graph(graph_test, u_min_norm[:,0], \n",
    "           scatter_data=(pos[ObsIdx, 0], pos[ObsIdx, 1], y_n.detach().cpu().numpy()\n",
    "                         ), filename=dir_plt+'airf_p_infer_mnorm')\n",
    "plt.close()\n",
    "plot_graph(graph_test, u_mean[:,0], \n",
    "           scatter_data=(pos[ObsIdx, 0], pos[ObsIdx, 1], y_n.detach().cpu().numpy()\n",
    "                         ), filename=dir_plt+'airf_p_infer_mean')\n",
    "plt.close()\n",
    "plot_graph(graph_test, u_std[:,0], \n",
    "           scatter_data=(pos[ObsIdx, 0], pos[ObsIdx, 1], y_n.detach().cpu().numpy()\n",
    "                         ), filename=dir_plt+'airf_p_infer_std')\n",
    "plt.close()\n",
    "\n",
    "plot_graph(graph_test, u_mean[:,0] - graph_test.x[:,0].detach().cpu().numpy(), \n",
    "           scatter_data=(pos[ObsIdx, 0], pos[ObsIdx, 1], y_n.detach().cpu().numpy()\n",
    "                         ), filename=dir_plt+'airf_p_infer_error')\n",
    "plt.close()\n",
    "\n",
    "plot_graph(graph_test, graph_test.x[:,0].detach().cpu().numpy(),\n",
    "            scatter_data=(pos[ObsIdx, 0], pos[ObsIdx, 1], y_n.detach().cpu().numpy()),\n",
    "           filename=dir_plt+'airf_p_gt_test')\n",
    "plt.close()\n",
    "\n",
    "\n",
    "\n",
    "plot_graph(graph_test, u_min_norm[:,1], filename=dir_plt+'airf_vx_infer_mnorm')\n",
    "plt.close()\n",
    "plot_graph(graph_test, u_mean[:,1], filename=dir_plt+'airf_vx_infer_mean')\n",
    "plt.close()\n",
    "plot_graph(graph_test, u_std[:,1], filename=dir_plt+'airf_vx_infer_std')\n",
    "plt.close()\n",
    "plot_graph(graph_test,u_mean[:,1] - graph_test.x[:,1].detach().cpu().numpy()\n",
    "           , filename=dir_plt+'airf_vx_infer_error')\n",
    "plt.close()\n",
    "plot_graph(graph_test, graph_test.x[:,1].detach().cpu().numpy(),\n",
    "           filename=dir_plt+'airf_vx_gt_test')\n",
    "plt.close()\n",
    "\n",
    "plot_graph(graph_test, u_min_norm[:,2], filename=dir_plt+'airf_vy_infer_mnorm')\n",
    "plt.close()\n",
    "plot_graph(graph_test, u_mean[:,2], filename=dir_plt+'airf_vy_infer_mean')\n",
    "plt.close()\n",
    "plot_graph(graph_test, u_std[:,2], filename=dir_plt+'airf_vy_infer_std')\n",
    "plt.close()\n",
    "plot_graph(graph_test,u_mean[:,2] - graph_test.x[:,2].detach().cpu().numpy()\n",
    "           , filename=dir_plt+'airf_vy_infer_error')\n",
    "plt.close()\n",
    "plot_graph(graph_test, graph_test.x[:,2].detach().cpu().numpy(),\n",
    "           filename=dir_plt+'airf_vy_gt_test')\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(1,5):\n",
    "    z_rnd = torch.randn(1, dim_z).to(device)\n",
    "    u_prior_decode = gcnD(z_rnd, graph_ABC_loaded)\n",
    "    \n",
    "    plot_graph(graph_test, u_prior_decode[:,0].detach().cpu().numpy(),\n",
    "            filename=dir_plt+f'airf_p_prior_s{i}_test')\n",
    "    plt.close()\n",
    "    plot_graph(graph_test, u_prior_decode[:,1].detach().cpu().numpy(),\n",
    "            filename=dir_plt+f'airf_vx_prior_s{i}_test')\n",
    "    plt.close()\n",
    "    plot_graph(graph_test, u_prior_decode[:,2].detach().cpu().numpy(),\n",
    "            filename=dir_plt+f'airf_vy_prior_s{i}_test')\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(ObsIdx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# u_pred = u_mode[:,0].cpu().detach().numpy()\n",
    "u_pred = u_mean\n",
    "# u_pred = u_min_norm\n",
    "\n",
    "u_true =  graph_test.x[:].cpu().detach().numpy()\n",
    "\n",
    "for i, var in enumerate(['p', 'vx', 'vy']):\n",
    "    MAE = np.mean( np.abs(u_pred[:,i] - u_true[:,i]) )\n",
    "    prec_in_1std = np.mean( np.abs(u_pred[:,i] - u_true[:,i]) < u_std[:,i] )\n",
    "    prec_in_2std = np.mean( np.abs(u_pred[:,i] - u_true[:,i]) < 2*u_std[:,i] )\n",
    "    print(f'Variable {var}:')\n",
    "    print(f'MAE = {MAE:.4f}')\n",
    "    print(f'prec_in_1std = {prec_in_1std:.4f}')\n",
    "    print(f'prec_in_2std = {prec_in_2std:.4f}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizerE.zero_grad()\n",
    "optimizerD.zero_grad()\n",
    "\n",
    "gcnD.zero_grad()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def observation_map(graph, n_obs):\n",
    "    \n",
    "    indices = torch.where(graph.airf_nodes == 1)[0]\n",
    "    # mask_graph\n",
    "    left_most = indices[torch.argmin(graph.pos[indices, 0])]\n",
    "    high_most = indices[torch.argmax(graph.pos[indices, 1])]\n",
    "    low_most  = indices[torch.argmin(graph.pos[indices, 1])]\n",
    "    # indices = np.where(data_test_graph.airf_nodes*0 == 0)[0]\n",
    "    # ObsIdx = np.random.choice(indices, size=n_obs-3, replace=False)\n",
    "    perm = torch.randperm(indices.size(0), device=graph.x.device)[:n_obs-3]\n",
    "    ObsIdx = indices[perm]\n",
    "    ObsIdx = torch.tensor(ObsIdx)\n",
    "    ObsIdx = torch.cat( (ObsIdx, torch.tensor([left_most,high_most, low_most]).to(device)), dim=0)  # add the first node\n",
    "    \n",
    "    return ObsIdx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "\n",
    "sigma = 0.01\n",
    "n_obs = 20\n",
    "sigma_tc = torch.FloatTensor([sigma]).to(device)\n",
    "\n",
    "NOBS = np.arange(5,50)\n",
    "norm_nobs = (1./NOBS) / (np.sum(1./NOBS))  # normalize to sum to 1\n",
    "def get_rand_nobs():\n",
    "    return np.random.choice(NOBS, p=norm_nobs, size=1)[0]\n",
    "    \n",
    "def run_pred(data_test_graph):\n",
    "    \n",
    "    # data_test_graph = data_test_graph.to(device)\n",
    "\n",
    "    # indices = np.where(data_test_graph.airf_nodes == 1)[0]\n",
    "    # left_most = indices[torch.argmin(data_test_graph.pos[indices, 0])]\n",
    "    # high_most = indices[torch.argmax(data_test_graph.pos[indices, 1])]\n",
    "    # low_most  = indices[torch.argmin(data_test_graph.pos[indices, 1])]\n",
    "    # # indices = np.where(data_test_graph.airf_nodes*0 == 0)[0]\n",
    "    # ObsIdx = np.random.choice(indices, size=n_obs-3, replace=False)\n",
    "    # ObsIdx = torch.tensor(ObsIdx)\n",
    "    # ObsIdx = torch.cat( (ObsIdx, torch.tensor([left_most,high_most, low_most])), dim=0)  # add the first node\n",
    "    data_test_graph.to(device)\n",
    "    n_obs = get_rand_nobs()\n",
    "    print('Number of observations:', n_obs)\n",
    "    ObsIdx = observation_map(data_test_graph, n_obs)\n",
    "    \n",
    "    y_n  = data_test_graph.x[ObsIdx, 0]\n",
    "    y_n += sigma_tc * torch.randn_like(y_n)\n",
    "    y_n = y_n.to(device)\n",
    "        \n",
    "        \n",
    "    time_start = time.time()\n",
    "    batch_size = 500\n",
    "    graph_test = data_test_graph\n",
    "    graph_test_batch = [graph_test.clone() for _ in range(batch_size)] \n",
    "    loader_test = DataLoader(graph_test_batch, batch_size=batch_size)\n",
    "    graph_test_batch_loaded  = next(iter(loader_test)).to(device)\n",
    "    graph_test_batch_loaded.require_grad = False\n",
    "    \n",
    "    graph_offsets = torch.arange(batch_size).to(device) * graph_test.x.shape[0]  # [B]\n",
    "    global_indices = graph_offsets[:, None] + ObsIdx[None, :]  # [B, K]\n",
    "    flat_indices = global_indices.reshape(-1)  # [B*K]\n",
    "\n",
    "    n_loops = 100\n",
    "    \n",
    "    n_acc = 100\n",
    "\n",
    "    norms_list   = torch.tensor([np.inf])\n",
    "    z_prior_list = torch.tensor(np.zeros((1, dim_z))).to(device, dtype=torch.float)  # Initialize with zeros\n",
    "\n",
    "    for i in range(n_loops):\n",
    "        z_prior = torch.randn(batch_size, dim_z).to(device)\n",
    "        u_decode = gcnD(z_prior, graph_test_batch_loaded)\n",
    "        \n",
    "        airfoil_pressure = u_decode[flat_indices, 0]  # [B*K]\n",
    "        y_selected = airfoil_pressure  # [B*K, F]\n",
    "        \n",
    "        # y_selected = u_decode[flat_indices]  # [B*K, F]\n",
    "        y_selected = y_selected.view(batch_size, n_obs)\n",
    "        y_selected = y_selected + sigma_tc.to(device) * torch.randn_like(y_selected).to(device) # [B, K]\n",
    "        norms = torch.norm(y_selected - y_n[None, :], dim=-1)\n",
    "        norms_list = torch.cat( (norms_list, norms.detach().cpu()), dim=0 )\n",
    "        z_prior_list = torch.cat( (z_prior_list, z_prior.detach()), dim=0 )\n",
    "        \n",
    "    sorted_index = torch.argsort(norms_list[:], dim=0)  # [n_loops, B*K]\n",
    "    z_abc   = z_prior_list[sorted_index[:n_acc]]\n",
    "            \n",
    "    # print(z_abc.shape)\n",
    "\n",
    "    loader_decode_ABC = DataLoader([graph_test], batch_size=1)\n",
    "    graph_ABC_loaded = next(iter(loader_decode_ABC)).to(device)\n",
    "    # z_abc = torch.vstack(z_acc)\n",
    "    u_abc_decode = torch.vmap(gcnD, in_dims=(0, None))(z_abc[:, None, :], graph_ABC_loaded)\n",
    "\n",
    "        \n",
    "    samples_decode_np = u_abc_decode.cpu().detach().numpy() #[:, :, :]\n",
    "    \n",
    "    u_mean = np.mean(samples_decode_np, axis=0)\n",
    "    u_std = np.std(samples_decode_np, axis=0)\n",
    "    \n",
    "    u_min_norm = gcnD( z_prior_list[sorted_index[:1]], graph_ABC_loaded)\n",
    "    u_min_norm = u_min_norm.cpu().detach().numpy()[:, 0]\n",
    "\n",
    "    \n",
    "    u_pred = u_mean\n",
    "    # u_pred = u_min_norm\n",
    "\n",
    "    u_true =  graph_test.x[:].cpu().detach().numpy()\n",
    "\n",
    "    MAE = []\n",
    "    MaxAE = []\n",
    "    prec_in_1std = []\n",
    "    prec_in_2std = []\n",
    "    for i, var in enumerate(['p', 'vx', 'vy']):\n",
    "        MAE.append(np.mean( np.abs(u_pred[:,i] - u_true[:,i]) ))\n",
    "        MaxAE.append(np.max( np.abs(u_pred[:,i] - u_true[:,i]) ))\n",
    "        prec_in_1std.append(np.mean( np.abs(u_pred[:,i] - u_true[:,i]) < u_std[:,i] ))\n",
    "        prec_in_2std.append(np.mean( np.abs(u_pred[:,i] - u_true[:,i]) < 2*u_std[:,i] ))\n",
    "        print(f'Variable {var}:')\n",
    "        print(f'MAE = {MAE[i]:.4f}')\n",
    "        print(f'MaxAE = {MaxAE[i]:.4f}')\n",
    "        print(f'prec_in_1std = {prec_in_1std[i]:.4f}')\n",
    "        print(f'prec_in_2std = {prec_in_2std[i]:.4f}')\n",
    "    \n",
    "    time_taken = time.time() - time_start\n",
    "    print('time_taken:', time_taken)\n",
    "\n",
    "    u_true =  data_test_graph.x[:,:].cpu().detach().numpy()\n",
    "    \n",
    "    return MAE, MaxAE, prec_in_1std, prec_in_2std, time_taken\n",
    "\n",
    "    \n",
    "# run_pred()\n",
    "\n",
    "variables = ['p', 'vx', 'vy']\n",
    "\n",
    "# Initialize dictionaries for each metric, keyed by variable name\n",
    "pred_stats = {\n",
    "    'MAE': {var: [] for var in variables},\n",
    "    'MaxAE': {var: [] for var in variables},\n",
    "    'prec_in_1std': {var: [] for var in variables},\n",
    "    'prec_in_2std': {var: [] for var in variables},\n",
    "    'time_taken': []\n",
    "}\n",
    "\n",
    "N_test = 100\n",
    "for i in range(N_test):\n",
    "    with torch.no_grad():\n",
    "        print(f'Running prediction for test data {i+1}/{N_test}')\n",
    "        # These should return a list/array of 3 values for each metric (one for each variable)\n",
    "        MAE, MaxAE, prec_in_1std, prec_in_2std, time_taken = run_pred(dataset_test[i])\n",
    "    \n",
    "    for j, var in enumerate(variables):\n",
    "        pred_stats['MAE'][var].append(MAE[j])\n",
    "        pred_stats['MaxAE'][var].append(MaxAE[j])\n",
    "        pred_stats['prec_in_1std'][var].append(prec_in_1std[j])\n",
    "        pred_stats['prec_in_2std'][var].append(prec_in_2std[j])\n",
    "    \n",
    "    pred_stats['time_taken'].append(time_taken)\n",
    "\n",
    "print('### Total Results ###')\n",
    "# Convert to numpy arrays\n",
    "for metric in ['MAE', 'MaxAE', 'prec_in_1std', 'prec_in_2std']:\n",
    "    for var in variables:\n",
    "        pred_stats[metric][var] = np.array(pred_stats[metric][var])\n",
    "pred_stats['time_taken'] = np.array(pred_stats['time_taken'])\n",
    "\n",
    "test_score_dict = {\n",
    "    'MAE': {var: [] for var in variables},\n",
    "    'MaxAE': {var: [] for var in variables},\n",
    "    'prec_in_1std': {var: [] for var in variables},\n",
    "    'prec_in_2std': {var: [] for var in variables},\n",
    "    'time': []\n",
    "}\n",
    "\n",
    "# Print results\n",
    "for metric in ['MAE', 'MaxAE', 'prec_in_1std', 'prec_in_2std']:\n",
    "    for var in variables:\n",
    "        mean_val = np.mean(pred_stats[metric][var])\n",
    "        std_val = np.std(pred_stats[metric][var])\n",
    "        test_score_dict[metric][var] = (mean_val, std_val)\n",
    "        print(f\"{metric} ({var}): {mean_val:.4f} ± {std_val:.4f}\")\n",
    "\n",
    "mean_time = np.mean(pred_stats['time_taken'])\n",
    "std_time = np.std(pred_stats['time_taken'])\n",
    "print(f\"time_taken: {mean_time:.4f} ± {std_time:.4f}\")\n",
    "test_score_dict['time'] = (mean_time, std_time)\n",
    "\n",
    "import pickle\n",
    "with open(dir+'test_score_dict.pkl', 'wb') as f:\n",
    "    pickle.dump(test_score_dict, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gabi",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
