{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d12e35d5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/ex_hdd/graph/lib/python3.8/site-packages/torch_geometric/typing.py:18: UserWarning: An issue occurred while importing 'pyg-lib'. Disabling its usage. Stacktrace: /ex_hdd/graph/lib/python3.8/site-packages/libpyg.so: undefined symbol: _ZNK5torch8autograd4Node4nameB5cxx11Ev\n",
      "  warnings.warn(f\"An issue occurred while importing 'pyg-lib'. \"\n",
      "/ex_hdd/graph/lib/python3.8/site-packages/torch_geometric/typing.py:42: UserWarning: An issue occurred while importing 'torch-sparse'. Disabling its usage. Stacktrace: /ex_hdd/graph/lib/python3.8/site-packages/libpyg.so: undefined symbol: _ZNK5torch8autograd4Node4nameB5cxx11Ev\n",
      "  warnings.warn(f\"An issue occurred while importing 'torch-sparse'. \"\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "import torch\n",
    "import h5py\n",
    "\n",
    "import pytorch_lightning as pl\n",
    "import numpy as np\n",
    "\n",
    "from utils import *\n",
    "from argparse import Namespace\n",
    "from torch import nn\n",
    "from torch_geometric.nn import MessagePassing, radius_graph, knn, knn_graph\n",
    "from sklearn.model_selection import train_test_split\n",
    "from pytorch_lightning.callbacks import ModelCheckpoint\n",
    "from models.backbones.mlp import MLP\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "\n",
    "sys.path.append(\"../\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f5f23b8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "hparams = Namespace(\n",
    "  # Optimization hyperparameters\n",
    "  factor=0.3,\n",
    "  step_size=50,\n",
    "  loss='l2',\n",
    "  lr=0.001,\n",
    "  weight_decay=0,\n",
    "  gpus=[1],\n",
    "  seed=1,\n",
    "  # Model hyperparameters\n",
    "  time_slice=25,\n",
    "  latent_dim=128,\n",
    "  num_message_passing_steps=5,\n",
    "  mlp_layers=4,\n",
    "  mlp_hidden=128,\n",
    "  radius=0.08,\n",
    "  neighbors=6,\n",
    "  n_chan=128,\n",
    "  teacher_forcing=True,\n",
    "  codec_neighbors=3,\n",
    "  noise=0,\n",
    "  interpolation='area',\n",
    "  dim=1,\n",
    "  samples=25\n",
    ")\n",
    "\n",
    "## setting parameters of ODE problem\n",
    "dirpath = 'save/MAgNet(E1_irregular)/'+str(hparams.seed)\n",
    "data_train = '1D/E1/irregular/CE_train_E1_graph_50.h5'\n",
    "data_test = '1D/E1/irregular/CE_test_E1_graph_50.h5'\n",
    "batch_size = 16\n",
    "num_workers = 20\n",
    "nt, nx, L = 250, 50, 16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "af00b224",
   "metadata": {},
   "outputs": [],
   "source": [
    "class HDF5DatasetImplicitGNN(Dataset):\n",
    "    \n",
    "    def __init__(self, \n",
    "                 path,\n",
    "                 nt,\n",
    "                 nx,\n",
    "                 mode='train', \n",
    "                 samples = 32):\n",
    "        \n",
    "        assert mode in ['train', 'valid', 'test'], \"mode must belong to one of these ['train', 'val', 'test']\"\n",
    "        \n",
    "        f = h5py.File(path, 'r')\n",
    "        self.mode = mode\n",
    "        self.data = f[self.mode]\n",
    "        self.dataset = f'pde_{nt}-{nx}'\n",
    "        self.samples = samples\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.data[self.dataset].shape[0]\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        \n",
    "        x = self.data['x'][idx]\n",
    "        # Normalize time coordinates\n",
    "        x = 2*(x-x.min())/(x.max()-x.min())-1\n",
    "        \n",
    "        t = self.data['t'][idx]\n",
    "        u_hr = torch.from_numpy(self.data[self.dataset][idx]).unsqueeze(1) # T, 1, L\n",
    "        T, _, L = u_hr.shape\n",
    "        u_lr = u_hr[:,:,::2] # T, 1, L//2\n",
    "        lr_coord = x[::2]\n",
    "        \n",
    "        \n",
    "#(Extrapolation)\n",
    "#        x = self.data['x'][idx, :125]  # slicing here\n",
    "        # Normalize time coordinates\n",
    "#        x = 2*(x-x.min())/(x.max()-x.min())-1\n",
    "        \n",
    "#        t = self.data['t'][idx, :125]  # slicing here\n",
    "#        u_hr = torch.from_numpy(self.data[self.dataset][idx, :125, :]).unsqueeze(1)  # slicing here\n",
    "#        T, _, L = u_hr.shape\n",
    "#        u_lr = u_hr[:,:,::2] # T, 1, L//2\n",
    "#        lr_coord = x[::2]    \n",
    "\n",
    "        if self.mode in ['train']:\n",
    "            indices_left = np.setdiff1d(np.arange(0,L), np.arange(0,L)[::2])\n",
    "            sample_lst = torch.tensor(sorted(np.random.choice(indices_left, self.samples, replace=False)))\n",
    "            hr_coord = x[sample_lst]\n",
    "\n",
    "            hr_points = u_hr[:,:,sample_lst].permute(0,2,1)\n",
    "\n",
    "            return_tensors = {\n",
    "            't': t,\n",
    "            'sample_idx': sample_lst,\n",
    "            'lr_frames': u_lr,\n",
    "            'hr_frames': u_hr,\n",
    "            'hr_points': hr_points, \n",
    "            'coords_hr': hr_coord,\n",
    "            'coords_lr': lr_coord\n",
    "            }\n",
    "        else:\n",
    "            indices_left = np.setdiff1d(np.arange(0,L), np.arange(0,L)[::2])\n",
    "            hr_coord = x[indices_left]\n",
    "\n",
    "            hr_points = u_hr[:,:,indices_left].permute(0,2,1)\n",
    "\n",
    "            return_tensors = {\n",
    "            't': t,\n",
    "            'lr_frames': u_lr,\n",
    "            'hr_frames': u_hr,\n",
    "            'hr_points': hr_points, \n",
    "            'coords_hr': hr_coord,\n",
    "            'coords_lr': lr_coord \n",
    "        }\n",
    "\n",
    "        return return_tensors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "1f7470c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def rel_L2_error(pred, true):\n",
    "    return (torch.sum((true-pred)**2, dim=-1)/torch.sum((true)**2, dim=-1))**0.5\n",
    "\n",
    "class Encoder(nn.Module):\n",
    "    def __init__(\n",
    "        self, \n",
    "        node_in, \n",
    "        node_out, \n",
    "        edge_in, \n",
    "        edge_out,\n",
    "        mlp_layers,\n",
    "        mlp_hidden,\n",
    "    ):\n",
    "        super(Encoder, self).__init__()\n",
    "        self.node_fn = nn.Sequential(\n",
    "            MLP(\n",
    "                in_dim=node_in, \n",
    "                hidden_list=[mlp_hidden]*mlp_layers, \n",
    "                out_dim=node_out),\n",
    "                nn.LayerNorm(node_out)\n",
    "        )\n",
    "        self.edge_fn = nn.Sequential(\n",
    "            MLP(\n",
    "                in_dim=edge_in, \n",
    "                hidden_list=[mlp_hidden]*mlp_layers, \n",
    "                out_dim=edge_out,\n",
    "            ),\n",
    "            nn.LayerNorm(edge_out)\n",
    "        )\n",
    "\n",
    "    def forward(self, x, edge_index, e_features): # global_features\n",
    "        # x: (E, node_in)\n",
    "        # edge_index: (2, E)\n",
    "        # e_features: (E, edge_in)\n",
    "        return self.node_fn(x), self.edge_fn(e_features)\n",
    "\n",
    "class InteractionNetwork(MessagePassing):\n",
    "    def __init__(\n",
    "        self, \n",
    "        node_in, \n",
    "        node_out, \n",
    "        edge_in, \n",
    "        edge_out,\n",
    "        mlp_layers,\n",
    "        mlp_hidden,\n",
    "    ):\n",
    "        super(InteractionNetwork, self).__init__(aggr='mean')\n",
    "        self.node_fn = nn.Sequential(\n",
    "            MLP(\n",
    "                in_dim=node_in+edge_out, \n",
    "                hidden_list=[mlp_hidden]*mlp_layers, \n",
    "                out_dim=node_out),\n",
    "                nn.LayerNorm(node_out))\n",
    "        self.edge_fn = nn.Sequential(\n",
    "            MLP(\n",
    "                in_dim=node_in+node_in+edge_in, \n",
    "                hidden_list=[mlp_hidden]*mlp_layers, \n",
    "                out_dim=edge_out\n",
    "            ),\n",
    "            nn.LayerNorm(edge_out)\n",
    "        )\n",
    "\n",
    "    def forward(self, x, edge_index, e_features):\n",
    "        # x: (E, node_in)\n",
    "        # edge_index: (2, E)\n",
    "        # e_features: (E, edge_in)\n",
    "        x_residual = x\n",
    "        e_features_residual = e_features\n",
    "        x, e_features = self.propagate(edge_index=edge_index, x=x, e_features=e_features)\n",
    "        return x+x_residual, e_features+e_features_residual\n",
    "\n",
    "    def message(self, edge_index, x_i, x_j, e_features):\n",
    "\n",
    "        e_features = torch.cat([x_i, x_j, e_features], dim=-1)\n",
    "        e_features = self.edge_fn(e_features)\n",
    "        return e_features\n",
    "\n",
    "    def update(self, x_updated, x, e_features):\n",
    "        # x_updated: (E, edge_out)\n",
    "        # x: (E, node_in)\n",
    "        x_updated = torch.cat([x_updated, x], dim=-1)\n",
    "        x_updated = self.node_fn(x_updated)\n",
    "        return x_updated, e_features\n",
    "\n",
    "class Processor(MessagePassing):\n",
    "    def __init__(\n",
    "        self, \n",
    "        node_in, \n",
    "        node_out, \n",
    "        edge_in, \n",
    "        edge_out,\n",
    "        num_message_passing_steps,\n",
    "        mlp_num_layers,\n",
    "        mlp_hidden_dim,\n",
    "    ):\n",
    "        super(Processor, self).__init__(aggr='max')\n",
    "        self.gnn_stacks = nn.ModuleList([\n",
    "            InteractionNetwork(\n",
    "                node_in=node_in, \n",
    "                node_out=node_out,\n",
    "                edge_in=edge_in, \n",
    "                edge_out=edge_out,\n",
    "                mlp_layers=mlp_num_layers,\n",
    "                mlp_hidden=mlp_hidden_dim,\n",
    "            ) for _ in range(num_message_passing_steps)])\n",
    "\n",
    "    def forward(self, x, edge_index, e_features):\n",
    "        for gnn in self.gnn_stacks:\n",
    "            x, e_features = gnn(x, edge_index, e_features)\n",
    "        return x, e_features\n",
    "\n",
    "class Decoder(nn.Module):\n",
    "    def __init__(\n",
    "        self, \n",
    "        node_in, \n",
    "        node_out,\n",
    "        mlp_layers,\n",
    "        mlp_hidden,\n",
    "    ):\n",
    "        super(Decoder, self).__init__()\n",
    "\n",
    "        self.node_fn = MLP(\n",
    "                in_dim=node_in, \n",
    "                hidden_list=[mlp_hidden]*mlp_layers, \n",
    "                out_dim=node_out)\n",
    "        \n",
    "\n",
    "    def forward(self, x):\n",
    "        # x: (E, node_in)\n",
    "        return self.node_fn(x)\n",
    "\n",
    "class MAgNetGNN(pl.LightningModule):\n",
    "    def __init__(self,hparams):\n",
    "    \n",
    "        super().__init__()\n",
    "        \n",
    "        self.save_hyperparameters()\n",
    "\n",
    "        # Training parameters\n",
    "        self.lr = hparams.lr\n",
    "        self.weight_decay = hparams.weight_decay\n",
    "        self.factor = hparams.factor\n",
    "        self.step_size = hparams.step_size\n",
    "        self.loss = hparams.loss\n",
    "        # Model parameters\n",
    "        self.time_slice = hparams.time_slice\n",
    "        self.num_message_passing_steps = hparams.num_message_passing_steps\n",
    "        self.latent_dim = hparams.latent_dim\n",
    "        self.mlp_layers = hparams.mlp_layers\n",
    "        self.mlp_hidden = hparams.mlp_hidden\n",
    "        self.n_chan = hparams.n_chan\n",
    "        self.radius = hparams.radius\n",
    "        self.codec_neighbors = hparams.codec_neighbors\n",
    "        self.teacher_forcing = hparams.teacher_forcing\n",
    "        self.noise = hparams.noise\n",
    "        self.interpolation = hparams.interpolation\n",
    "        self.neighbors = hparams.neighbors\n",
    "#        self.seed = hparams.seed\n",
    "#        pl.seed_everything(self.seed)\n",
    "        \n",
    "        if self.loss == 'l1':\n",
    "            self.criterion = nn.L1Loss()\n",
    "        elif self.loss == 'l2':\n",
    "            self.criterion = nn.MSELoss()\n",
    "        elif self.loss == 'smooth_l1':\n",
    "            self.criterion = nn.SmoothL1Loss()\n",
    "       \n",
    "        self.mse_criterion = nn.MSELoss()\n",
    "        self.mae_criterion = nn.L1Loss()\n",
    "\n",
    "        self.encoder = Encoder(\n",
    "            node_in=self.time_slice+2, \n",
    "            node_out=self.latent_dim,\n",
    "            edge_in=self.time_slice+1, \n",
    "            edge_out=self.latent_dim,\n",
    "            mlp_layers=self.mlp_layers,\n",
    "            mlp_hidden=self.mlp_hidden,\n",
    "        )\n",
    "        self.processor = Processor(\n",
    "            node_in=self.latent_dim, \n",
    "            node_out=self.latent_dim,\n",
    "            edge_in=self.latent_dim, \n",
    "            edge_out=self.latent_dim,\n",
    "            num_message_passing_steps=self.num_message_passing_steps,\n",
    "            mlp_num_layers=self.mlp_layers,\n",
    "            mlp_hidden_dim=self.mlp_hidden,\n",
    "        )\n",
    "\n",
    "        self.proj_head = nn.Linear(self.latent_dim+3, self.n_chan)\n",
    "        self.projector = MLP(\n",
    "                in_dim=self.n_chan, \n",
    "                hidden_list=[self.mlp_hidden]*self.mlp_layers, \n",
    "                out_dim=1)\n",
    "        \n",
    "        \n",
    "        self._encoder = Encoder(\n",
    "            node_in=self.time_slice+2, \n",
    "            node_out=self.latent_dim,\n",
    "            edge_in=self.time_slice+1, \n",
    "            edge_out=self.latent_dim,\n",
    "            mlp_layers=self.mlp_layers,\n",
    "            mlp_hidden=self.mlp_hidden,\n",
    "        )\n",
    "        self._processor = Processor(\n",
    "            node_in=self.latent_dim, \n",
    "            node_out=self.latent_dim,\n",
    "            edge_in=self.latent_dim, \n",
    "            edge_out=self.latent_dim,\n",
    "            num_message_passing_steps=self.num_message_passing_steps,\n",
    "            mlp_num_layers=self.mlp_layers,\n",
    "            mlp_hidden_dim=self.mlp_hidden,\n",
    "        )\n",
    "        self._decoder = Decoder(\n",
    "            node_in=self.latent_dim,\n",
    "            node_out=self.time_slice,\n",
    "            mlp_layers=self.mlp_layers,\n",
    "            mlp_hidden=self.mlp_hidden,\n",
    "        )\n",
    "    \n",
    "    def continuous_decoder(\n",
    "        self,\n",
    "        x_lr, \n",
    "        lr_encoded, \n",
    "        lr_coords, \n",
    "        hr_coords, \n",
    "        t):\n",
    "        '''\n",
    "        Args:\n",
    "            x_lr, [B, T, C, L]\n",
    "            lr_encoded, [B, L, C]: \n",
    "            lr_coords, [B, L, 1]\n",
    "            hr_coords, [B, N, 1]\n",
    "            t, [B, T]\n",
    "        '''\n",
    "        B, T, _, L = x_lr.shape\n",
    "        N = hr_coords.shape[1]\n",
    "\n",
    "        # Find nearest k low-res neighbors for each high-res coordinate (k=2 by default)\n",
    "        flat_lr_coords = lr_coords.reshape(B*L, -1)\n",
    "        batch_lr = torch.cat([torch.LongTensor([i]*L) for i in range(B)]).to(flat_lr_coords.device)\n",
    "        flat_hr_coords = hr_coords.reshape(B*N, -1)\n",
    "        batch_hr = torch.cat([torch.LongTensor([i]*N) for i in range(B)]).to(flat_hr_coords.device)\n",
    "        assign_index = knn(flat_lr_coords, flat_hr_coords, self.codec_neighbors, batch_lr, batch_hr)\n",
    "\n",
    "        lr_encoded_flat = lr_encoded.reshape(B*L, -1)\n",
    "        timesteps = t.unsqueeze(1).repeat(1,N,1) # B, N, T\n",
    "        timesteps = timesteps.reshape(B*N, -1) # B*N, T\n",
    "\n",
    "        out = []\n",
    "        for i in range(T):\n",
    "            weights = []\n",
    "            latents = []\n",
    "            x_lr_flat = x_lr[:,i].permute(0,2,1).reshape(B*L, -1)\n",
    "            timestep = timesteps[:,i:i+1]\n",
    "            for j in range(self.codec_neighbors):\n",
    "                q_feat = lr_encoded_flat[assign_index[1,j::self.codec_neighbors]]\n",
    "                q_inp = x_lr_flat[assign_index[1,j::self.codec_neighbors]]\n",
    "                q_coord = flat_lr_coords[assign_index[1,j::self.codec_neighbors]]\n",
    "                final_coord = q_coord-flat_hr_coords\n",
    "\n",
    "                final_input = torch.cat([q_feat, q_inp, final_coord, timestep], dim=-1)\n",
    "                if self.interpolation == 'area':\n",
    "                    weight = torch.norm(final_coord, 2, dim=-1)**2 # B*N, 1\n",
    "                    weight = weight.unsqueeze(-1)\n",
    "                elif self.interpolation == 'knn':\n",
    "                    weight = (1/(torch.norm(final_coord, 2, dim=-1)**2)).unsqueeze(-1)\n",
    "                elif self.interpolation == 'sph':\n",
    "                    weight = torch.pow(1 - (L*torch.norm(final_coord, 2, dim=-1)**2), 3).unsqueeze(-1)\n",
    "                latents.append(self.proj_head(final_input)) # B*N, C\n",
    "                weights.append(weight)\n",
    "            \n",
    "            if self.interpolation == 'area':\n",
    "                latent = (latents[0]*weights[1]+latents[1]*weights[0])/(weights[1]+weights[0])\n",
    "            else:\n",
    "                latent = (latents[0]*weights[0]+latents[1]*weights[1])/(weights[1]+weights[0])\n",
    "            out.append(latent)\n",
    "        \n",
    "        out = torch.stack(out, dim=1) # B*N, T, C\n",
    "        return out\n",
    "\n",
    "    \n",
    "    def _build_graph(self, u, x, t):\n",
    "        B, N, _ = u.shape\n",
    "\n",
    "        u_ = u.reshape(B*N, -1)\n",
    "        x_ = x.reshape(B*N, -1)\n",
    "\n",
    "        batch_ids = torch.cat([torch.LongTensor([i for _ in range(n)]) for i, n in enumerate(B*[N])]).to(self.device)\n",
    "        edges = radius_graph(x_, batch=batch_ids, r=self.radius, loop=True) # (2, n_edges)\n",
    "#        edges = knn_graph(x_, batch=batch_ids, k=self.neighbors, loop=True) # (2, n_edges)\n",
    "        receivers = edges[0, :]\n",
    "        senders = edges[1, :]\n",
    "        edge_index = torch.stack([senders, receivers])\n",
    "        \n",
    "        node_features = []\n",
    "        node_features.append(u_)\n",
    "        node_features.append(x_)\n",
    "        node_features.append(t[:,-1:].repeat(N, 1))\n",
    "        node_features = torch.cat(node_features, dim=-1)\n",
    "        \n",
    "        edge_features = []\n",
    "\n",
    "        edge_features.append((u_[senders]-u_[receivers]))\n",
    "        edge_features.append((x_[senders]-x_[receivers]))\n",
    "        edge_features = torch.cat(edge_features, dim=-1)\n",
    "\n",
    "        return node_features, edge_index, edge_features\n",
    "\n",
    "    def forward(\n",
    "        self, \n",
    "        x_lr,\n",
    "        lr_coords, \n",
    "        hr_coords, \n",
    "        t, \n",
    "        hr_last):\n",
    "        '''\n",
    "        Args:\n",
    "            x_lr: tensor of shape [B, T, C, L] that represents the low-resolution frames\n",
    "            lr_coords: tensor of shape [B, L, 1] that represents the L coordinates for sequence of low frames in the batch\n",
    "            hr_coords: tensor of shape [B, N, 1] that represents the N coordinates for sequence of points in the batch\n",
    "            t: tensor of shape [B, T] represents the time-coordinates for each sequence in the batch\n",
    "        '''\n",
    "        B, T, C, L = x_lr.shape\n",
    "        N = hr_coords.shape[1]\n",
    "        T_out = t.shape[1] - T\n",
    "\n",
    "        # Build graph and encode it\n",
    "        u = x_lr.permute(0,3,1,2) # B, L, T, C\n",
    "        u = u.reshape(B, L, -1) # B, L, C\n",
    "        node_features, edge_index, edge_features = self._build_graph(u, lr_coords, t[:,:T])\n",
    "        node_features, edge_features = self.encoder(node_features, edge_index, edge_features)\n",
    "        lr_encoded, _ = self.processor(node_features, edge_index, edge_features)\n",
    "\n",
    "        # Get interpolated features from low-res points\n",
    "        z = self.continuous_decoder(x_lr, lr_encoded, lr_coords, hr_coords, t) # B*N, T, C\n",
    "        hr_points = self.projector(z) # B*N, T, 1\n",
    "        \n",
    "        # Build Graph\n",
    "        hr_points = hr_points.reshape(B, N, T, -1) # B, N, T, C\n",
    "        hr_points = hr_points.reshape(B, N, -1) # B, N, C\n",
    "        lr_points = x_lr.permute(0,3,1,2) # B, L, T, C\n",
    "        lr_points = lr_points.reshape(B, L, -1) # B, L, C\n",
    "\n",
    "        all_coords = torch.cat([lr_coords, hr_coords], dim=1) # B, (L+N), 1\n",
    "\n",
    "        all_feats = torch.cat([lr_points, hr_points], dim=1) # B, (L+N), C\n",
    "\n",
    "        node_features, edge_index, edge_features = self._build_graph(all_feats, all_coords, t[:,:T])\n",
    "\n",
    "\n",
    "        node_features, edge_features = self._encoder(node_features, edge_index, edge_features)\n",
    "        node_features, _ = self._processor(node_features, edge_index, edge_features)\n",
    "        node_features = self._decoder(node_features) # B*(L+N), T_out\n",
    "        ret = node_features.reshape(B, -1, node_features.shape[-1]) # B, (L+N), T_out\n",
    "\n",
    "        outputs = []\n",
    "        tt = t.unsqueeze(1).repeat(1,L+N,1)\n",
    "\n",
    "        last_values = torch.cat([x_lr[:,-1].permute(0,2,1), hr_last], dim=1) # B, (L+N), 1\n",
    "\n",
    "        for i in range(T_out):\n",
    "            delta_t = tt[:,:,T+i:T+i+1]-tt[:,:,T-1:T]\n",
    "            op = ret[...,i].unsqueeze(-1) # B, (L+N), 1\n",
    "            outputs.append(last_values+delta_t*op)\n",
    "        \n",
    "        outputs = torch.stack(outputs, dim=1) # B, T, (L+N), 1\n",
    "\n",
    "        out_lr = outputs[:,:,:L]\n",
    "        out_hr = outputs[:,:,L:]\n",
    "        hr_points = hr_points.reshape(B, N, T, -1)\n",
    "        hr_points = hr_points.permute(0,2,1,3)\n",
    "\n",
    "        return out_hr, out_lr, hr_points\n",
    "    \n",
    "    def configure_optimizers(self):\n",
    "        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)\n",
    "        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.step_size, gamma=self.factor)\n",
    "        return {\n",
    "        \"optimizer\": optimizer,\n",
    "        \"lr_scheduler\": {\n",
    "            \"scheduler\": scheduler\n",
    "        },\n",
    "    }\n",
    "    \n",
    "    def training_step(self, train_batch, batch_idx):\n",
    "        t = train_batch['t'].float()\n",
    "        u = train_batch['lr_frames'].float()\n",
    "        u_values = train_batch['hr_points'].float()\n",
    "        coords = train_batch['coords_hr'].float()\n",
    "        lr_coords = train_batch['coords_lr'].float()\n",
    "\n",
    "        u_values_future = u_values[:,self.time_slice:] # B, T_future, N, 1\n",
    "        B, T_future = u_values_future.shape[:2]\n",
    "\n",
    "        u_values_hat = []\n",
    "        hr_values_hat = []\n",
    "\n",
    "        inp = u[:,:self.time_slice]\n",
    "        noise = self.noise*torch.randn(inp.shape).to(inp.device)\n",
    "        inp = inp+noise\n",
    "\n",
    "        hr_last = u_values[:,self.time_slice-1]\n",
    "        noise = self.noise*torch.randn(hr_last.shape).to(hr_last.device)\n",
    "        hr_last = hr_last+noise\n",
    "\n",
    "        for i in range(T_future//self.time_slice):\n",
    "            out_hr, out_lr, hr_points = self.forward(inp, lr_coords, coords, t[:,i*self.time_slice:(i+2)*self.time_slice], hr_last)\n",
    "            y_hat = torch.cat([out_hr, out_lr], dim=2)\n",
    "            u_values_hat.append(y_hat)\n",
    "            hr_values_hat.append(hr_points)\n",
    "        \n",
    "            if self.teacher_forcing:\n",
    "                inp = u[:,(i+1)*self.time_slice:(i+2)*self.time_slice] # B, T, C, L\n",
    "                hr_last = u_values[:,(i+2)*self.time_slice-1]\n",
    "            else:\n",
    "                inp = out_lr.permute(0,1,3,2)\n",
    "                hr_last = out_hr[:,-1]\n",
    "\n",
    "            noise = self.noise*torch.randn(inp.shape).to(inp.device)\n",
    "            inp = inp+noise\n",
    "\n",
    "            noise = self.noise*torch.randn(hr_last.shape).to(hr_last.device)\n",
    "            hr_last = hr_last+noise\n",
    "        \n",
    "        u_values_hat = torch.cat(u_values_hat, dim=1) # B, T_out, (N+L), 1 \n",
    "        hr_values_hat = torch.cat(hr_values_hat, dim=1) # B, T_in, N, 1\n",
    "        \n",
    "        target = torch.cat([u_values_future, u[:,self.time_slice:].permute(0,1,3,2)], dim=2)\n",
    "        loss = self.criterion(u_values_hat, target)+self.criterion(hr_values_hat, u_values[:,:-self.time_slice])\n",
    "        mae_loss = self.mae_criterion(u_values_hat, target)\n",
    "        interp_loss = self.mae_criterion(hr_values_hat, u_values[:,:-self.time_slice])\n",
    "        rel_error = rel_L2_error(u_values_hat.reshape(B, -1), target.reshape(B, -1))\n",
    "        rel_error = torch.mean(rel_error)\n",
    "        \n",
    "        self.log('train_loss', loss, prog_bar=True)\n",
    "        self.log('train_mae_loss', mae_loss, prog_bar=True)\n",
    "        self.log('train_interp_loss', interp_loss, prog_bar=True)\n",
    "        self.log('train_rel_error', rel_error, prog_bar=True)\n",
    "        \n",
    "        return loss\n",
    "\n",
    "    def validation_step(self, val_batch, batch_idx):\n",
    "        t = val_batch['t'].float()\n",
    "        u = val_batch['lr_frames'].float() # B, T, 1, L\n",
    "        u_values = val_batch['hr_points'].float()\n",
    "        coords = val_batch['coords_hr'].float()\n",
    "        lr_coords = val_batch['coords_lr'].float()\n",
    "\n",
    "        u_values_future = u_values[:,self.time_slice:] # B, T_future, N, 1\n",
    "        T_future = u_values_future.shape[1]\n",
    "\n",
    "        u_values_hat = []\n",
    "        inp = u[:,:self.time_slice]\n",
    "        hr_last = u_values[:,self.time_slice-1]\n",
    "\n",
    "        for i in range(T_future//self.time_slice):\n",
    "            out_hr, out_lr, _ = self.forward(\n",
    "                inp, \n",
    "                lr_coords, \n",
    "                coords, \n",
    "                t[:,i*self.time_slice:(i+2)*self.time_slice], \n",
    "                hr_last)\n",
    "            y_hat = torch.cat([out_hr, out_lr], dim=2)\n",
    "            u_values_hat.append(y_hat)\n",
    "            \n",
    "            inp = out_lr.permute(0,1,3,2)\n",
    "            hr_last = out_hr[:,-1]\n",
    "        \n",
    "        u_values_hat = torch.cat(u_values_hat, dim=1)\n",
    "        target = torch.cat([u_values_future, u[:,self.time_slice:].permute(0,1,3,2)], dim=2)\n",
    "        loss = self.criterion(u_values_hat, target)\n",
    "        mae_loss = self.mae_criterion(u_values_hat, target)\n",
    "        B=u_values_hat.shape[0]\n",
    "        rel_error = rel_L2_error(u_values_hat.reshape(B, -1), target.reshape(B, -1))\n",
    "        rel_error = torch.mean(rel_error)\n",
    "        \n",
    "        self.log('val_loss', loss, prog_bar=True)\n",
    "        self.log('val_mae_loss', mae_loss, prog_bar=True)\n",
    "        self.log('val_rel_error', rel_error, prog_bar=True)\n",
    "        \n",
    "        return loss \n",
    "\n",
    "    def test_step(self, test_batch, batch_idx):\n",
    "        t = test_batch['t'].float()\n",
    "        u = test_batch['lr_frames'].float() # B, T, 1, L\n",
    "        u_values = test_batch['hr_points'].float()\n",
    "        coords = test_batch['coords_hr'].float()\n",
    "        lr_coords = test_batch['coords_lr'].float()\n",
    "\n",
    "        u_values_future = u_values[:,self.time_slice:] # B, T_future, N, 1\n",
    "        T_future = u_values_future.shape[1]\n",
    "\n",
    "        u_values_hat = []\n",
    "        inp = u[:,:self.time_slice]\n",
    "        hr_last = u_values[:,self.time_slice-1]\n",
    "\n",
    "        for i in range(T_future//self.time_slice):\n",
    "            out_hr, out_lr, _ = self.forward(\n",
    "                inp, \n",
    "                lr_coords, \n",
    "                coords, \n",
    "                t[:,i*self.time_slice:(i+2)*self.time_slice], \n",
    "                hr_last)\n",
    "            y_hat = torch.cat([out_hr, out_lr], dim=2)\n",
    "            u_values_hat.append(y_hat)\n",
    "            \n",
    "            inp = out_lr.permute(0,1,3,2)\n",
    "            hr_last = out_hr[:,-1]\n",
    "        \n",
    "        u_values_hat = torch.cat(u_values_hat, dim=1)\n",
    "        target = torch.cat([u_values_future, u[:,self.time_slice:].permute(0,1,3,2)], dim=2)\n",
    "        loss = self.criterion(u_values_hat, target)\n",
    "        mae_loss = self.mae_criterion(u_values_hat, target)\n",
    "        B=u_values_hat.shape[0]\n",
    "        rel_error = rel_L2_error(u_values_hat.reshape(B, -1), target.reshape(B, -1))\n",
    "        rel_error = torch.mean(rel_error)\n",
    "        \n",
    "        self.log('test_loss', loss, prog_bar=True)\n",
    "        self.log('test_mae_loss', mae_loss, prog_bar=True)\n",
    "        self.log('test_rel_error', rel_error, prog_bar=True)\n",
    "        \n",
    "        return {'test_loss': loss, 'test_rel_error': rel_error}   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "14c87745",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Global seed set to 1\n"
     ]
    }
   ],
   "source": [
    "# Load the data in a single mode\n",
    "train_dataset = HDF5DatasetImplicitGNN(data_train, nt, nx, 'train', hparams.samples)\n",
    "# Split data into training and test datasets\n",
    "train_dataset, valid_dataset = train_test_split(train_dataset, test_size=1/16, random_state=42)\n",
    "# Assuming 'valid' data is available in the HDF5 file\n",
    "test_dataset = HDF5DatasetImplicitGNN(data_test, nt, nx, 'valid')\n",
    "\n",
    "# Create the dataloaders using the custom collate function\n",
    "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)\n",
    "valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)\n",
    "# Modify the test_loader to use the full dataset as a single batch\n",
    "test_loader = DataLoader(test_dataset, batch_size=len(test_dataset), num_workers=num_workers)\n",
    "\n",
    "# Set the training and validation dataloaders in the PyTorch Lightning module\n",
    "model = MAgNetGNN(hparams)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "32c1dcb8",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True, used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n",
      "\n",
      "  | Name          | Type      | Params\n",
      "--------------------------------------------\n",
      "0 | criterion     | MSELoss   | 0     \n",
      "1 | mse_criterion | MSELoss   | 0     \n",
      "2 | mae_criterion | L1Loss    | 0     \n",
      "3 | encoder       | Encoder   | 139 K \n",
      "4 | processor     | Processor | 1.1 M \n",
      "5 | proj_head     | Linear    | 16.9 K\n",
      "6 | projector     | MLP       | 66.2 K\n",
      "7 | _encoder      | Encoder   | 139 K \n",
      "8 | _processor    | Processor | 1.1 M \n",
      "9 | _decoder      | Decoder   | 69.3 K\n",
      "--------------------------------------------\n",
      "2.6 M     Trainable params\n",
      "0         Non-trainable params\n",
      "2.6 M     Total params\n",
      "10.318    Total estimated model params size (MB)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation sanity check: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Global seed set to 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9fb895d414bd4fc185bc46a24206bf7c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 9, global step 1199: val_rel_error reached 0.49354 (best 0.49354), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MAgNet(E1_irregular)/1/epoch=09.ckpt\" as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 19, global step 2399: val_rel_error reached 0.59875 (best 0.49354), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MAgNet(E1_irregular)/1/epoch=19.ckpt\" as top 2\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 29, global step 3599: val_rel_error reached 0.64423 (best 0.49354), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MAgNet(E1_irregular)/1/epoch=29.ckpt\" as top 3\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 39, global step 4799: val_rel_error reached 0.77284 (best 0.49354), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MAgNet(E1_irregular)/1/epoch=39.ckpt\" as top 4\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 49, global step 5999: val_rel_error reached 0.55347 (best 0.49354), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MAgNet(E1_irregular)/1/epoch=49.ckpt\" as top 5\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 59, global step 7199: val_rel_error reached 0.56771 (best 0.49354), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MAgNet(E1_irregular)/1/epoch=59.ckpt\" as top 6\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 69, global step 8399: val_rel_error reached 0.45660 (best 0.45660), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MAgNet(E1_irregular)/1/epoch=69.ckpt\" as top 7\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 79, global step 9599: val_rel_error reached 0.52843 (best 0.45660), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MAgNet(E1_irregular)/1/epoch=79.ckpt\" as top 8\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 89, global step 10799: val_rel_error reached 0.44665 (best 0.44665), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MAgNet(E1_irregular)/1/epoch=89.ckpt\" as top 9\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 99, global step 11999: val_rel_error reached 0.52734 (best 0.44665), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MAgNet(E1_irregular)/1/epoch=99.ckpt\" as top 10\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 109, global step 13199: val_rel_error reached 0.51629 (best 0.44665), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MAgNet(E1_irregular)/1/epoch=109.ckpt\" as top 11\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 119, global step 14399: val_rel_error reached 0.47354 (best 0.44665), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MAgNet(E1_irregular)/1/epoch=119.ckpt\" as top 12\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 129, global step 15599: val_rel_error reached 0.45247 (best 0.44665), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MAgNet(E1_irregular)/1/epoch=129.ckpt\" as top 13\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 139, global step 16799: val_rel_error reached 0.45503 (best 0.44665), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MAgNet(E1_irregular)/1/epoch=139.ckpt\" as top 14\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 149, global step 17999: val_rel_error reached 0.44648 (best 0.44648), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MAgNet(E1_irregular)/1/epoch=149.ckpt\" as top 15\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 159, global step 19199: val_rel_error reached 0.40394 (best 0.40394), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MAgNet(E1_irregular)/1/epoch=159.ckpt\" as top 16\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 169, global step 20399: val_rel_error reached 0.41850 (best 0.40394), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MAgNet(E1_irregular)/1/epoch=169.ckpt\" as top 17\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 179, global step 21599: val_rel_error reached 0.44256 (best 0.40394), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MAgNet(E1_irregular)/1/epoch=179.ckpt\" as top 18\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 189, global step 22799: val_rel_error reached 0.41950 (best 0.40394), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MAgNet(E1_irregular)/1/epoch=189.ckpt\" as top 19\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 199, global step 23999: val_rel_error reached 0.41509 (best 0.40394), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MAgNet(E1_irregular)/1/epoch=199.ckpt\" as top 20\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 209, global step 25199: val_rel_error reached 0.42098 (best 0.40394), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MAgNet(E1_irregular)/1/epoch=209.ckpt\" as top 21\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 219, global step 26399: val_rel_error reached 0.42312 (best 0.40394), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MAgNet(E1_irregular)/1/epoch=219.ckpt\" as top 22\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 229, global step 27599: val_rel_error reached 0.42149 (best 0.40394), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MAgNet(E1_irregular)/1/epoch=229.ckpt\" as top 23\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 239, global step 28799: val_rel_error reached 0.41716 (best 0.40394), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MAgNet(E1_irregular)/1/epoch=239.ckpt\" as top 24\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 249, global step 29999: val_rel_error reached 0.41211 (best 0.40394), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/save/MAgNet(E1_irregular)/1/epoch=249.ckpt\" as top 25\n",
      "Saving latest checkpoint...\n"
     ]
    }
   ],
   "source": [
    "# Define the checkpoint callback\n",
    "checkpoint_callback = ModelCheckpoint(\n",
    "    dirpath=dirpath,\n",
    "    filename=\"{epoch:02d}\",\n",
    "    save_top_k=-1,\n",
    "    every_n_epochs=10,\n",
    "    verbose=True,\n",
    "    monitor=\"val_rel_error\",\n",
    "    mode=\"min\",\n",
    "    save_last=True,\n",
    ")\n",
    "\n",
    "# Initialize the trainer\n",
    "trainer = pl.Trainer(\n",
    "    accelerator=\"gpu\",\n",
    "    max_epochs=250,\n",
    "    log_every_n_steps=20,\n",
    "    gpus=hparams.gpus, \n",
    "    default_root_dir=\"lightning_logs/GDON_1D\",\n",
    "    callbacks=[checkpoint_callback]\n",
    ")\n",
    "\n",
    "# Train the model with the created dataloaders\n",
    "trainer.fit(model, train_loader, valid_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "a7407a78",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average Test Loss: 0.3239303746377118, Average Test Relative Error: 0.49626575934235007\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(0.3239303746377118, 0.49626575934235007)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "## setting parameters of ODE problem\n",
    "dirpath = 'save/MAgNet(E1_irregular)/'+str(2)\n",
    "data_train = '1D/E1/irregular/CE_train_E1_graph_50.h5'\n",
    "data_test = '1D/E1/irregular/CE_test_E1_graph_50.h5'\n",
    "\n",
    "# Load the data in a single mode\n",
    "train_dataset = HDF5DatasetImplicitGNN(data_train, nt, nx, 'train', hparams.samples)\n",
    "# Split data into training and test datasets\n",
    "train_dataset, valid_dataset = train_test_split(train_dataset, test_size=1/16, random_state=42)\n",
    "# Assuming 'valid' data is available in the HDF5 file\n",
    "test_dataset = HDF5DatasetImplicitGNN(data_test, nt, nx, 'valid')\n",
    "\n",
    "# Create the dataloaders using the custom collate function\n",
    "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)\n",
    "valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=num_workers)\n",
    "# Modify the test_loader to use the full dataset as a single batch\n",
    "test_loader = DataLoader(test_dataset, batch_size=len(test_dataset), num_workers=num_workers)\n",
    "\n",
    "def compute_test_error(model, test_loader):\n",
    "    # Move model to CPU\n",
    "    model = model.to(\"cpu\")\n",
    "    \n",
    "    # Ensure model is in eval mode\n",
    "    model.eval()\n",
    "    \n",
    "    # Accumulators for loss and relative error\n",
    "    total_test_loss = 0.0\n",
    "    total_test_rel_error = 0.0\n",
    "    \n",
    "    # Disable gradients to save memory\n",
    "    with torch.no_grad():\n",
    "        num_batches = 0\n",
    "        for batch in test_loader:\n",
    "            results = model.test_step(batch, batch_idx=num_batches)\n",
    "            total_test_loss += results['test_loss'].item()\n",
    "            total_test_rel_error += results['test_rel_error'].item()\n",
    "            num_batches += 1\n",
    "            \n",
    "    # Compute average values\n",
    "    avg_test_loss = total_test_loss / num_batches\n",
    "    avg_test_rel_error = total_test_rel_error / num_batches\n",
    "            \n",
    "    # Print or return the results\n",
    "    print(f\"Average Test Loss: {avg_test_loss}, Average Test Relative Error: {avg_test_rel_error}\")\n",
    "    return avg_test_loss, avg_test_rel_error\n",
    "\n",
    "checkpoint_path = dirpath+'/epoch=249.ckpt'\n",
    "model = MAgNetGNN.load_from_checkpoint(checkpoint_path)\n",
    "\n",
    "# Call the function to compute the test error\n",
    "compute_test_error(model, valid_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98615e01",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "graph",
   "language": "python",
   "name": "graph"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
