{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f10bd61e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/ex_hdd/graph/lib/python3.8/site-packages/torch_geometric/typing.py:18: UserWarning: An issue occurred while importing 'pyg-lib'. Disabling its usage. Stacktrace: /ex_hdd/graph/lib/python3.8/site-packages/libpyg.so: undefined symbol: _ZNK5torch8autograd4Node4nameB5cxx11Ev\n",
      "  warnings.warn(f\"An issue occurred while importing 'pyg-lib'. \"\n",
      "/ex_hdd/graph/lib/python3.8/site-packages/torch_geometric/typing.py:42: UserWarning: An issue occurred while importing 'torch-sparse'. Disabling its usage. Stacktrace: /ex_hdd/graph/lib/python3.8/site-packages/libpyg.so: undefined symbol: _ZNK5torch8autograd4Node4nameB5cxx11Ev\n",
      "  warnings.warn(f\"An issue occurred while importing 'torch-sparse'. \"\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "import torch\n",
    "import random\n",
    "import h5py\n",
    "import warnings\n",
    "\n",
    "import numpy as np\n",
    "import pytorch_lightning as pl\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from utils import *\n",
    "from torch import nn\n",
    "from argparse import Namespace\n",
    "from torch_geometric.nn import MessagePassing, radius_graph, knn_graph, InstanceNorm\n",
    "from torch_geometric.data import Data\n",
    "from sklearn.model_selection import train_test_split\n",
    "from pytorch_lightning.callbacks import ModelCheckpoint\n",
    "from models.backbones.mlp import MLP\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "\n",
    "sys.path.append(\"../\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8c6794ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "hparams = Namespace(\n",
    "    # Optimization hyperparameters\n",
    "    factor = 0.3,\n",
    "    step_size= 50,\n",
    "    loss = 'l2',\n",
    "    lr=1e-3, \n",
    "    weight_decay=0,\n",
    "    seed =0,\n",
    "    gpus = [1],\n",
    "    \n",
    "    # Model hyperparameters\n",
    "    hidden_features = 128,\n",
    "    hidden_layer = 5, \n",
    "    time_window = 10,\n",
    "    teacher_forcing = False,\n",
    "    neighbors = 1,\n",
    "    regular = True\n",
    "\n",
    ")\n",
    "\n",
    "dirpath = 'save/MPNN(shallow_regular)/'+str(hparams.seed)\n",
    "data = '2D/Shallow/shallow.h5'\n",
    "batch_size = 16 \n",
    "num_workers = 20\n",
    "nt, nx, L = 50, 32, 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "4ee27407",
   "metadata": {},
   "outputs": [],
   "source": [
    "class HDF5DatasetGraph_2d_MPNN(Dataset):\n",
    "    def __init__(self, \n",
    "                 path,\n",
    "                 nt,\n",
    "                 res,\n",
    "                 mode='train', \n",
    "                 regular=True,\n",
    "                 seed=0):\n",
    "        \n",
    "        assert mode in ['train', 'valid', 'test'], \"mode must belong to one of these ['train', 'valid', 'test']\"\n",
    "        \n",
    "        self.f = h5py.File(path, 'r')\n",
    "        self.mode = mode\n",
    "        self.regular = regular\n",
    "        self.dataset = f'pde_{nt}-{res}'\n",
    "        self.seed = seed\n",
    "        \n",
    "        # Generate keys from '0000' to '0999'\n",
    "        all_keys = [str(i).zfill(4) for i in range(1000)]\n",
    "\n",
    "        # split keys into train, valid, and test sets\n",
    "        train_valid_keys, self.test_keys = train_test_split(all_keys, test_size=0.2, random_state=42)\n",
    "        self.train_keys, self.valid_keys = train_test_split(train_valid_keys, test_size=0.25, random_state=42)  # Taking 20% of 80% -> 16% of total as validation\n",
    "    \n",
    "        if self.mode == 'train':\n",
    "            self.keys = self.train_keys\n",
    "        elif self.mode == 'test':\n",
    "            self.keys = self.test_keys\n",
    "        else:  # For 'valid' mode\n",
    "            self.keys = self.valid_keys\n",
    "\n",
    "        if not self.regular:\n",
    "            # Only compute the necessary indices once to save memory\n",
    "            random.seed(seed)\n",
    "            self.sampled_indices = random.sample(range(128 * 128), 32 * 32) \n",
    "            \n",
    "    def __len__(self):\n",
    "        return len(self.keys)\n",
    "            \n",
    "    def __getitem__(self, idx):\n",
    "        key = self.keys[idx]\n",
    "        \n",
    "        if self.regular:\n",
    "            data = self.f[key]['data'][:, 2::4, 2::4, :]  # Shape: (101, 128, 128, 1) -> (101, 32, 32, 1)\n",
    "            u = torch.from_numpy(data.squeeze(-1))\n",
    "            u = u.reshape(u.shape[0], -1)\n",
    "            u = u.permute(1, 0)\n",
    "            u = u[:, :-1] #shape (1024, 101)-> (1024, 100)\n",
    "            x = torch.from_numpy(self.f[key]['grid']['x'][2::4])  # Shape: (128,) -> (32,)\n",
    "            y = torch.from_numpy(self.f[key]['grid']['y'][2::4])  # Shape: (128,) -> (32,)\n",
    "            coords = torch.stack(torch.meshgrid(x, y), dim=-1).reshape(-1, 2)  # Shape: (1024, 2)\n",
    "        else:\n",
    "            data = self.f[key]['data'][:]  \n",
    "            u = torch.from_numpy(data.squeeze(-1))\n",
    "            u = u.reshape(u.shape[0], -1)\n",
    "            u = u.permute(1, 0)\n",
    "            u = u[:, :-1] #shape (1024, 101)-> (1024, 100)\n",
    "            # Now compute the coordinates on the fly using the indices\n",
    "            x_full = torch.from_numpy(self.f[key]['grid']['x'][:])\n",
    "            y_full = torch.from_numpy(self.f[key]['grid']['y'][:])\n",
    "            W = len(x_full)\n",
    "\n",
    "            coords = [(x_full[i % W].item(), y_full[i // W].item()) for i in self.sampled_indices]\n",
    "            coords = torch.tensor(coords, dtype=torch.float32)\n",
    "            u = u[torch.tensor(self.sampled_indices, dtype=torch.long), :]\n",
    "\n",
    "        t = torch.from_numpy(self.f[key]['grid']['t'][:])  # Shape: (101,)\n",
    "        t = t[:-1] # shape(100,)\n",
    "        \n",
    "        return_tensors = {\n",
    "            'u': u,\n",
    "            'x': coords,\n",
    "            't': t\n",
    "        }        \n",
    "        return return_tensors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "276e447d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def rel_L2_error(pred, true):\n",
    "    return (torch.sum((true-pred)**2, dim=-1)/torch.sum((true)**2, dim=-1))**0.5\n",
    "\n",
    "class Swish(nn.Module):\n",
    "    \"\"\"\n",
    "    Swish activation function\n",
    "    \"\"\"\n",
    "    def __init__(self, beta=1):\n",
    "        super(Swish, self).__init__()\n",
    "        self.beta = beta\n",
    "\n",
    "    def forward(self, x):\n",
    "        return x * torch.sigmoid(self.beta*x)\n",
    "\n",
    "\n",
    "class GNN_Layer(MessagePassing):\n",
    "    \"\"\"\n",
    "    Message passing layer\n",
    "    \"\"\"\n",
    "    def __init__(self,\n",
    "                 in_features: int,\n",
    "                 out_features: int,\n",
    "                 hidden_features: int,\n",
    "                 time_window: int,\n",
    "                 n_variables: int):\n",
    "        \"\"\"\n",
    "        Initialize message passing layers\n",
    "        Args:\n",
    "            in_features (int): number of node input features\n",
    "            out_features (int): number of node output features\n",
    "            hidden_features (int): number of hidden features\n",
    "            time_window (int): number of input/output timesteps (temporal bundling)\n",
    "            n_variables (int): number of equation specific parameters used in the solver\n",
    "        \"\"\"\n",
    "        super(GNN_Layer, self).__init__(node_dim=-2, aggr='mean')\n",
    "        self.in_features = in_features\n",
    "        self.out_features = out_features\n",
    "        self.hidden_features = hidden_features\n",
    "\n",
    "        self.message_net_1 = nn.Sequential(nn.Linear(2 * in_features + time_window + 2 + n_variables, hidden_features),\n",
    "                                           Swish()\n",
    "                                           )\n",
    "        self.message_net_2 = nn.Sequential(nn.Linear(hidden_features, hidden_features),\n",
    "                                           Swish()\n",
    "                                           )\n",
    "        self.update_net_1 = nn.Sequential(nn.Linear(in_features + hidden_features + n_variables, hidden_features),\n",
    "                                          Swish()\n",
    "                                          )\n",
    "        self.update_net_2 = nn.Sequential(nn.Linear(hidden_features, out_features),\n",
    "                                          Swish()\n",
    "                                          )\n",
    "        self.norm = InstanceNorm(hidden_features)\n",
    "\n",
    "    def forward(self, x, u, pos, variables, edge_index, batch):\n",
    "        \"\"\"\n",
    "        Propagate messages along edges\n",
    "        \"\"\"\n",
    "        x = self.propagate(edge_index, x=x, u=u, pos=pos, variables=variables)\n",
    "        x = self.norm(x, batch)\n",
    "        return x\n",
    "\n",
    "    def message(self, x_i, x_j, u_i, u_j, pos_i, pos_j, variables_i):\n",
    "        \"\"\"\n",
    "        Message update following formula 8 of the paper\n",
    "        \"\"\"\n",
    "        message = self.message_net_1(torch.cat((x_i, x_j, u_i - u_j, pos_i - pos_j, variables_i), dim=-1))\n",
    "        message = self.message_net_2(message)\n",
    "        return message\n",
    "\n",
    "    def update(self, message, x, variables):\n",
    "        \"\"\"\n",
    "        Node update following formula 9 of the paper\n",
    "        \"\"\"\n",
    "        update = self.update_net_1(torch.cat((x, message, variables), dim=-1))\n",
    "        update = self.update_net_2(update)\n",
    "        if self.in_features == self.out_features:\n",
    "            return x + update\n",
    "        else:\n",
    "            return update"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "c1e40e65",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MPNN2d(pl.LightningModule):\n",
    "    def __init__(self,hparams):\n",
    "    \n",
    "        super().__init__()\n",
    "        \n",
    "        self.save_hyperparameters()\n",
    "\n",
    "        # Training parameters\n",
    "        self.lr = hparams.lr\n",
    "        self.weight_decay = hparams.weight_decay\n",
    "        self.factor = hparams.factor\n",
    "        self.step_size = hparams.step_size\n",
    "        self.loss = hparams.loss\n",
    "        # Model parameters\n",
    "        self.out_features = hparams.time_window\n",
    "        self.hidden_features = hparams.hidden_features\n",
    "        self.hidden_layer = hparams.hidden_layer\n",
    "        self.time_window = hparams.time_window\n",
    "        self.teacher_forcing = hparams.teacher_forcing\n",
    "        self.n = hparams.neighbors\n",
    "        self.seed = hparams.seed\n",
    "        pl.seed_everything(self.seed)\n",
    "       \n",
    "        self.gnn_layers = torch.nn.ModuleList(modules=(GNN_Layer(\n",
    "            in_features=self.hidden_features,\n",
    "            hidden_features=self.hidden_features,\n",
    "            out_features=self.hidden_features,\n",
    "            time_window=self.time_window,\n",
    "            n_variables=1  # variables = eq_variables + time\n",
    "        ) for _ in range(self.hidden_layer - 1)))\n",
    "\n",
    "        self.gnn_layers.append(GNN_Layer(in_features=self.hidden_features,\n",
    "                                         hidden_features=self.hidden_features,\n",
    "                                         out_features=self.hidden_features,\n",
    "                                         time_window=self.time_window,\n",
    "                                         n_variables=1\n",
    "                                        )\n",
    "                               )\n",
    "\n",
    "        self.embedding_mlp = nn.Sequential(\n",
    "            nn.Linear(self.time_window + 3, self.hidden_features),\n",
    "            Swish(),\n",
    "            nn.Linear(self.hidden_features, self.hidden_features),\n",
    "            Swish()\n",
    "        )\n",
    "\n",
    "        # Decoder CNN, maps to different outputs (temporal bundling)\n",
    "        if(self.time_window==10):\n",
    "            self.output_mlp = nn.Sequential(\n",
    "                                            nn.Conv1d(1, 8, 16, stride=6),\n",
    "                                            Swish(),\n",
    "                                            nn.Conv1d(8, 1, 10, stride=1))\n",
    "        if(self.time_window==16):\n",
    "            self.output_mlp = nn.Sequential(\n",
    "                                            nn.Conv1d(1, 8, 16, stride=5),\n",
    "                                            Swish(),\n",
    "                                            nn.Conv1d(8, 1, 8, stride=1))\n",
    "        if(self.time_window==20):\n",
    "            self.output_mlp = nn.Sequential(nn.Conv1d(1, 8, 15, stride=4),\n",
    "                                            Swish(),\n",
    "                                            nn.Conv1d(8, 1, 10, stride=1)\n",
    "                                            )\n",
    "        if (self.time_window == 25):\n",
    "            self.output_mlp = nn.Sequential(nn.Conv1d(1, 8, 16, stride=3),\n",
    "                                            Swish(),\n",
    "                                            nn.Conv1d(8, 1, 14, stride=1)\n",
    "                                            )\n",
    "        if(self.time_window==50):\n",
    "            self.output_mlp = nn.Sequential(nn.Conv1d(1, 8, 12, stride=2),\n",
    "                                            Swish(),\n",
    "                                            nn.Conv1d(8, 1, 10, stride=1)\n",
    "                                            )\n",
    "\n",
    "        if self.loss == 'l1':\n",
    "            self.criterion = nn.L1Loss()\n",
    "        elif self.loss == 'l2':\n",
    "            self.criterion = nn.MSELoss()\n",
    "        elif self.loss == 'smooth_l1':\n",
    "            self.criterion = nn.SmoothL1Loss()\n",
    "       \n",
    "        self.mse_criterion = nn.MSELoss()\n",
    "        self.mae_criterion = nn.L1Loss()\n",
    "    \n",
    "    def forward(self, data, L, tmax, dt):\n",
    "\n",
    "        u = data.x\n",
    "        # Encode and normalize coordinate information\n",
    "        pos = data.pos\n",
    "        pos_x = pos[:, 1][:, None] / L\n",
    "        pos_t = pos[:, 0][:, None] / tmax\n",
    "        edge_index = data.edge_index\n",
    "        batch = data.batch\n",
    "\n",
    "        # Encode equation specific parameters\n",
    "        # alpha, beta, gamma are used in E1,E2,E3 experiments\n",
    "        # bc_left, bc_right, c are used in WE1, WE2, WE3 experiments\n",
    "        variables = pos_t    # time is treated as equation variable\n",
    "        \n",
    "        # Encoder and processor (message passing)\n",
    "        node_input = torch.cat((u, pos_x, variables), -1)\n",
    "        h = self.embedding_mlp(node_input)\n",
    "        for i in range(self.hidden_layer):\n",
    "            h = self.gnn_layers[i](h, u, pos_x, variables, edge_index, batch)\n",
    "\n",
    "        # Decoder (formula 10 in the paper)\n",
    "        dt = (torch.ones(1, self.time_window).to(dt.device) * dt).to(dt.device)\n",
    "        dt = torch.cumsum(dt, dim=1)\n",
    "        # [batch*n_nodes, hidden_dim] -> 1DCNN([batch*n_nodes, 1, hidden_dim]) -> [batch*n_nodes, time_window]\n",
    "        diff = self.output_mlp(h[:, None]).squeeze(1)\n",
    "        out = u[:, -1].repeat(self.time_window, 1).transpose(0, 1) + dt * diff\n",
    "\n",
    "        return out\n",
    "        \n",
    "    \n",
    "    def configure_optimizers(self):\n",
    "        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)\n",
    "        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.step_size, gamma=self.factor)\n",
    "        return {\n",
    "        \"optimizer\": optimizer,\n",
    "        \"lr_scheduler\": {\n",
    "            \"scheduler\": scheduler\n",
    "        },\n",
    "    }\n",
    "    \n",
    "    def _build_graph(self,\n",
    "                    data: torch.Tensor,\n",
    "                    t: torch.Tensor,\n",
    "                    x: torch.Tensor,\n",
    "                    steps: list):\n",
    "        \"\"\"\n",
    "        data, [B, T, N]\n",
    "        t, [B]\n",
    "        x, [B, N]\n",
    "        steps, [B]\n",
    "        \"\"\"\n",
    "        nx = data.shape[-1]\n",
    "\n",
    "        u = torch.Tensor().to(data.device)\n",
    "        x_pos = torch.Tensor().to(data.device)\n",
    "        t_pos = torch.Tensor().to(data.device) \n",
    "        batch = torch.Tensor().to(data.device)\n",
    "\n",
    "        for b, (data_batch, step) in enumerate(zip(data, steps)):\n",
    "            u = torch.cat((u, torch.transpose(torch.cat([d[None, :] for d in data_batch]), 0, 1)), )\n",
    "            x_pos = torch.cat((x_pos, x[0]), )\n",
    "            t_pos = torch.cat((t_pos, torch.ones(nx, device=t.device) * t[b, step]), )\n",
    "            batch = torch.cat((batch, torch.ones(nx, device=batch.device) * b), )\n",
    "\n",
    "        # Calculate the edge_index\n",
    "        dx = x[0][1] - x[0][0]\n",
    "        dy = x[0][int(nx**0.5)] - x[0][0]\n",
    "        dr = torch.norm(dx-dy, p=2)\n",
    "        radius = self.n * dr + 0.0001\n",
    "\n",
    "        edge_index = radius_graph(x_pos, r=radius, batch=batch.long(), loop=False)\n",
    "#        edge_index = knn_graph(x_pos, k=self.n, batch=batch.long(), loop=False)\n",
    "        \n",
    "        graph = Data(x=u, edge_index=edge_index)\n",
    "        graph.pos = torch.cat((t_pos[:, None], x_pos), 1)\n",
    "        graph.batch = batch.long()\n",
    "\n",
    "        return graph\n",
    "\n",
    "    \n",
    "    def training_step(self, train_batch, batch_idx):\n",
    "        u = train_batch['u'].float().permute(0,2,1)\n",
    "        x = train_batch['x'].float().squeeze(-1)\n",
    "        B, _, N = u.shape\n",
    "        t = train_batch['t'].float() # B, T\n",
    "        dt = t[0][1] - t[0][0]\n",
    "                \n",
    "        graph = self._build_graph(\n",
    "            u[:,:self.time_window,:], \n",
    "            t,\n",
    "            x,\n",
    "            steps=[self.time_window-1]*B)\n",
    "        \n",
    "        target = u[:,self.time_window:,:]\n",
    "        T_out = target.shape[1]\n",
    "        \n",
    "        u_hat = []\n",
    "        for i in range(T_out//self.time_window):\n",
    "            y_hat = self.forward(graph, x[0,-1], t[0,-1], dt)\n",
    "            y_hat = y_hat.reshape(B, N, -1).permute(0,2,1)\n",
    "            u_hat.append(y_hat)\n",
    "            \n",
    "            if self.teacher_forcing:\n",
    "                graph = self._build_graph(\n",
    "                        u[:,(i+1)*self.time_window:(i+2)*self.time_window,:], \n",
    "                        t,\n",
    "                        x,\n",
    "                        steps=[(i+2)*self.time_window-1]*B)\n",
    "            else:\n",
    "                graph = self._build_graph(\n",
    "                        y_hat, \n",
    "                        t,\n",
    "                        x,\n",
    "                        steps=[(i+2)*self.time_window-1]*B)\n",
    "\n",
    "        u_hat = torch.cat(u_hat, dim=1)\n",
    "\n",
    "        loss = self.criterion(u_hat, target)\n",
    "        mae_loss = self.mae_criterion(u_hat, target)\n",
    "        rel_error = rel_L2_error(u_hat.reshape(B, -1), target.reshape(B, -1))\n",
    "        rel_error = torch.mean(rel_error)\n",
    "        \n",
    "        self.log('train_rel_error', rel_error, prog_bar=True)\n",
    "        self.log('train_loss', loss, prog_bar=True)\n",
    "        self.log('train_mae_loss', mae_loss, prog_bar=True)\n",
    "        \n",
    "        return loss\n",
    "\n",
    "    def validation_step(self, val_batch, batch_idx):\n",
    "        u = val_batch['u'].float().permute(0,2,1)\n",
    "        x = val_batch['x'].float().squeeze(-1)\n",
    "        B, T_in, N = u.shape\n",
    "        t = val_batch['t'].float() # B, T\n",
    "        dt = t[0][1] - t[0][0]\n",
    "                \n",
    "        graph = self._build_graph(\n",
    "            u[:,:self.time_window,:], \n",
    "            t,\n",
    "            x,\n",
    "            steps=[self.time_window-1]*B)\n",
    "        \n",
    "        target = u[:,self.time_window:,:]\n",
    "        T_out = target.shape[1]\n",
    "        \n",
    "        u_hat = []\n",
    "        for i in range(T_out//self.time_window):\n",
    "            y_hat = self.forward(graph, x[0,-1], t[0,-1], dt)\n",
    "            y_hat = y_hat.reshape(B, N, -1).permute(0,2,1)\n",
    "            u_hat.append(y_hat)\n",
    "                        \n",
    "            graph = self._build_graph(\n",
    "                        y_hat, \n",
    "                        t,\n",
    "                        x,\n",
    "                        steps=[(i+2)*self.time_window-1]*B)\n",
    "\n",
    "        u_hat = torch.cat(u_hat, dim=1)\n",
    "\n",
    "        loss = self.criterion(u_hat, target)\n",
    "        mae_loss = self.mae_criterion(u_hat, target)\n",
    "        rel_error = rel_L2_error(u_hat.reshape(B, -1), target.reshape(B, -1))\n",
    "        rel_error = torch.mean(rel_error)\n",
    "        \n",
    "        self.log('val_rel_error', rel_error, prog_bar=True)\n",
    "        self.log('val_loss', loss, prog_bar=True)\n",
    "        self.log('val_mae_loss', mae_loss, prog_bar=True)\n",
    "        \n",
    "        return loss\n",
    "        \n",
    "    def test_step(self, test_batch, batch_idx):\n",
    "        u = test_batch['u'].float().permute(0,2,1)\n",
    "        x = test_batch['x'].float().squeeze(-1)\n",
    "        B, T_in, N = u.shape\n",
    "        t = test_batch['t'].float() # B, T\n",
    "        dt = t[0][1] - t[0][0]\n",
    "                \n",
    "        graph = self._build_graph(\n",
    "            u[:,:self.time_window,:], \n",
    "            t,\n",
    "            x,\n",
    "            steps=[self.time_window-1]*B)\n",
    "        \n",
    "        target = u[:,self.time_window:,:]\n",
    "        T_out = target.shape[1]\n",
    "        \n",
    "        u_hat = []\n",
    "        for i in range(T_out//self.time_window):\n",
    "            y_hat = self.forward(graph, x[0,-1], t[0,-1], dt)\n",
    "            y_hat = y_hat.reshape(B, N, -1).permute(0,2,1)\n",
    "            u_hat.append(y_hat)\n",
    "                        \n",
    "            graph = self._build_graph(\n",
    "                        y_hat, \n",
    "                        t,\n",
    "                        x,\n",
    "                        steps=[(i+2)*self.time_window-1]*B)\n",
    "\n",
    "        u_hat = torch.cat(u_hat, dim=1)\n",
    "\n",
    "        loss = self.criterion(u_hat, target)\n",
    "        mae_loss = self.mae_criterion(u_hat, target)\n",
    "        rel_error = rel_L2_error(u_hat.reshape(B, -1), target.reshape(B, -1))\n",
    "        rel_error = torch.mean(rel_error)\n",
    "        \n",
    "        self.log('test_rel_error', rel_error, prog_bar=True)\n",
    "        self.log('test_loss', loss, prog_bar=True)\n",
    "        self.log('test_mae_loss', mae_loss, prog_bar=True)\n",
    "        \n",
    "        return {'test_loss': loss, 'test_rel_error': rel_error}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "b182d4e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "#_target_: datamodule.h5_datamodule_2d.HDF5DatamoduleGraph_2d\n",
    "#name: h5_datamodule_graph_2d\n",
    "#train_path: /home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_train_B1_64.h5\n",
    "#val_path: /home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_64.h5\n",
    "#test_path: /home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_64.h5\n",
    "#nt_train: 50\n",
    "#res_train: 64\n",
    "#nt_val: 50\n",
    "#res_val: 64\n",
    "#nt_test: 50\n",
    "#res_test: 64\n",
    "#train_regular: True\n",
    "#val_regular: True\n",
    "#test_regular: True\n",
    "\n",
    "#num_workers: 3\n",
    "#batch_size: 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "4d0e97df",
   "metadata": {},
   "outputs": [],
   "source": [
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "# Set the training and validation dataloaders in the PyTorch Lightning module\n",
    "model = MPNN2d(hparams)\n",
    "\n",
    "# Load the data\n",
    "train_dataset =  HDF5DatasetGraph_2d_MPNN(data,nt,nx,'train', hparams.regular,hparams.seed)\n",
    "valid_dataset = HDF5DatasetGraph_2d_MPNN(data,nt,nx,'valid', hparams.regular, hparams.seed)\n",
    "test_dataset = HDF5DatasetGraph_2d_MPNN(data,nt,nx,'test', hparams.regular,hparams.seed)\n",
    "\n",
    "# Create the dataloaders using the custom collate function\n",
    "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)\n",
    "valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)\n",
    "# Modify the test_loader to use the full dataset as a single batch\n",
    "test_loader = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=False, num_workers=num_workers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f2cd275",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True, used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n",
      "\n",
      "  | Name          | Type       | Params\n",
      "---------------------------------------------\n",
      "0 | gnn_layers    | ModuleList | 503 K \n",
      "1 | embedding_mlp | Sequential | 18.3 K\n",
      "2 | output_mlp    | Sequential | 217   \n",
      "3 | criterion     | MSELoss    | 0     \n",
      "4 | mse_criterion | MSELoss    | 0     \n",
      "5 | mae_criterion | L1Loss     | 0     \n",
      "---------------------------------------------\n",
      "521 K     Trainable params\n",
      "0         Non-trainable params\n",
      "521 K     Total params\n",
      "2.086     Total estimated model params size (MB)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation sanity check: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0edeece4d94f4f7c8c02f36d2e4af850",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 9, global step 1499: val_rel_error reached 0.00596 (best 0.00596), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/MPNN(shallow_regular)/epoch=09.ckpt\" as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 19, global step 2999: val_rel_error reached 0.00370 (best 0.00370), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/MPNN(shallow_regular)/epoch=19.ckpt\" as top 2\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 29, global step 4499: val_rel_error reached 0.00562 (best 0.00370), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/MPNN(shallow_regular)/epoch=29.ckpt\" as top 3\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 39, global step 5999: val_rel_error reached 0.00284 (best 0.00284), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/MPNN(shallow_regular)/epoch=39.ckpt\" as top 4\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 49, global step 7499: val_rel_error reached 0.00268 (best 0.00268), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/MPNN(shallow_regular)/epoch=49.ckpt\" as top 5\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 59, global step 8999: val_rel_error reached 0.00221 (best 0.00221), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/MPNN(shallow_regular)/epoch=59.ckpt\" as top 6\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 69, global step 10499: val_rel_error reached 0.00268 (best 0.00221), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/MPNN(shallow_regular)/epoch=69.ckpt\" as top 7\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 79, global step 11999: val_rel_error reached 0.00213 (best 0.00213), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/MPNN(shallow_regular)/epoch=79.ckpt\" as top 8\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 89, global step 13499: val_rel_error reached 0.00195 (best 0.00195), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/MPNN(shallow_regular)/epoch=89.ckpt\" as top 9\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 99, global step 14999: val_rel_error reached 0.00175 (best 0.00175), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/MPNN(shallow_regular)/epoch=99.ckpt\" as top 10\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 109, global step 16499: val_rel_error reached 0.00179 (best 0.00175), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/MPNN(shallow_regular)/epoch=109.ckpt\" as top 11\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 119, global step 17999: val_rel_error reached 0.00163 (best 0.00163), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/MPNN(shallow_regular)/epoch=119.ckpt\" as top 12\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 129, global step 19499: val_rel_error reached 0.00157 (best 0.00157), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/MPNN(shallow_regular)/epoch=129.ckpt\" as top 13\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 139, global step 20999: val_rel_error reached 0.00165 (best 0.00157), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/MPNN(shallow_regular)/epoch=139.ckpt\" as top 14\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 149, global step 22499: val_rel_error reached 0.00154 (best 0.00154), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/MPNN(shallow_regular)/epoch=149.ckpt\" as top 15\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Define the checkpoint callback\n",
    "checkpoint_callback = ModelCheckpoint(\n",
    "    dirpath=dirpath,\n",
    "    filename=\"{epoch:02d}\",\n",
    "    save_top_k=-1,\n",
    "    every_n_epochs=10,\n",
    "    verbose=True,\n",
    "    monitor=\"val_rel_error\",\n",
    "    mode=\"min\",\n",
    "    save_last=True,\n",
    ")\n",
    "\n",
    "# Initialize the trainer\n",
    "trainer = pl.Trainer(\n",
    "    accelerator=\"gpu\",\n",
    "    max_epochs=250,\n",
    "    log_every_n_steps=20,\n",
    "    gpus=hparams.gpus, \n",
    "    default_root_dir=\"lightning_logs/GDON_1D\",\n",
    "    callbacks=[checkpoint_callback]\n",
    ")\n",
    "\n",
    "# Train the model with the created dataloaders\n",
    "trainer.fit(model, train_loader, valid_loader)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "graph",
   "language": "python",
   "name": "graph"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
