{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import sys\n",
    "sys.path.append('../')\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.tri as mtri\n",
    "import numpy as np\n",
    "    \n",
    "filename = 'Car_GABI_uf.ipynb'\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "device='cuda'\n",
    "TRAIN = False\n",
    "\n",
    "# device='cpu'\n",
    "# TRAIN = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "# dir_data  = './data/helm_car_multforce_easy/'\n",
    "# dir_model = './model/helm_car_gabi_multiforce_easy/'\n",
    "dir_data  = './data/helm_car_frontForce_hdamp/'\n",
    "dir_model = './model/helm_car_frontForce_hdamp_uf/'\n",
    "\n",
    "dir_plt = dir_model+'plt/'\n",
    "os.makedirs(dir_model, exist_ok=True)\n",
    "os.makedirs(dir_plt, exist_ok=True)\n",
    "import os\n",
    "pwd = os.getcwd()\n",
    "print(pwd)\n",
    "os.system(f\"jupyter nbconvert {pwd}/{filename} --to python --output {pwd}/{dir_model}run_file.py\" )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "with open(dir_data+f'graphs.pkl', 'rb') as f:\n",
    "    dataset = pickle.load(f)\n",
    "print(len(dataset))\n",
    "\n",
    "# Collect all data into one array for scaling\n",
    "press_p = np.zeros((1,))\n",
    "\n",
    "for i in range(len(dataset)):\n",
    "    press_p = np.concatenate((dataset[i].x.reshape(-1,), press_p), axis=0) \n",
    "\n",
    "# Compute min and max\n",
    "press_p_min = np.min(press_p)\n",
    "press_p_max = np.max(press_p)\n",
    "press_p_mean = np.mean(press_p)\n",
    "press_p_std = np.std(press_p)\n",
    "\n",
    "print('press_p_min', press_p_min)\n",
    "print('press_p_max', press_p_max)\n",
    "\n",
    "# Apply min-max scaling to each sample\n",
    "for i in range(len(dataset)):\n",
    "    # scaled = (dataset[i].x - press_p_min) / (press_p_max - press_p_min)\n",
    "    scaled = (dataset[i].x - press_p_mean) / press_p_std\n",
    "\n",
    "    dataset[i].x = torch.cat((torch.tensor(scaled).to(torch.float)[:, None], dataset[i].f.to(torch.float)[:, None]), dim=1)\n",
    "    # dataset[i].x = (dataset[i].x).to(torch.float)[:, None]\n",
    "\n",
    "\n",
    "    \n",
    "# press_p = np.zeros((1,))\n",
    "# for i in range(len(dataset)):\n",
    "#     press_p = np.concat((dataset[i].x.reshape(-1,), press_p), axis=0) \n",
    "    \n",
    "# press_p_mean = np.mean(np.array(press_p).reshape(-1,))\n",
    "# press_p_std  = np.std(np.array(press_p).reshape(-1,))\n",
    "# print('press_p_mean', press_p_mean)\n",
    "# print('press_p_std', press_p_std)\n",
    "\n",
    "# dataset_train = dataset[:100].copy()\n",
    "print(dataset[0].x.shape)   \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataset_train = dataset[:6000:5].copy()\n",
    "# dataset_test  = dataset[-6000::5].copy()\n",
    "\n",
    "# dataset_train = dataset[:5000:5].copy()\n",
    "# dataset_test  = dataset[5000::].copy()\n",
    "\n",
    "dataset_train = dataset[:500:].copy()\n",
    "dataset_test  = dataset[500::].copy()\n",
    "\n",
    "print('len dataset_train', len(dataset_train))\n",
    "print('len dataset_test', len(dataset_test))\n",
    "del dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_train[0]\n",
    "dataset_train[0].edge_attr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pyvista as pv\n",
    "\n",
    "# Enable notebook plotting (use 'static', 'client', 'panel', or 'pythreejs')\n",
    "pv.set_jupyter_backend('static')\n",
    "\n",
    "def plot_graph(graph, u, filename=None):\n",
    "    vertices = graph.pos.cpu().numpy()\n",
    "    faces = graph.face.cpu().numpy().T  # shape (n_faces, 3)\n",
    "\n",
    "    # PyVista expects faces as a flat array: [3, v0, v1, v2, 3, v0, v1, v2, ...]\n",
    "    faces_flat = np.hstack([np.full((faces.shape[0],1), 3), faces]).astype(np.int64).flatten()\n",
    "\n",
    "    # Create PolyData\n",
    "    mesh = pv.PolyData(vertices, faces_flat)\n",
    "    \n",
    "    mesh.point_data['Pressure'] = u\n",
    "    # mesh.point_data['helmholtz_forcing'] = graph.f.detach().cpu().numpy()\n",
    "    # Rotate the mesh if needed\n",
    "    # mesh.rotate_x(90)\n",
    "\n",
    "    # Create a plotter\n",
    "    p = pv.Plotter(notebook=True)\n",
    "\n",
    "    # Add mesh: choose colormap and colorbar location\n",
    "    p.add_mesh(\n",
    "        mesh, \n",
    "        scalars='Pressure', \n",
    "        # scalars='helmholtz_forcing', \n",
    "        cmap='coolwarm', \n",
    "        show_scalar_bar=True,\n",
    "        scalar_bar_args={\n",
    "            'title': 'Pressure',\n",
    "            # 'title': 'Helmholtz Forcing',\n",
    "            'vertical': True,      # True = vertical colorbar; False = horizontal\n",
    "            'position_x': 0.1,    # X position in normalized [0,1]\n",
    "            'position_y': 0.1,     # Y position in normalized [0,1]\n",
    "            # 'position_x': 0.85,    # X position in normalized [0,1]\n",
    "            # 'position_y': 0.1,     # Y position in normalized [0,1]\n",
    "            'width': 0.05,\n",
    "            'height': 0.8\n",
    "        }\n",
    "    )\n",
    "    # p.show_axes()\n",
    "    p.show_grid()\n",
    "    p.camera_position = [np.array([2, 3, -2]), (0.0, 0.0, 0.0), (0, 0, 0)]\n",
    "    # p.show_axes()\n",
    "    p.show()\n",
    "    # p.save_graphic(dir_plt + 'graph_plot.pdf')\n",
    "idx = np.random.randint(0, len(dataset_train))\n",
    "plot_graph(dataset_train[idx], dataset_train[idx].x[:,0], filename=dir_plt+'recon')\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.nn import GCNConv, MLP\n",
    "from torch_geometric.loader import DataLoader\n",
    "from torch_geometric.nn import pool\n",
    "\n",
    "from torch.nn import ModuleList\n",
    "from torch_geometric.nn import GCNConv\n",
    "from torch_geometric.nn.dense import Linear\n",
    "\n",
    "# loader = DataLoader(data_list, batch_size=100, shuffle=True)\n",
    "# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "# data = next(iter(loader)).to(device)\n",
    "# print(data.num_node_features)\n",
    "\n",
    "# dim_z = 250\n",
    "dim_z = 100\n",
    "# dim_z = 50\n",
    "\n",
    "dim_u = 2\n",
    "dim_x = 3\n",
    "dim_node_features = dim_x + dim_u\n",
    "class GCN_E(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(GCN_E, self).__init__()\n",
    "\n",
    "        self.c = dim_z\n",
    "        self.convs = ModuleList([\n",
    "            GCNConv(dim_node_features*2, self.c),\n",
    "            GCNConv(self.c*2, self.c),\n",
    "            GCNConv(self.c*2, self.c),\n",
    "            # GCNConv(self.c*2, self.c), \n",
    "            # GCNConv(self.c*2, self.c),\n",
    "            GCNConv(self.c*2, self.c), \n",
    "            GCNConv(self.c*2, self.c),\n",
    "            GCNConv(self.c, dim_z),\n",
    "        ])\n",
    "\n",
    "    def forward(self, u, graph):\n",
    "        pos, edge_index, edge_attr = graph.pos, graph.edge_index, graph.edge_attr\n",
    "        x = torch.concat([u, pos], dim=1)\n",
    "        \n",
    "        for ctr, conv in enumerate(self.convs[:-1]):\n",
    "            x_global = pool.global_mean_pool(x, graph.batch)\n",
    "\n",
    "            global_x_expanded = x_global[graph.batch]  # [num_nodes, hidden_dim]\n",
    "            \n",
    "            x = torch.cat([x, global_x_expanded], dim=1)  # [num_nodes, hidden_dim * 2]\n",
    "        \n",
    "            x = F.silu(conv(x, edge_index, edge_weight=edge_attr))\n",
    "\n",
    "        x = self.convs[-1](x, edge_index, edge_weight=edge_attr)\n",
    "\n",
    "        x = pool.global_mean_pool(x, graph.batch)\n",
    "        \n",
    "        return x\n",
    "\n",
    "\n",
    "class GCN_D(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(GCN_D, self).__init__()\n",
    "        self.c = dim_z\n",
    "        self.convs = ModuleList([\n",
    "            GCNConv((dim_z + 3)*2, self.c),\n",
    "            GCNConv(self.c*2, self.c),\n",
    "            GCNConv(self.c*2, self.c),\n",
    "            # GCNConv(self.c*2, self.c),  \n",
    "            # GCNConv(self.c*2, self.c),\n",
    "            GCNConv(self.c*2, self.c),  \n",
    "            GCNConv(self.c*2, self.c),\n",
    "            GCNConv(self.c, dim_u),\n",
    "        ])\n",
    "        \n",
    "    def forward(self, z, graph):\n",
    "\n",
    "        z_expanded = z[graph.batch]\n",
    "\n",
    "        x = torch.cat([z_expanded,  graph.pos], dim=1)\n",
    "\n",
    "        edge_index, edge_attr = graph.edge_index, graph.edge_attr\n",
    "\n",
    "        for ctr, conv in enumerate(self.convs[:-1]):\n",
    "            x_global = pool.global_mean_pool(x, graph.batch)\n",
    "            global_x_expanded = x_global[graph.batch]  # [num_nodes, hidden_dim]\n",
    "            \n",
    "            x = torch.cat([x, global_x_expanded], dim=1)  # [num_nodes, hidden_dim * 2]\n",
    "        \n",
    "            x = F.silu(conv(x, edge_index, edge_weight=edge_attr))\n",
    "            \n",
    "        x = self.convs[-1](x, edge_index, edge_weight=edge_attr)\n",
    "\n",
    "        return x\n",
    "    \n",
    "gcnE = GCN_E().to(device)\n",
    "gcnD = GCN_D().to(device)\n",
    "\n",
    "# # from torch.optim.lr_scheduler import ExponentialLR\n",
    "# optimizerE = torch.optim.AdamW(gcnE.parameters(), lr=0.0005)\n",
    "# # schedulerE = ExponentialLR(optimizerE, gamma=0.95)\n",
    "# optimizerD = torch.optim.AdamW(gcnD.parameters(), lr=0.0005)\n",
    "# # schedulerD = ExponentialLR(optimizerD, gamma=0.95)\n",
    "\n",
    "# n_epochs = 10_000\n",
    "n_epochs = 10_000\n",
    "\n",
    "\n",
    "from torch.optim.lr_scheduler import ExponentialLR, CosineAnnealingLR\n",
    "optimizerE = torch.optim.AdamW(gcnE.parameters(), lr=0.001)\n",
    "schedulerE = CosineAnnealingLR(optimizerE, T_max=n_epochs, eta_min=1e-5)\n",
    "optimizerD = torch.optim.AdamW(gcnD.parameters(), lr=0.001)\n",
    "schedulerD = CosineAnnealingLR(optimizerD, T_max=n_epochs, eta_min=1e-5)\n",
    "\n",
    "\n",
    "\n",
    "from GABI.stat import MMDLoss\n",
    "mmd = MMDLoss()\n",
    "\n",
    "def loss_function(data):\n",
    "        z = gcnE(data.x, data)\n",
    "        # data_clone = data.clone()\n",
    "        # data_clone.x = torch.zeros( ( data_clone.x.shape[0], dim_z+nfreq*2) ).to(device)\n",
    "        u = gcnD(z, data)\n",
    "        lossL = torch.mean((data.x - u)**2.)\n",
    "        Xd = torch.randn_like(z).to(device)\n",
    "        LossD = mmd(z, Xd)\n",
    "        loss = lossL + LossD\n",
    "        return loss, (lossL, LossD)\n",
    "\n",
    "\n",
    "def count_parameters(model):\n",
    "    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "\n",
    "print(count_parameters(gcnE))\n",
    "print(count_parameters(gcnD))\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataset_train = [data.to(device) for data in dataset_train]\n",
    "# loader_train = DataLoader(dataset_train, batch_size=20, shuffle=True)\n",
    "loader_train = DataLoader(dataset_train, batch_size=100, shuffle=True)\n",
    "print(dataset_train[0])\n",
    "data_train = next(iter(loader_train))\n",
    "print(data_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# data_train.to(device)\n",
    "# z = gcnE(data_train.x, data_train)\n",
    "# u = gcnD(z, data_train)\n",
    "\n",
    "# print(z.shape)\n",
    "# print(u.shape)\n",
    "# print(data_train.x.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from IPython.display import clear_output\n",
    "\n",
    "if TRAIN:\n",
    "    print(\"Training mode: True\")\n",
    "    gcnE.train()\n",
    "    gcnD.train()\n",
    "\n",
    "    LOSS = []\n",
    "    LOSS_D = []\n",
    "    LOSS_L = []\n",
    "    import time\n",
    "    time_train_start = time.time()\n",
    "    for epoch in range(n_epochs):\n",
    "        start_time_b = time.time()\n",
    "        for graph in loader_train:\n",
    "            optimizerE.zero_grad()\n",
    "            optimizerD.zero_grad()\n",
    "            graph.to(device)\n",
    "            loss, aux = loss_function(graph)\n",
    "            lossL, LossD = aux\n",
    "            loss.backward()\n",
    "            optimizerE.step()\n",
    "            optimizerD.step()\n",
    "        LOSS.append(loss.item())\n",
    "        LOSS_D.append(LossD.item())\n",
    "        LOSS_L.append(lossL.item())\n",
    "        schedulerE.step()\n",
    "        schedulerD.step()\n",
    "\n",
    "        \n",
    "        # clear_output(wait=True)\n",
    "        \n",
    "        # plt.semilogy(LOSS_L, label='Loss $\\mathsf{L}$')\n",
    "        # plt.semilogy(LOSS_D, label='Loss $\\mathsf{d}$')\n",
    "        # plt.grid()\n",
    "        # plt.legend()\n",
    "        # plt.show()\n",
    "        \n",
    "        print(f'time epoch = {time.time() - start_time_b:.3f}s', )\n",
    "        print(schedulerE.get_last_lr()) # will print last learning rate.\n",
    "        print(schedulerD.get_last_lr()) # will print last learning rate.\n",
    "\n",
    "        if epoch % 1000 == 0:\n",
    "            print(\"## SAVING MODEL ##\")\n",
    "            torch.save(gcnE.state_dict(), dir_model+f'gcnE_{epoch}.model')\n",
    "            torch.save(optimizerE.state_dict(), dir_model+f'gcnE_{epoch}.opt')\n",
    "            torch.save(gcnD.state_dict(), dir_model+f'gcnD_{epoch}.model')\n",
    "            torch.save(optimizerD.state_dict(), dir_model+f'gcnD_{epoch}.opt')\n",
    "            loss_dict = {'LOSS': LOSS, 'LOSS_L': LOSS_L, 'LOSS_D': LOSS_D,\n",
    "                        'time_train':time.time()-time_train_start,}\n",
    "            with open(dir_model+f'loss_dict.pkl', 'wb') as f:\n",
    "                pickle.dump(loss_dict, f)\n",
    "\n",
    "            \n",
    "        print(\"Epoch: \", epoch, \"Loss: \", loss.item())\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "if TRAIN:\n",
    "\n",
    "    time_train = time.time() - time_train_start"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "if TRAIN:\n",
    "\n",
    "    gcnE.eval()\n",
    "    gcnD.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# SAVE MODELS\n",
    "if TRAIN:\n",
    "    torch.save(gcnE.state_dict(), dir_model+'gcnE.model')\n",
    "    torch.save(optimizerE.state_dict(), dir_model+'gcnE.opt')\n",
    "    torch.save(gcnD.state_dict(), dir_model+'gcnD.model')\n",
    "    torch.save(optimizerD.state_dict(), dir_model+'gcnD.opt')\n",
    "\n",
    "\n",
    "    loss_dict = {'LOSS': LOSS, 'LOSS_L': LOSS_L, 'LOSS_D': LOSS_D, 'time_train':time_train,}\n",
    "    import pickle\n",
    "    with open(dir_model+'loss_dict.pkl', 'wb') as f:\n",
    "        pickle.dump(loss_dict, f)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "plt.rcParams.update({\n",
    "    \"text.usetex\": True,\n",
    "    'text.latex.preamble': r'\\usepackage{amsfonts, amsmath, amssymb}',\n",
    "    \"font.family\": \"sans-serif\",\n",
    "    \"font.sans-serif\": [\"Helvetica\"],\n",
    "    'axes.labelsize':   18,\n",
    "    'axes.titlesize':   18,\n",
    "    'xtick.labelsize' : 16,\n",
    "    'ytick.labelsize' : 16,\n",
    "          })\n",
    "# latex font definition\n",
    "plt.rc('legend',fontsize=14)\n",
    "plt.rc('text', usetex=True)\n",
    "plt.rc('font', **{'family':'serif','serif':['Computer Modern Roman']})\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(dir_model+'loss_dict.pkl', 'rb') as f:\n",
    "    loss_dict = pickle.load(f)\n",
    "LOSS = loss_dict['LOSS']\n",
    "LOSS_L = loss_dict['LOSS_L']\n",
    "LOSS_D = loss_dict['LOSS_D']\n",
    "time_train = loss_dict['time_train']\n",
    "\n",
    "print('time_train:', time_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.semilogy(LOSS)\n",
    "plt.grid()\n",
    "plt.show()\n",
    "plt.close()\n",
    "\n",
    "plt.semilogy(LOSS_L, label='Loss $\\mathsf{L}$')\n",
    "plt.semilogy(LOSS_D, label='Loss $\\mathsf{d}$')\n",
    "plt.grid()\n",
    "plt.legend()\n",
    "# plt.savefig(dir_plt+ 'losses.pdf')\n",
    "plt.show()\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "if TRAIN:\n",
    "    gcnE.load_state_dict(torch.load(dir_model+f'gcnE.model'))\n",
    "    gcnE.eval()\n",
    "    gcnD.load_state_dict(torch.load(dir_model+f'gcnD.model'))\n",
    "    gcnD.eval()\n",
    "    optimizerE.load_state_dict(torch.load(dir_model+f'gcnE.opt'))\n",
    "    optimizerD.load_state_dict(torch.load(dir_model+f'gcnD.opt'))\n",
    "\n",
    "\n",
    "epoch_load = None\n",
    "if epoch_load is not None and TRAIN == False:\n",
    "    # LOAD MODELS\n",
    "    gcnE.load_state_dict(torch.load(dir_model+f'gcnE_{epoch_load}.model'))\n",
    "    gcnE.eval()\n",
    "    gcnD.load_state_dict(torch.load(dir_model+f'gcnD_{epoch_load}.model'))\n",
    "    gcnD.eval()\n",
    "    optimizerE.load_state_dict(torch.load(dir_model+f'gcnE_{epoch_load}.opt'))\n",
    "    optimizerD.load_state_dict(torch.load(dir_model+f'gcnD_{epoch_load}.opt'))\n",
    "    \n",
    "if epoch_load is None and TRAIN == False:\n",
    "    gcnE.load_state_dict(torch.load(dir_model+f'gcnE.model'))\n",
    "    gcnE.eval()\n",
    "    gcnD.load_state_dict(torch.load(dir_model+f'gcnD.model'))\n",
    "    gcnD.eval()\n",
    "    optimizerE.load_state_dict(torch.load(dir_model+f'gcnE.opt'))\n",
    "    optimizerD.load_state_dict(torch.load(dir_model+f'gcnD.opt'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = data_train.to(device)\n",
    "z = gcnE(data.x, data)\n",
    "z_hist = z.ravel().detach().cpu().numpy()\n",
    "plt.hist(z_hist, bins=30, density=True, label='Latent $z$ histogram')\n",
    "lin = np.linspace(-10, 10, 100)\n",
    "# plt.xlim(-5, 5)\n",
    "plt.plot(lin, np.exp(-lin**2/2)/np.sqrt(2*np.pi), label='Standard Gaussian')\n",
    "plt.legend()\n",
    "# plt.savefig(dir_plt+ 'hist_latent.pdf')\n",
    "plt.show()\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pyvista as pv\n",
    "\n",
    "# Enable notebook plotting (use 'static', 'client', 'panel', or 'pythreejs')\n",
    "pv.set_jupyter_backend('static')\n",
    "\n",
    "def plot_graph(graph, u, obs_idx=None, filename=None):\n",
    "    vertices = graph.pos.cpu().numpy()\n",
    "    faces = graph.face.cpu().numpy().T  # shape (n_faces, 3)\n",
    "\n",
    "    # PyVista expects faces as a flat array: [3, v0, v1, v2, 3, v0, v1, v2, ...]\n",
    "    faces_flat = np.hstack([np.full((faces.shape[0],1), 3), faces]).astype(np.int64).flatten()\n",
    "\n",
    "    # Create PolyData\n",
    "    mesh = pv.PolyData(vertices, faces_flat)\n",
    "    \n",
    "    mesh.point_data['helm'] = u\n",
    "    # Rotate the mesh if needed\n",
    "    # mesh.rotate_x(90)\n",
    "\n",
    "    # Create a plotter\n",
    "    p = pv.Plotter(notebook=True)\n",
    "    abs_max = np.max(np.abs(u))\n",
    "\n",
    "\n",
    "    # Add mesh: choose colormap and colorbar location\n",
    "    p.add_mesh(\n",
    "        mesh,\n",
    "        scalars='helm',\n",
    "        cmap='coolwarm',\n",
    "        # clim=[-abs_max, abs_max],\n",
    "        show_scalar_bar=True,\n",
    "        show_edges=True,   # <— This shows the mesh lines\n",
    "        edge_color='black',\n",
    "        line_width=0.1,  # Thicker mesh lines\n",
    "        scalar_bar_args={\n",
    "            'title': '',        # No title\n",
    "            'vertical': True,\n",
    "            'position_x': 0.1,\n",
    "            'position_y': 0.1,\n",
    "            'width': 0.05,\n",
    "            'height': 0.8,\n",
    "            'title_font_size': 20,    # Title font size\n",
    "            'label_font_size': 24     # Number labels font size\n",
    "        }\n",
    "    )\n",
    "    \n",
    "    if obs_idx is not None:\n",
    "        point = vertices[obs_idx, :]  # shape (1, 3)\n",
    "        print(obs_idx.shape)\n",
    "        print(vertices.shape)\n",
    "        print(point.shape)\n",
    "        # Create a PolyData object for the point\n",
    "        point_cloud = pv.PolyData(point)\n",
    "\n",
    "        p.add_mesh(point_cloud, color=\"magenta\", point_size=20, render_points_as_spheres=True)\n",
    "\n",
    "    # Set camera orientation: e.g., top view, front view, or custom camera position\n",
    "    # p.view_xy()     # Top-down (Z+)\n",
    "    # p.view_xz()     # Front view (Y+)\n",
    "    # p.view_yz()     # Side view (X+)\n",
    "\n",
    "    # Or set custom camera position: (position, focal_point, view_up)\n",
    "    # Example: look from (2,2,2) toward origin with Y-up\n",
    "    p.show_axes()\n",
    "    p.camera_position = [np.array([2, 2.75, -2])/3, (0.3, 0.22, 0.25), (0, 1, 0)]\n",
    "    p.show()\n",
    "    if filename is not None:\n",
    "        p.save_graphic(filename)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "idxs = [1, 2, 5]\n",
    "\n",
    "for i in idxs:\n",
    "    data = dataset_train[i]\n",
    "    plot_graph(data, data.x.cpu().numpy()[:,0], filename=dir_plt+f'data_car_u_{i}.pdf')\n",
    "    plt.close()\n",
    "    plot_graph(data, data.x.cpu().numpy()[:,1], filename=dir_plt+f'data_car_f_{i}.pdf')\n",
    "    plt.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loader_test = DataLoader(dataset_train, batch_size=1, shuffle=True)\n",
    "# loader_test = DataLoader(dataset_test, batch_size=1, shuffle=True)\n",
    "data_test = next(iter(loader_test))\n",
    "\n",
    "data_test = data_test.to(device)\n",
    "z = gcnE(data_test.x, data_test)\n",
    "u = gcnD(z, data_test)[:,0].detach().cpu()\n",
    "# u = u * STD + MEAN\n",
    "\n",
    "plot_graph(data_test, u.numpy(), filename=None)\n",
    "plt.close()\n",
    "plot_graph(data_test, data_test.x[:,0].detach().cpu().numpy(), filename=None)\n",
    "plt.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loader_test = DataLoader(dataset_test[7:8], batch_size=1, shuffle=True)\n",
    "\n",
    "data_test = next(iter(loader_test))\n",
    "data_test = data_test.to(device)\n",
    "z = gcnE(data_test.x, data_test)\n",
    "# print(z)\n",
    "for i in range(3):\n",
    "    z = torch.randn((1, dim_z)).to(device)\n",
    "    u = gcnD(z, data_test).detach().cpu().numpy()\n",
    "    # u = u * STD + MEAN\n",
    "\n",
    "    plot_graph(data_test, u[:,0], filename=dir_plt+f'prior_samples_u_{i}.pdf')\n",
    "    plt.close()\n",
    "    plot_graph(data_test, u[:,1], filename=dir_plt+f'prior_samples_f_{i}.pdf')\n",
    "    plt.close()\n",
    "    \n",
    "plot_graph(data_test, data_test.x[:].detach().cpu().numpy(), filename=None)\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### ABC\n",
    "# idx_text = 5 #6\n",
    "# idx_test = 7\n",
    "idx_test = 9\n",
    "\n",
    "data_test_graph = dataset_test[idx_test].clone()\n",
    "# data_test_graph = dataset_train[idx_test].clone()\n",
    "\n",
    "loader_test_gbase = DataLoader([data_test_graph], batch_size=1)\n",
    "data_test_graph  = next(iter(loader_test_gbase))\n",
    "# n_obs = 10\n",
    "n_obs = 20\n",
    "sigma_tc = torch.FloatTensor([0.01]).to(device)\n",
    "# sigma_tc *= 0.\n",
    "\n",
    "ObsIdx = np.random.choice(range(data_test_graph.x.shape[0]), size=n_obs, replace=False)\n",
    "ObsIdx = torch.tensor(ObsIdx)\n",
    "\n",
    "y_n  = data_test_graph.x[ObsIdx,0].to(device)\n",
    "print('yn.shape:', y_n.shape)\n",
    "y_n += sigma_tc * torch.randn_like(y_n).to(device)\n",
    "    \n",
    "batch_size = 500\n",
    "graph_test = data_test_graph\n",
    "graph_test_batch = [graph_test.clone() for _ in range(batch_size)] \n",
    "loader_test = DataLoader(graph_test_batch, batch_size=batch_size)\n",
    "graph_test_batch_loaded  = next(iter(loader_test)).to(device)\n",
    " \n",
    "graph_offsets = torch.arange(batch_size) * graph_test.x.shape[0]  # [B]\n",
    "global_indices = graph_offsets[:, None] + ObsIdx[None, :]  # [B, K]\n",
    "flat_indices = global_indices.reshape(-1)  # [B*K]\n",
    "\n",
    "u_decode = []\n",
    "z_acc = []\n",
    "u_acc = []\n",
    "n_loops = 500\n",
    "\n",
    "norms_list   = torch.tensor([np.inf])\n",
    "z_prior_list = torch.tensor(np.zeros((1, dim_z))).to(device, dtype=torch.float)  # Initialize with zeros\n",
    "\n",
    "with torch.no_grad():\n",
    "    for i in range(n_loops):\n",
    "        z_prior = torch.randn(batch_size, dim_z).to(device)\n",
    "        u_decode = gcnD(z_prior, graph_test_batch_loaded)\n",
    "        \n",
    "        y_selected = u_decode[flat_indices, 0]\n",
    "        y_selected = y_selected.view(batch_size, n_obs)\n",
    "        y_selected = y_selected + sigma_tc.to(device) * torch.randn_like(y_selected).to(device) # [B, K]\n",
    "        norms = torch.norm(y_selected - y_n[:], dim=-1)\n",
    "        norms_list = torch.cat( (norms_list, norms.detach().cpu()), dim=0)\n",
    "        z_prior_list = torch.cat( (z_prior_list, z_prior.detach()), dim=0)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "   \n",
    "sorted_index = torch.argsort(norms_list[:], dim=0)  # [n_loops, B*K]\n",
    "z_abc   = z_prior_list[sorted_index[:100]]\n",
    "\n",
    "print(norms_list[sorted_index])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loader_decode_ABC = DataLoader([graph_test], batch_size=1)\n",
    "graph_ABC_loaded = next(iter(loader_decode_ABC)).to(device)\n",
    "# z_abc = torch.vstack(z_acc)\n",
    "with torch.no_grad():\n",
    "    u_abc_decode = torch.vmap(gcnD, in_dims=(0, None))(z_abc[:, None, :], graph_ABC_loaded)\n",
    "# u_abc_decode.shape\n",
    "\n",
    "acc_ratio = z_abc.shape[0] / (batch_size * n_loops)\n",
    "print('acc_ratio = ', acc_ratio)\n",
    "print('num samples = ', z_abc.shape[0])\n",
    "\n",
    "u_min_norm = gcnD(z_prior_list[sorted_index[:1]], graph_ABC_loaded)\n",
    "u_min_norm = u_min_norm.cpu().detach().numpy()[:]\n",
    "\n",
    "print('best norm:', norms_list[sorted_index[0]])\n",
    "\n",
    "u_mean = np.mean(u_abc_decode.detach().cpu().numpy(), axis=0)[:]\n",
    "u_std = np.std(u_abc_decode.detach().cpu().numpy(), axis=0)[:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('mean')\n",
    "plot_graph(data_test_graph, u_mean[:,0], obs_idx=ObsIdx, filename=dir_plt+'recon_mean_u.pdf')\n",
    "plt.close()\n",
    "plot_graph(data_test_graph, u_mean[:,1], filename=dir_plt+'recon_mean_f.pdf')\n",
    "plt.close()\n",
    "\n",
    "print('gt')\n",
    "plot_graph(data_test_graph, data_test_graph.x[:,0].detach().cpu().numpy(), filename=dir_plt+'gt_u.pdf')\n",
    "plt.close()\n",
    "plot_graph(data_test_graph, data_test_graph.x[:,1].detach().cpu().numpy(), filename=dir_plt+'gt_f.pdf')\n",
    "plt.close()\n",
    "\n",
    "print('std')\n",
    "plot_graph(data_test_graph, u_std[:,0], obs_idx=ObsIdx, filename=dir_plt+'recon_std_u.pdf')\n",
    "plt.close()\n",
    "plot_graph(data_test_graph, u_std[:,1], obs_idx=None, filename=dir_plt+'recon_std_f.pdf')\n",
    "plt.close()\n",
    "\n",
    "print('error')\n",
    "plot_graph(data_test_graph, (u_mean[:,0] - data_test_graph.x[:,0].detach().cpu().numpy()), filename=dir_plt+'error_u.pdf')\n",
    "plt.close()\n",
    "plot_graph(data_test_graph, (u_mean[:,1] - data_test_graph.x[:,1].detach().cpu().numpy()), filename=dir_plt+'error_f.pdf')\n",
    "plt.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(ObsIdx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# u_pred = u_mode[:,0].cpu().detach().numpy()\n",
    "u_pred = u_mean\n",
    "# u_pred = u_min_norm\n",
    "\n",
    "u_true =  graph_test.x[:].cpu().detach().numpy()\n",
    "\n",
    "for i, var in enumerate(['u', 'f']):\n",
    "    MAE = np.mean( np.abs(u_pred[:,i] - u_true[:]) )\n",
    "    prec_in_1std = np.mean( np.abs(u_pred[:,i] - u_true[:]) < u_std[:,i] )\n",
    "    prec_in_2std = np.mean( np.abs(u_pred[:,i] - u_true[:]) < 2*u_std[:,i] )\n",
    "    print(f'Variable {var}:')\n",
    "    print(f'MAE = {MAE:.4f}')\n",
    "    print(f'prec_in_1std = {prec_in_1std:.4f}')\n",
    "    print(f'prec_in_2std = {prec_in_2std:.4f}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizerE.zero_grad()\n",
    "optimizerD.zero_grad()\n",
    "\n",
    "gcnD.zero_grad()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "def observation_map(graph, n_obs):\n",
    "    \n",
    "    ObsIdx = np.random.choice(range(graph.x.shape[0]), size=n_obs, replace=False)\n",
    "    ObsIdx = torch.tensor(ObsIdx)\n",
    "    \n",
    "    return ObsIdx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "\n",
    "sigma = 0.01\n",
    "n_obs = 20\n",
    "sigma_tc = torch.FloatTensor([sigma]).to(device)\n",
    "\n",
    "NOBS = np.arange(20,50)\n",
    "norm_nobs = (1./NOBS) / (np.sum(1./NOBS))  # normalize to sum to 1\n",
    "def get_rand_nobs():\n",
    "    return np.random.choice(NOBS, p=norm_nobs, size=1)[0]\n",
    "    # return n_obs\n",
    "    \n",
    "def observation_map(data_test_graph, n_obs):\n",
    "    return torch.tensor(np.random.choice(range(data_test_graph.x.shape[0]), size=n_obs, replace=False)).to(device)\n",
    "    \n",
    "def run_pred(data_test_graph):\n",
    "    \n",
    "\n",
    "    data_test_graph.to(device)\n",
    "    n_obs = get_rand_nobs()\n",
    "    print('Number of observations:', n_obs)\n",
    "    ObsIdx = observation_map(data_test_graph, n_obs)\n",
    "    \n",
    "    y_n  = data_test_graph.x[ObsIdx, 0]\n",
    "    y_n += sigma_tc * torch.randn_like(y_n)\n",
    "    y_n = y_n.to(device)\n",
    "        \n",
    "        \n",
    "    time_start = time.time()\n",
    "    batch_size = 500\n",
    "    graph_test = data_test_graph\n",
    "    graph_test_batch = [graph_test.clone() for _ in range(batch_size)] \n",
    "    loader_test = DataLoader(graph_test_batch, batch_size=batch_size)\n",
    "    graph_test_batch_loaded  = next(iter(loader_test)).to(device)\n",
    "    graph_test_batch_loaded.require_grad = False\n",
    "    \n",
    "    graph_offsets = torch.arange(batch_size).to(device) * graph_test.x.shape[0]  # [B]\n",
    "    global_indices = graph_offsets[:, None] + ObsIdx[None, :]  # [B, K]\n",
    "    flat_indices = global_indices.reshape(-1)  # [B*K]\n",
    "\n",
    "    n_loops = 100\n",
    "    \n",
    "    n_acc = 100\n",
    "\n",
    "    norms_list   = torch.tensor([np.inf])\n",
    "    z_prior_list = torch.tensor(np.zeros((1, dim_z))).to(device, dtype=torch.float)  # Initialize with zeros\n",
    "\n",
    "    for i in range(n_loops):\n",
    "        z_prior = torch.randn(batch_size, dim_z).to(device)\n",
    "        u_decode = gcnD(z_prior, graph_test_batch_loaded)\n",
    "        \n",
    "        y_selected = u_decode[flat_indices, 0]  # [B*K]\n",
    "        \n",
    "        # y_selected = u_decode[flat_indices]  # [B*K, F]\n",
    "        y_selected = y_selected.view(batch_size, n_obs)\n",
    "        y_selected = y_selected + sigma_tc.to(device) * torch.randn_like(y_selected).to(device) # [B, K]\n",
    "        norms = torch.norm(y_selected - y_n[:], dim=-1)\n",
    "        norms_list = torch.cat( (norms_list, norms.detach().cpu()), dim=0 )\n",
    "        z_prior_list = torch.cat( (z_prior_list, z_prior.detach()), dim=0 )\n",
    "        \n",
    "    sorted_index = torch.argsort(norms_list[:], dim=0)  # [n_loops, B*K]\n",
    "    z_abc   = z_prior_list[sorted_index[:n_acc]]\n",
    "            \n",
    "    # print(z_abc.shape)\n",
    "\n",
    "    loader_decode_ABC = DataLoader([graph_test], batch_size=1)\n",
    "    graph_ABC_loaded = next(iter(loader_decode_ABC)).to(device)\n",
    "    # z_abc = torch.vstack(z_acc)\n",
    "    u_abc_decode = torch.vmap(gcnD, in_dims=(0, None))(z_abc[:, None, :], graph_ABC_loaded)\n",
    "\n",
    "        \n",
    "    samples_decode_np = u_abc_decode.cpu().detach().numpy() #[:, :, :]\n",
    "    \n",
    "    u_mean = np.mean(samples_decode_np, axis=0)\n",
    "    u_std = np.std(samples_decode_np, axis=0)\n",
    "    \n",
    "    u_min_norm = gcnD( z_prior_list[sorted_index[:1]], graph_ABC_loaded)\n",
    "    u_min_norm = u_min_norm.cpu().detach().numpy()[:, 0]\n",
    "\n",
    "    \n",
    "    u_pred = u_mean\n",
    "    # u_pred = u_min_norm\n",
    "\n",
    "    u_true =  graph_test.x[:].cpu().detach().numpy()\n",
    "\n",
    "    MAE = []\n",
    "    MaxAE = []\n",
    "    prec_in_1std = []\n",
    "    prec_in_2std = []\n",
    "    for i, var in enumerate(['u', 'f']):\n",
    "        MAE.append(np.mean( np.abs(u_pred[:,i] - u_true[:,i]) ))\n",
    "        MaxAE.append(np.max( np.abs(u_pred[:,i] - u_true[:,i]) ))\n",
    "        prec_in_1std.append(np.mean( np.abs(u_pred[:,i] - u_true[:,i]) < u_std[:,i] ))\n",
    "        prec_in_2std.append(np.mean( np.abs(u_pred[:,i] - u_true[:,i]) < 2*u_std[:,i] ))\n",
    "        print(f'Variable {var}:')\n",
    "        print(f'MAE = {MAE[i]:.4f}')\n",
    "        print(f'MaxAE = {MaxAE[i]:.4f}')\n",
    "        print(f'prec_in_1std = {prec_in_1std[i]:.4f}')\n",
    "        print(f'prec_in_2std = {prec_in_2std[i]:.4f}')\n",
    "    \n",
    "    time_taken = time.time() - time_start\n",
    "    print('time_taken:', time_taken)\n",
    "\n",
    "    u_true =  data_test_graph.x[:,:].cpu().detach().numpy()\n",
    "    \n",
    "    return MAE, MaxAE, prec_in_1std, prec_in_2std, time_taken\n",
    "\n",
    "    \n",
    "# run_pred()\n",
    "\n",
    "variables = ['u', 'f']\n",
    "\n",
    "# Initialize dictionaries for each metric, keyed by variable name\n",
    "pred_stats = {\n",
    "    'MAE': {var: [] for var in variables},\n",
    "    'MaxAE': {var: [] for var in variables},\n",
    "    'prec_in_1std': {var: [] for var in variables},\n",
    "    'prec_in_2std': {var: [] for var in variables},\n",
    "    'time_taken': []\n",
    "}\n",
    "\n",
    "N_test = 100\n",
    "for i in range(N_test):\n",
    "    with torch.no_grad():\n",
    "        print(f'Running prediction for test data {i+1}/{N_test}')\n",
    "        # These should return a list/array of 3 values for each metric (one for each variable)\n",
    "        MAE, MaxAE, prec_in_1std, prec_in_2std, time_taken = run_pred(dataset_test[i])\n",
    "    \n",
    "    for j, var in enumerate(variables):\n",
    "        pred_stats['MAE'][var].append(MAE[j])\n",
    "        pred_stats['MaxAE'][var].append(MaxAE[j])\n",
    "        pred_stats['prec_in_1std'][var].append(prec_in_1std[j])\n",
    "        pred_stats['prec_in_2std'][var].append(prec_in_2std[j])\n",
    "    \n",
    "    pred_stats['time_taken'].append(time_taken)\n",
    "\n",
    "print('### Total Results ###')\n",
    "# Convert to numpy arrays\n",
    "for metric in ['MAE', 'MaxAE', 'prec_in_1std', 'prec_in_2std']:\n",
    "    for var in variables:\n",
    "        pred_stats[metric][var] = np.array(pred_stats[metric][var])\n",
    "pred_stats['time_taken'] = np.array(pred_stats['time_taken'])\n",
    "\n",
    "test_score_dict = {\n",
    "    'MAE': {var: [] for var in variables},\n",
    "    'MaxAE': {var: [] for var in variables},\n",
    "    'prec_in_1std': {var: [] for var in variables},\n",
    "    'prec_in_2std': {var: [] for var in variables},\n",
    "    'time': []\n",
    "}\n",
    "\n",
    "# Print results\n",
    "for metric in ['MAE', 'MaxAE', 'prec_in_1std', 'prec_in_2std']:\n",
    "    for var in variables:\n",
    "        mean_val = np.mean(pred_stats[metric][var])\n",
    "        std_val = np.std(pred_stats[metric][var])\n",
    "        test_score_dict[metric][var] = (mean_val, std_val)\n",
    "        print(f\"{metric} ({var}): {mean_val:.4f} ± {std_val:.4f}\")\n",
    "\n",
    "mean_time = np.mean(pred_stats['time_taken'])\n",
    "std_time = np.std(pred_stats['time_taken'])\n",
    "print(f\"time_taken: {mean_time:.4f} ± {std_time:.4f}\")\n",
    "test_score_dict['time'] = (mean_time, std_time)\n",
    "\n",
    "import pickle\n",
    "with open(dir_model+'test_score_dict.pkl', 'wb') as f:\n",
    "    pickle.dump(test_score_dict, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gabi",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
