{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f10bd61e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/ex_hdd/graph/lib/python3.8/site-packages/torch_geometric/typing.py:18: UserWarning: An issue occurred while importing 'pyg-lib'. Disabling its usage. Stacktrace: /ex_hdd/graph/lib/python3.8/site-packages/libpyg.so: undefined symbol: _ZNK5torch8autograd4Node4nameB5cxx11Ev\n",
      "  warnings.warn(f\"An issue occurred while importing 'pyg-lib'. \"\n",
      "/ex_hdd/graph/lib/python3.8/site-packages/torch_geometric/typing.py:42: UserWarning: An issue occurred while importing 'torch-sparse'. Disabling its usage. Stacktrace: /ex_hdd/graph/lib/python3.8/site-packages/libpyg.so: undefined symbol: _ZNK5torch8autograd4Node4nameB5cxx11Ev\n",
      "  warnings.warn(f\"An issue occurred while importing 'torch-sparse'. \"\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "import torch\n",
    "import random\n",
    "import h5py\n",
    "import warnings\n",
    "\n",
    "import numpy as np\n",
    "import pytorch_lightning as pl\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from utils import *\n",
    "from torch import nn\n",
    "from argparse import Namespace\n",
    "from torch_geometric.nn import MessagePassing, radius_graph, knn_graph, InstanceNorm\n",
    "from torch_geometric.data import Data\n",
    "from sklearn.model_selection import train_test_split\n",
    "from pytorch_lightning.callbacks import ModelCheckpoint\n",
    "from models.backbones.mlp import MLP\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "\n",
    "sys.path.append(\"../\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8c6794ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "hparams = Namespace(\n",
    "    # Optimization hyperparameters\n",
    "    factor = 0.8,\n",
    "    step_size= 50,\n",
    "    loss = 'l2',\n",
    "    lr=1e-3, \n",
    "    weight_decay=0,\n",
    "    seed = 2,\n",
    "    gpus = [0],\n",
    "    # Model hyperparameters\n",
    "    hidden_features = 128,\n",
    "    hidden_layer = 5, \n",
    "    time_window = 10,\n",
    "    teacher_forcing = False,\n",
    "    neighbors = 8,#8,\n",
    "    regular = False\n",
    ")\n",
    "\n",
    "dirpath = 'save/MPNN(shallow_irregular)/'+str(hparams.seed)\n",
    "data = '2D/Shallow/shallow.h5'\n",
    "batch_size = 16\n",
    "num_workers = 20\n",
    "nt, nx, L = 50, 32, 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "15e60a1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "class HDF5DatasetGraph_2d_MPNN(Dataset):\n",
    "    def __init__(self, \n",
    "                 path,\n",
    "                 nt,\n",
    "                 res,\n",
    "                 mode='train', \n",
    "                 regular=True,\n",
    "                 seed=0):\n",
    "        \n",
    "        assert mode in ['train', 'valid', 'test'], \"mode must belong to one of these ['train', 'valid', 'test']\"\n",
    "        \n",
    "        self.f = h5py.File(path, 'r')\n",
    "        self.mode = mode\n",
    "        self.regular = regular\n",
    "        self.dataset = f'pde_{nt}-{res}'\n",
    "        self.seed = seed\n",
    "        \n",
    "        # Generate keys from '0000' to '0999'\n",
    "        all_keys = [str(i).zfill(4) for i in range(1000)]\n",
    "\n",
    "        # split keys into train, valid, and test sets\n",
    "        train_valid_keys, self.test_keys = train_test_split(all_keys, test_size=0.2, random_state=42)\n",
    "        self.train_keys, self.valid_keys = train_test_split(train_valid_keys, test_size=0.25, random_state=42)  # Taking 20% of 80% -> 16% of total as validation\n",
    "    \n",
    "        if self.mode == 'train':\n",
    "            self.keys = self.train_keys\n",
    "        elif self.mode == 'test':\n",
    "            self.keys = self.test_keys\n",
    "        else:  # For 'valid' mode\n",
    "            self.keys = self.valid_keys\n",
    "\n",
    "        if not self.regular:\n",
    "            # Only compute the necessary indices once to save memory\n",
    "            random.seed(seed)\n",
    "            self.sampled_indices = random.sample(range(128 * 128), 32*32) \n",
    "            \n",
    "    def __len__(self):\n",
    "        return len(self.keys)\n",
    "            \n",
    "    def __getitem__(self, idx):\n",
    "        key = self.keys[idx]\n",
    "        \n",
    "        if self.regular:\n",
    "            data = self.f[key]['data'][::, 2::4, 2::4, :]  # Shape: (101, 128, 128, 1) -> (101, 32, 32, 1)\n",
    "            u = torch.from_numpy(data.squeeze(-1))\n",
    "            u = u.reshape(u.shape[0], -1)\n",
    "            u = u.permute(1, 0)\n",
    "            u = u[:, :-1] #shape (1024, 101)-> (1024, 100)\n",
    "            x = torch.from_numpy(self.f[key]['grid']['x'][2::4])  # Shape: (128,) -> (32,)\n",
    "            y = torch.from_numpy(self.f[key]['grid']['y'][2::4])  # Shape: (128,) -> (32,)\n",
    "            coords = torch.stack(torch.meshgrid(x, y), dim=-1).reshape(-1, 2)  # Shape: (1024, 2)\n",
    "        else:\n",
    "            data = self.f[key]['data'][::2]  \n",
    "            u = torch.from_numpy(data.squeeze(-1))\n",
    "            u = u.reshape(u.shape[0], -1)\n",
    "            u = u.permute(1, 0)\n",
    "            u = u[:, :-1] #shape (1024, 101)-> (1024, 100)\n",
    "            # Now compute the coordinates on the fly using the indices\n",
    "            x_full = torch.from_numpy(self.f[key]['grid']['x'][:])\n",
    "            y_full = torch.from_numpy(self.f[key]['grid']['y'][:])\n",
    "            W = len(x_full)\n",
    "\n",
    "            coords = [(x_full[i % W].item(), y_full[i // W].item()) for i in self.sampled_indices]\n",
    "            coords = torch.tensor(coords, dtype=torch.float32)\n",
    "            u = u[torch.tensor(self.sampled_indices, dtype=torch.long), :]\n",
    "\n",
    "        t = torch.from_numpy(self.f[key]['grid']['t'][:])  # Shape: (101,)\n",
    "        t = t[:-1] # shape(100,)\n",
    "        \n",
    "        return_tensors = {\n",
    "            'u': u,\n",
    "            'x': coords,\n",
    "            't': t\n",
    "        }        \n",
    "        return return_tensors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "276e447d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def rel_L2_error(pred, true):\n",
    "    return (torch.sum((true-pred)**2, dim=-1)/torch.sum((true)**2, dim=-1))**0.5\n",
    "\n",
    "class Swish(nn.Module):\n",
    "    \"\"\"\n",
    "    Swish activation function\n",
    "    \"\"\"\n",
    "    def __init__(self, beta=1):\n",
    "        super(Swish, self).__init__()\n",
    "        self.beta = beta\n",
    "\n",
    "    def forward(self, x):\n",
    "        return x * torch.sigmoid(self.beta*x)\n",
    "\n",
    "\n",
    "class GNN_Layer(MessagePassing):\n",
    "    \"\"\"\n",
    "    Message passing layer\n",
    "    \"\"\"\n",
    "    def __init__(self,\n",
    "                 in_features: int,\n",
    "                 out_features: int,\n",
    "                 hidden_features: int,\n",
    "                 time_window: int,\n",
    "                 n_variables: int):\n",
    "        \"\"\"\n",
    "        Initialize message passing layers\n",
    "        Args:\n",
    "            in_features (int): number of node input features\n",
    "            out_features (int): number of node output features\n",
    "            hidden_features (int): number of hidden features\n",
    "            time_window (int): number of input/output timesteps (temporal bundling)\n",
    "            n_variables (int): number of equation specific parameters used in the solver\n",
    "        \"\"\"\n",
    "        super(GNN_Layer, self).__init__(node_dim=-2, aggr='mean')\n",
    "        self.in_features = in_features\n",
    "        self.out_features = out_features\n",
    "        self.hidden_features = hidden_features\n",
    "\n",
    "        self.message_net_1 = nn.Sequential(nn.Linear(2 * in_features + time_window + 2 + n_variables, hidden_features),\n",
    "                                           Swish()\n",
    "                                           )\n",
    "        self.message_net_2 = nn.Sequential(nn.Linear(hidden_features, hidden_features),\n",
    "                                           Swish()\n",
    "                                           )\n",
    "        self.update_net_1 = nn.Sequential(nn.Linear(in_features + hidden_features + n_variables, hidden_features),\n",
    "                                          Swish()\n",
    "                                          )\n",
    "        self.update_net_2 = nn.Sequential(nn.Linear(hidden_features, out_features),\n",
    "                                          Swish()\n",
    "                                          )\n",
    "        self.norm = InstanceNorm(hidden_features)\n",
    "\n",
    "    def forward(self, x, u, pos, variables, edge_index, batch):\n",
    "        \"\"\"\n",
    "        Propagate messages along edges\n",
    "        \"\"\"\n",
    "        x = self.propagate(edge_index, x=x, u=u, pos=pos, variables=variables)\n",
    "        x = self.norm(x, batch)\n",
    "        return x\n",
    "\n",
    "    def message(self, x_i, x_j, u_i, u_j, pos_i, pos_j, variables_i):\n",
    "        \"\"\"\n",
    "        Message update following formula 8 of the paper\n",
    "        \"\"\"\n",
    "        message = self.message_net_1(torch.cat((x_i, x_j, u_i - u_j, pos_i - pos_j, variables_i), dim=-1))\n",
    "        message = self.message_net_2(message)\n",
    "        return message\n",
    "\n",
    "    def update(self, message, x, variables):\n",
    "        \"\"\"\n",
    "        Node update following formula 9 of the paper\n",
    "        \"\"\"\n",
    "        update = self.update_net_1(torch.cat((x, message, variables), dim=-1))\n",
    "        update = self.update_net_2(update)\n",
    "        if self.in_features == self.out_features:\n",
    "            return x + update\n",
    "        else:\n",
    "            return update"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e7fb626d",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MPNN2d(pl.LightningModule):\n",
    "    def __init__(self,hparams):\n",
    "    \n",
    "        super().__init__()\n",
    "        \n",
    "        self.save_hyperparameters()\n",
    "\n",
    "        # Training parameters\n",
    "        self.lr = hparams.lr\n",
    "        self.weight_decay = hparams.weight_decay\n",
    "        self.factor = hparams.factor\n",
    "        self.step_size = hparams.step_size\n",
    "        self.loss = hparams.loss\n",
    "        # Model parameters\n",
    "        self.out_features = hparams.time_window\n",
    "        self.hidden_features = hparams.hidden_features\n",
    "        self.hidden_layer = hparams.hidden_layer\n",
    "        self.time_window = hparams.time_window\n",
    "        self.teacher_forcing = hparams.teacher_forcing\n",
    "        self.seed = hparams.seed\n",
    "        self.n = hparams.neighbors\n",
    "\n",
    "        # Set the seed for reproducibility\n",
    "        pl.seed_everything(self.seed)\n",
    "        \n",
    "        self.gnn_layers = torch.nn.ModuleList(modules=(GNN_Layer(\n",
    "            in_features=self.hidden_features,\n",
    "            hidden_features=self.hidden_features,\n",
    "            out_features=self.hidden_features,\n",
    "            time_window=self.time_window,\n",
    "            n_variables=1  # variables = eq_variables + time\n",
    "        ) for _ in range(self.hidden_layer - 1)))\n",
    "\n",
    "        self.gnn_layers.append(GNN_Layer(in_features=self.hidden_features,\n",
    "                                         hidden_features=self.hidden_features,\n",
    "                                         out_features=self.hidden_features,\n",
    "                                         time_window=self.time_window,\n",
    "                                         n_variables=1\n",
    "                                        )\n",
    "                               )\n",
    "\n",
    "        self.embedding_mlp = nn.Sequential(\n",
    "            nn.Linear(self.time_window + 3, self.hidden_features),\n",
    "            Swish(),\n",
    "            nn.Linear(self.hidden_features, self.hidden_features),\n",
    "            Swish()\n",
    "        )\n",
    "\n",
    "        # Decoder CNN, maps to different outputs (temporal bundling)\n",
    "        if(self.time_window==10):\n",
    "            self.output_mlp = nn.Sequential(\n",
    "                                            nn.Conv1d(1, 8, 16, stride=6),\n",
    "                                            Swish(),\n",
    "                                            nn.Conv1d(8, 1, 10, stride=1))\n",
    "        if(self.time_window==16):\n",
    "            self.output_mlp = nn.Sequential(\n",
    "                                            nn.Conv1d(1, 8, 16, stride=5),\n",
    "                                            Swish(),\n",
    "                                            nn.Conv1d(8, 1, 8, stride=1))\n",
    "        if(self.time_window==20):\n",
    "            self.output_mlp = nn.Sequential(nn.Conv1d(1, 8, 15, stride=4),\n",
    "                                            Swish(),\n",
    "                                            nn.Conv1d(8, 1, 10, stride=1)\n",
    "                                            )\n",
    "        if (self.time_window == 25):\n",
    "            self.output_mlp = nn.Sequential(nn.Conv1d(1, 8, 16, stride=3),\n",
    "                                            Swish(),\n",
    "                                            nn.Conv1d(8, 1, 14, stride=1)\n",
    "                                            )\n",
    "        if(self.time_window==50):\n",
    "            self.output_mlp = nn.Sequential(nn.Conv1d(1, 8, 12, stride=2),\n",
    "                                            Swish(),\n",
    "                                            nn.Conv1d(8, 1, 10, stride=1)\n",
    "                                            )\n",
    "\n",
    "        if self.loss == 'l1':\n",
    "            self.criterion = nn.L1Loss()\n",
    "        elif self.loss == 'l2':\n",
    "            self.criterion = nn.MSELoss()\n",
    "        elif self.loss == 'smooth_l1':\n",
    "            self.criterion = nn.SmoothL1Loss()\n",
    "       \n",
    "        self.mse_criterion = nn.MSELoss()\n",
    "        self.mae_criterion = nn.L1Loss()\n",
    "    \n",
    "    def forward(self, data, L, tmax, dt):\n",
    "\n",
    "        u = data.x\n",
    "        # Encode and normalize coordinate information\n",
    "        pos = data.pos\n",
    "        pos_x = pos[:, 1][:, None] / L\n",
    "        pos_t = pos[:, 0][:, None] / tmax\n",
    "        edge_index = data.edge_index\n",
    "        batch = data.batch\n",
    "\n",
    "        # Encode equation specific parameters\n",
    "        # alpha, beta, gamma are used in E1,E2,E3 experiments\n",
    "        # bc_left, bc_right, c are used in WE1, WE2, WE3 experiments\n",
    "        variables = pos_t    # time is treated as equation variable\n",
    "        \n",
    "        # Encoder and processor (message passing)\n",
    "        node_input = torch.cat((u, pos_x, variables), -1)\n",
    "        h = self.embedding_mlp(node_input)\n",
    "        for i in range(self.hidden_layer):\n",
    "            h = self.gnn_layers[i](h, u, pos_x, variables, edge_index, batch)\n",
    "\n",
    "        # Decoder (formula 10 in the paper)\n",
    "        dt = (torch.ones(1, self.time_window).to(dt.device) * dt).to(dt.device)\n",
    "        dt = torch.cumsum(dt, dim=1)\n",
    "        # [batch*n_nodes, hidden_dim] -> 1DCNN([batch*n_nodes, 1, hidden_dim]) -> [batch*n_nodes, time_window]\n",
    "        diff = self.output_mlp(h[:, None]).squeeze(1)\n",
    "        out = u[:, -1].repeat(self.time_window, 1).transpose(0, 1) + dt * diff\n",
    "\n",
    "        return out\n",
    "        \n",
    "    \n",
    "    def configure_optimizers(self):\n",
    "        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)\n",
    "        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.step_size, gamma=self.factor)\n",
    "        return {\n",
    "        \"optimizer\": optimizer,\n",
    "        \"lr_scheduler\": {\n",
    "            \"scheduler\": scheduler\n",
    "        },\n",
    "    }\n",
    "    \n",
    "    def _build_graph(self,\n",
    "                    data: torch.Tensor,\n",
    "                    t: torch.Tensor,\n",
    "                    x: torch.Tensor,\n",
    "                    steps: list):\n",
    "        \"\"\"\n",
    "        data, [B, T, N]\n",
    "        t, [B]\n",
    "        x, [B, N]\n",
    "        steps, [B]\n",
    "        \"\"\"\n",
    "        nx = data.shape[-1]\n",
    "\n",
    "        u = torch.Tensor().to(data.device)\n",
    "        x_pos = torch.Tensor().to(data.device)\n",
    "        t_pos = torch.Tensor().to(data.device) \n",
    "        batch = torch.Tensor().to(data.device)\n",
    "\n",
    "        for b, (data_batch, step) in enumerate(zip(data, steps)):\n",
    "            u = torch.cat((u, torch.transpose(torch.cat([d[None, :] for d in data_batch]), 0, 1)), )\n",
    "            x_pos = torch.cat((x_pos, x[0]), )\n",
    "            t_pos = torch.cat((t_pos, torch.ones(nx, device=t.device) * t[b, step]), )\n",
    "            batch = torch.cat((batch, torch.ones(nx, device=batch.device) * b), )\n",
    "\n",
    "        # Calculate the edge_index\n",
    "        dx = x[0][1] - x[0][0]\n",
    "        dy = x[0][int(nx**0.5)] - x[0][0]\n",
    "        dr = torch.norm(dx-dy, p=2)\n",
    "        radius = self.n * dr + 0.0001\n",
    "\n",
    "#        edge_index = radius_graph(x_pos, r=radius, batch=batch.long(), loop=False)\n",
    "        edge_index = knn_graph(x_pos, k=self.n, batch=batch.long(), loop=False)\n",
    "        \n",
    "        graph = Data(x=u, edge_index=edge_index)\n",
    "        graph.pos = torch.cat((t_pos[:, None], x_pos), 1)\n",
    "        graph.batch = batch.long()\n",
    "\n",
    "        return graph\n",
    "\n",
    "    \n",
    "    def training_step(self, train_batch, batch_idx):\n",
    "        u = train_batch['u'].float().permute(0,2,1)\n",
    "        x = train_batch['x'].float().squeeze(-1)\n",
    "        B, _, N = u.shape\n",
    "        t = train_batch['t'].float() # B, T\n",
    "        dt = t[0][1] - t[0][0]\n",
    "                \n",
    "        graph = self._build_graph(\n",
    "            u[:,:self.time_window,:], \n",
    "            t,\n",
    "            x,\n",
    "            steps=[self.time_window-1]*B)\n",
    "        \n",
    "        target = u[:,self.time_window:,:]\n",
    "        T_out = target.shape[1]\n",
    "        \n",
    "        u_hat = []\n",
    "        for i in range(T_out//self.time_window):\n",
    "            y_hat = self.forward(graph, x[0,-1], t[0,-1], dt)\n",
    "            y_hat = y_hat.reshape(B, N, -1).permute(0,2,1)\n",
    "            u_hat.append(y_hat)\n",
    "            \n",
    "            if self.teacher_forcing:\n",
    "                graph = self._build_graph(\n",
    "                        u[:,(i+1)*self.time_window:(i+2)*self.time_window,:], \n",
    "                        t,\n",
    "                        x,\n",
    "                        steps=[(i+2)*self.time_window-1]*B)\n",
    "            else:\n",
    "                graph = self._build_graph(\n",
    "                        y_hat, \n",
    "                        t,\n",
    "                        x,\n",
    "                        steps=[(i+2)*self.time_window-1]*B)\n",
    "\n",
    "        u_hat = torch.cat(u_hat, dim=1)\n",
    "\n",
    "        loss = self.criterion(u_hat, target)\n",
    "        mae_loss = self.mae_criterion(u_hat, target)\n",
    "        rel_error = rel_L2_error(u_hat.reshape(B, -1), target.reshape(B, -1))\n",
    "        rel_error = torch.mean(rel_error)\n",
    "        \n",
    "        self.log('train_rel_error', rel_error, prog_bar=True)\n",
    "        self.log('train_loss', loss, prog_bar=True)\n",
    "        self.log('train_mae_loss', mae_loss, prog_bar=True)\n",
    "        \n",
    "        return loss\n",
    "\n",
    "    def validation_step(self, val_batch, batch_idx):\n",
    "        u = val_batch['u'].float().permute(0,2,1)\n",
    "        x = val_batch['x'].float().squeeze(-1)\n",
    "        B, T_in, N = u.shape\n",
    "        t = val_batch['t'].float() # B, T\n",
    "        dt = t[0][1] - t[0][0]\n",
    "                \n",
    "        graph = self._build_graph(\n",
    "            u[:,:self.time_window,:], \n",
    "            t,\n",
    "            x,\n",
    "            steps=[self.time_window-1]*B)\n",
    "        \n",
    "        target = u[:,self.time_window:,:]\n",
    "        T_out = target.shape[1]\n",
    "        \n",
    "        u_hat = []\n",
    "        for i in range(T_out//self.time_window):\n",
    "            y_hat = self.forward(graph, x[0,-1], t[0,-1], dt)\n",
    "            y_hat = y_hat.reshape(B, N, -1).permute(0,2,1)\n",
    "            u_hat.append(y_hat)\n",
    "                        \n",
    "            graph = self._build_graph(\n",
    "                        y_hat, \n",
    "                        t,\n",
    "                        x,\n",
    "                        steps=[(i+2)*self.time_window-1]*B)\n",
    "\n",
    "        u_hat = torch.cat(u_hat, dim=1)\n",
    "\n",
    "        loss = self.criterion(u_hat, target)\n",
    "        mae_loss = self.mae_criterion(u_hat, target)\n",
    "        rel_error = rel_L2_error(u_hat.reshape(B, -1), target.reshape(B, -1))\n",
    "        rel_error = torch.mean(rel_error)\n",
    "        \n",
    "        self.log('val_rel_error', rel_error, prog_bar=True)\n",
    "        self.log('val_loss', loss, prog_bar=True)\n",
    "        self.log('val_mae_loss', mae_loss, prog_bar=True)\n",
    "        \n",
    "        return loss\n",
    "        \n",
    "    def test_step(self, test_batch, batch_idx):\n",
    "        u = test_batch['u'].float().permute(0,2,1)\n",
    "        x = test_batch['x'].float().squeeze(-1)\n",
    "        B, T_in, N = u.shape\n",
    "        t = test_batch['t'].float() # B, T\n",
    "        dt = t[0][1] - t[0][0]\n",
    "                \n",
    "        graph = self._build_graph(\n",
    "            u[:,:self.time_window,:], \n",
    "            t,\n",
    "            x,\n",
    "            steps=[self.time_window-1]*B)\n",
    "        \n",
    "        target = u[:,self.time_window:,:]\n",
    "        T_out = target.shape[1]\n",
    "        \n",
    "        u_hat = []\n",
    "        for i in range(T_out//self.time_window):\n",
    "            y_hat = self.forward(graph, x[0,-1], t[0,-1], dt)\n",
    "            y_hat = y_hat.reshape(B, N, -1).permute(0,2,1)\n",
    "            u_hat.append(y_hat)\n",
    "                        \n",
    "            graph = self._build_graph(\n",
    "                        y_hat, \n",
    "                        t,\n",
    "                        x,\n",
    "                        steps=[(i+2)*self.time_window-1]*B)\n",
    "\n",
    "        u_hat = torch.cat(u_hat, dim=1)\n",
    "\n",
    "        loss = self.criterion(u_hat, target)\n",
    "        mae_loss = self.mae_criterion(u_hat, target)\n",
    "        rel_error = rel_L2_error(u_hat.reshape(B, -1), target.reshape(B, -1))\n",
    "        rel_error = torch.mean(rel_error)\n",
    "        \n",
    "        self.log('test_rel_error', rel_error, prog_bar=True)\n",
    "        self.log('test_loss', loss, prog_bar=True)\n",
    "        self.log('test_mae_loss', mae_loss, prog_bar=True)\n",
    "        \n",
    "        return {'test_loss': loss, 'test_rel_error': rel_error}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b182d4e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "#_target_: datamodule.h5_datamodule_2d.HDF5DatamoduleGraph_2d\n",
    "#name: h5_datamodule_graph_2d\n",
    "#train_path: /home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_train_B1_64.h5\n",
    "#val_path: /home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_64.h5\n",
    "#test_path: /home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_64.h5\n",
    "#nt_train: 50\n",
    "#res_train: 64\n",
    "#nt_val: 50\n",
    "#res_val: 64\n",
    "#nt_test: 50\n",
    "#res_test: 64\n",
    "#train_regular: True\n",
    "#val_regular: True\n",
    "#test_regular: True\n",
    "\n",
    "#num_workers: 3\n",
    "#batch_size: 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4d0e97df",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Global seed set to 2\n"
     ]
    }
   ],
   "source": [
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "# Set the training and validation dataloaders in the PyTorch Lightning module\n",
    "model = MPNN2d(hparams)\n",
    "\n",
    "# Load the data\n",
    "train_dataset =  HDF5DatasetGraph_2d_MPNN(data,nt,nx,'train', hparams.regular ,hparams.seed)\n",
    "valid_dataset = HDF5DatasetGraph_2d_MPNN(data,nt,nx,'valid', hparams.regular ,hparams.seed)\n",
    "test_dataset = HDF5DatasetGraph_2d_MPNN(data,nt,nx,'test', hparams.regular ,hparams.seed)\n",
    "\n",
    "# Create the dataloaders using the custom collate function\n",
    "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)\n",
    "valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)\n",
    "# Modify the test_loader to use the full dataset as a single batch\n",
    "test_loader = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=False, num_workers=num_workers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "6c567676",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True, used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n",
      "\n",
      "  | Name          | Type       | Params\n",
      "---------------------------------------------\n",
      "0 | gnn_layers    | ModuleList | 503 K \n",
      "1 | embedding_mlp | Sequential | 18.3 K\n",
      "2 | output_mlp    | Sequential | 217   \n",
      "3 | criterion     | MSELoss    | 0     \n",
      "4 | mse_criterion | MSELoss    | 0     \n",
      "5 | mae_criterion | L1Loss     | 0     \n",
      "---------------------------------------------\n",
      "521 K     Trainable params\n",
      "0         Non-trainable params\n",
      "521 K     Total params\n",
      "2.086     Total estimated model params size (MB)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation sanity check: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Global seed set to 2\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "331457cd1efe4d81bb7a71106c7aa232",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 9, global step 379: val_rel_error reached 0.02428 (best 0.02428), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MPNN(shallow_irregular)/2/epoch=09.ckpt\" as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 19, global step 759: val_rel_error reached 0.02185 (best 0.02185), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MPNN(shallow_irregular)/2/epoch=19.ckpt\" as top 2\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 29, global step 1139: val_rel_error reached 0.01624 (best 0.01624), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MPNN(shallow_irregular)/2/epoch=29.ckpt\" as top 3\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 39, global step 1519: val_rel_error reached 0.02902 (best 0.01624), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MPNN(shallow_irregular)/2/epoch=39.ckpt\" as top 4\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 49, global step 1899: val_rel_error reached 0.01450 (best 0.01450), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MPNN(shallow_irregular)/2/epoch=49.ckpt\" as top 5\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 59, global step 2279: val_rel_error reached 0.01351 (best 0.01351), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MPNN(shallow_irregular)/2/epoch=59.ckpt\" as top 6\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 69, global step 2659: val_rel_error reached 0.01205 (best 0.01205), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MPNN(shallow_irregular)/2/epoch=69.ckpt\" as top 7\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 79, global step 3039: val_rel_error reached 0.00982 (best 0.00982), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MPNN(shallow_irregular)/2/epoch=79.ckpt\" as top 8\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 89, global step 3419: val_rel_error reached 0.00876 (best 0.00876), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MPNN(shallow_irregular)/2/epoch=89.ckpt\" as top 9\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 99, global step 3799: val_rel_error reached 0.00944 (best 0.00876), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MPNN(shallow_irregular)/2/epoch=99.ckpt\" as top 10\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 109, global step 4179: val_rel_error reached 0.00776 (best 0.00776), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MPNN(shallow_irregular)/2/epoch=109.ckpt\" as top 11\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 119, global step 4559: val_rel_error reached 0.00742 (best 0.00742), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MPNN(shallow_irregular)/2/epoch=119.ckpt\" as top 12\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 129, global step 4939: val_rel_error reached 0.00729 (best 0.00729), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MPNN(shallow_irregular)/2/epoch=129.ckpt\" as top 13\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 139, global step 5319: val_rel_error reached 0.00654 (best 0.00654), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MPNN(shallow_irregular)/2/epoch=139.ckpt\" as top 14\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 149, global step 5699: val_rel_error reached 0.00630 (best 0.00630), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MPNN(shallow_irregular)/2/epoch=149.ckpt\" as top 15\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 159, global step 6079: val_rel_error reached 0.00606 (best 0.00606), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MPNN(shallow_irregular)/2/epoch=159.ckpt\" as top 16\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 169, global step 6459: val_rel_error reached 0.00607 (best 0.00606), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MPNN(shallow_irregular)/2/epoch=169.ckpt\" as top 17\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 179, global step 6839: val_rel_error reached 0.00568 (best 0.00568), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MPNN(shallow_irregular)/2/epoch=179.ckpt\" as top 18\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 189, global step 7219: val_rel_error reached 0.00546 (best 0.00546), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MPNN(shallow_irregular)/2/epoch=189.ckpt\" as top 19\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 199, global step 7599: val_rel_error reached 0.00529 (best 0.00529), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MPNN(shallow_irregular)/2/epoch=199.ckpt\" as top 20\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 209, global step 7979: val_rel_error reached 0.00504 (best 0.00504), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MPNN(shallow_irregular)/2/epoch=209.ckpt\" as top 21\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 219, global step 8359: val_rel_error reached 0.00585 (best 0.00504), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MPNN(shallow_irregular)/2/epoch=219.ckpt\" as top 22\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 229, global step 8739: val_rel_error reached 0.00540 (best 0.00504), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MPNN(shallow_irregular)/2/epoch=229.ckpt\" as top 23\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 239, global step 9119: val_rel_error reached 0.00486 (best 0.00486), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MPNN(shallow_irregular)/2/epoch=239.ckpt\" as top 24\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 249, global step 9499: val_rel_error reached 0.00464 (best 0.00464), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MPNN(shallow_irregular)/2/epoch=249.ckpt\" as top 25\n",
      "Saving latest checkpoint...\n"
     ]
    }
   ],
   "source": [
    "# Define the checkpoint callback\n",
    "checkpoint_callback = ModelCheckpoint(\n",
    "    dirpath=dirpath,\n",
    "    filename=\"{epoch:02d}\",\n",
    "    save_top_k=-1,\n",
    "    every_n_epochs=10,\n",
    "    verbose=True,\n",
    "    monitor=\"val_rel_error\",\n",
    "    mode=\"min\",\n",
    "    save_last=True,\n",
    ")\n",
    "\n",
    "# Initialize the trainer\n",
    "trainer = pl.Trainer(\n",
    "    accelerator=\"gpu\",\n",
    "    max_epochs=250,\n",
    "    log_every_n_steps=300,\n",
    "    gpus=hparams.gpus, \n",
    "    default_root_dir=\"lightning_logs\",\n",
    "    callbacks=[checkpoint_callback]\n",
    ")\n",
    "\n",
    "# Train the model with the created dataloaders\n",
    "trainer.fit(model, train_loader, valid_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "9fc5ca8e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Global seed set to 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average Test Loss: 0.050286468118429184, Average Test Relative Error: 0.21536745131015778\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(0.050286468118429184, 0.21536745131015778)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dirpath = 'save/MPNN(shallow_irregular)/'+str(1)\n",
    "data = '2D/Shallow/shallow.h5'\n",
    "\n",
    "# Load the data\n",
    "train_dataset =  HDF5DatasetGraph_2d_MPNN(data,nt,nx,'train', hparams.regular ,hparams.seed)\n",
    "valid_dataset = HDF5DatasetGraph_2d_MPNN(data,nt,nx,'valid', hparams.regular ,hparams.seed)\n",
    "test_dataset = HDF5DatasetGraph_2d_MPNN(data,nt,nx,'test', hparams.regular ,hparams.seed)\n",
    "\n",
    "# Create the dataloaders using the custom collate function\n",
    "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)\n",
    "valid_loader = DataLoader(valid_dataset, batch_size=len(valid_dataset), shuffle=False, num_workers=num_workers)\n",
    "# Modify the test_loader to use the full dataset as a single batch\n",
    "test_loader = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=False, num_workers=num_workers)\n",
    "\n",
    "def compute_test_error(model, test_loader):\n",
    "    # Move model to CPU\n",
    "    model = model.to(\"cpu\")\n",
    "    \n",
    "    # Ensure model is in eval mode\n",
    "    model.eval()\n",
    "    \n",
    "    # Accumulators for loss and relative error\n",
    "    total_test_loss = 0.0\n",
    "    total_test_rel_error = 0.0\n",
    "    \n",
    "    # Disable gradients to save memory\n",
    "    with torch.no_grad():\n",
    "        num_batches = 0\n",
    "        for batch in test_loader:\n",
    "            results = model.test_step(batch, batch_idx=num_batches)\n",
    "            total_test_loss += results['test_loss'].item()\n",
    "            total_test_rel_error += results['test_rel_error'].item()\n",
    "            num_batches += 1\n",
    "            \n",
    "    # Compute average values\n",
    "    avg_test_loss = total_test_loss / num_batches\n",
    "    avg_test_rel_error = total_test_rel_error / num_batches\n",
    "            \n",
    "    # Print or return the results\n",
    "    print(f\"Average Test Loss: {avg_test_loss}, Average Test Relative Error: {avg_test_rel_error}\")\n",
    "    return avg_test_loss, avg_test_rel_error\n",
    "\n",
    "checkpoint_path = dirpath+'/epoch=249.ckpt'\n",
    "model = MPNN2d.load_from_checkpoint(checkpoint_path)\n",
    "\n",
    "# Call the function to compute the test error\n",
    "compute_test_error(model, valid_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ddffab11",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "graph",
   "language": "python",
   "name": "graph"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
