{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "cdac2e17",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/ex_hdd/graph/lib/python3.8/site-packages/torch_geometric/typing.py:18: UserWarning: An issue occurred while importing 'pyg-lib'. Disabling its usage. Stacktrace: /ex_hdd/graph/lib/python3.8/site-packages/libpyg.so: undefined symbol: _ZNK5torch8autograd4Node4nameB5cxx11Ev\n",
      "  warnings.warn(f\"An issue occurred while importing 'pyg-lib'. \"\n",
      "/ex_hdd/graph/lib/python3.8/site-packages/torch_geometric/typing.py:42: UserWarning: An issue occurred while importing 'torch-sparse'. Disabling its usage. Stacktrace: /ex_hdd/graph/lib/python3.8/site-packages/libpyg.so: undefined symbol: _ZNK5torch8autograd4Node4nameB5cxx11Ev\n",
      "  warnings.warn(f\"An issue occurred while importing 'torch-sparse'. \"\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import sys\n",
    "import random\n",
    "import h5py\n",
    "\n",
    "import numpy as np\n",
    "import pytorch_lightning as pl\n",
    "\n",
    "from utils import *\n",
    "from torch import nn\n",
    "from argparse import Namespace\n",
    "from sklearn.model_selection import train_test_split\n",
    "from torch_geometric.nn import MessagePassing, radius_graph, knn, knn_graph\n",
    "from pytorch_lightning.callbacks import ModelCheckpoint\n",
    "from models.backbones.mlp import MLP\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "\n",
    "sys.path.append(\"../\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "466b0309",
   "metadata": {},
   "outputs": [],
   "source": [
    "hparams = Namespace(  \n",
    "  # Optimization hyperparameters\n",
    "  factor=0.3,\n",
    "  step_size=50,\n",
    "  loss='l1',\n",
    "  lr=0.001,\n",
    "  weight_decay=0,\n",
    "  dim=2,\n",
    "  seed=0,\n",
    "  gpus=[0],\n",
    "  # Model hyperparameters\n",
    "  time_slice=10,\n",
    "  latent_dim=128,\n",
    "  num_message_passing_steps=5,\n",
    "  mlp_layers=4,\n",
    "  mlp_hidden=128,\n",
    "  radius=0.08,\n",
    "  neighbors=8,\n",
    "  n_chan=128,\n",
    "  teacher_forcing=True,\n",
    "  codec_neighbors=4,\n",
    "  noise=0,\n",
    "  interpolation='area'\n",
    "  samples=32\n",
    ")\n",
    "\n",
    "dirpath = 'MAgNet(shallow_irregular)'\n",
    "data = '2D/Shallow/shallow.h5'\n",
    "batch_size = 8 \n",
    "num_workers = 20\n",
    "nt, nx, L = 50, 32, 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "028a40eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "class HDF5DatasetImplicitGNN_2d_irregular(Dataset): \n",
    "    def __init__(self, path, nt, res, mode='train', samples=256, seed=0):\n",
    "        assert mode in ['train', 'valid', 'test'], \"mode must belong to one of these ['train', 'val', 'test']\"\n",
    "        \n",
    "        self.f = h5py.File(path, 'r')\n",
    "        self.mode = mode\n",
    "        self.nt = nt\n",
    "        self.res = res\n",
    "        self.samples = samples\n",
    "\n",
    "        # Generate keys from '0000' to '0999'\n",
    "        all_keys = [str(i).zfill(4) for i in range(1000)]\n",
    "\n",
    "        # split keys into train, valid, and test sets\n",
    "        train_valid_keys, self.test_keys = train_test_split(all_keys, test_size=0.2, random_state=42)\n",
    "        self.train_keys, self.valid_keys = train_test_split(train_valid_keys, test_size=0.25, random_state=42)  # Taking 20% of 80% -> 16% of total as validation\n",
    "    \n",
    "        if self.mode == 'train':\n",
    "            self.keys = self.train_keys\n",
    "        elif self.mode == 'test':\n",
    "            self.keys = self.test_keys\n",
    "        else:  # For 'valid' mode\n",
    "            self.keys = self.valid_keys\n",
    "            \n",
    "        key = self.keys[0]  \n",
    "        W = len(self.f[key]['grid']['x'][:])\n",
    "\n",
    "        # Set seed and sample 1024 coordinates randomly\n",
    "        random.seed(seed)\n",
    "        self.sampled_coords = random.sample([(i, j) for i in range(W) for j in range(W)], 32*32)\n",
    "    def __len__(self):\n",
    "        return len(self.keys)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        key = str(idx).zfill(4)\n",
    "        \n",
    "        data = self.f[key]['data'][:-1,:, :, :]  # shape: (101, 128, 128, 1)\n",
    "        data = torch.from_numpy(data).squeeze(-1)  # Shape: (101, 128, 128)\n",
    "        u_hr = data[:, [coord[0] for coord in self.sampled_coords], [coord[1] for coord in self.sampled_coords]]  # Shape: (101, 32, 32)\n",
    "\n",
    "        grid = self.f[key]['grid']    \n",
    "\n",
    "        t = grid['t'][:-1]  # shape: (101,)   \n",
    "        x_full = self.f[key]['grid']['x'][:]\n",
    "        y_full = self.f[key]['grid']['y'][:]\n",
    "\n",
    "        coords = np.array([[x_full[i], y_full[j]] for i, j in self.sampled_coords])  # Shape: (1024, 2)\n",
    "      \n",
    "        u_hr = u_hr.reshape(u_hr.shape[0], 1, -1)  # reshape to (101, 1, 1024)\n",
    "        \n",
    "        coords = 2*(coords-coords.min(0))/(coords.max(0)-coords.min(0))-1  # normalize coordinates    \n",
    "        \n",
    "        T, _, N = u_hr.shape\n",
    "        u_lr = u_hr[:,:,::2] # T, 1, N//2\n",
    "        lr_coord = coords[::2]\n",
    "        \n",
    "        if self.mode in ['train']:\n",
    "            indices_left = np.setdiff1d(np.arange(0,N), np.arange(0,N)[::2])\n",
    "            sample_lst = torch.tensor(sorted(np.random.choice(indices_left, self.samples, replace=False)))\n",
    "            hr_coord = coords[sample_lst]\n",
    "            hr_points = u_hr[:,:,sample_lst].permute(0,2,1)\n",
    "\n",
    "            return_tensors = {\n",
    "                't': t,\n",
    "                'sample_idx': sample_lst,\n",
    "                'lr_frames': u_lr,\n",
    "                'hr_frames': u_hr,\n",
    "                'hr_points': hr_points, \n",
    "                'coords_hr': hr_coord,\n",
    "                'coords_lr': lr_coord\n",
    "            }\n",
    "        else:\n",
    "            indices_left = np.setdiff1d(np.arange(0,N), np.arange(0,N)[::2])\n",
    "            hr_coord = coords[indices_left]\n",
    "            hr_points = u_hr[:,:,indices_left].permute(0,2,1)\n",
    "\n",
    "            return_tensors = {\n",
    "                't': t,\n",
    "                'lr_frames': u_lr,\n",
    "                'hr_frames': u_hr,\n",
    "                'hr_points': hr_points, \n",
    "                'coords_hr': hr_coord,\n",
    "                'coords_lr': lr_coord \n",
    "            } \n",
    "\n",
    "        return return_tensors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1bc880cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "def rel_L2_error(pred, true):\n",
    "    return (torch.sum((true-pred)**2, dim=-1)/torch.sum((true)**2, dim=-1))**0.5\n",
    "\n",
    "class Encoder(nn.Module):\n",
    "    def __init__(\n",
    "        self, \n",
    "        node_in, \n",
    "        node_out, \n",
    "        edge_in, \n",
    "        edge_out,\n",
    "        mlp_layers,\n",
    "        mlp_hidden,\n",
    "    ):\n",
    "        super(Encoder, self).__init__()\n",
    "        self.node_fn = nn.Sequential(\n",
    "            MLP(\n",
    "                in_dim=node_in, \n",
    "                hidden_list=[mlp_hidden]*mlp_layers, \n",
    "                out_dim=node_out),\n",
    "                nn.LayerNorm(node_out)\n",
    "        )\n",
    "        self.edge_fn = nn.Sequential(\n",
    "            MLP(\n",
    "                in_dim=edge_in, \n",
    "                hidden_list=[mlp_hidden]*mlp_layers, \n",
    "                out_dim=edge_out,\n",
    "            ),\n",
    "            nn.LayerNorm(edge_out)\n",
    "        )\n",
    "\n",
    "    def forward(self, x, edge_index, e_features): # global_features\n",
    "        # x: (E, node_in)\n",
    "        # edge_index: (2, E)\n",
    "        # e_features: (E, edge_in)\n",
    "        return self.node_fn(x), self.edge_fn(e_features)\n",
    "\n",
    "class InteractionNetwork(MessagePassing):\n",
    "    def __init__(\n",
    "        self, \n",
    "        node_in, \n",
    "        node_out, \n",
    "        edge_in, \n",
    "        edge_out,\n",
    "        mlp_layers,\n",
    "        mlp_hidden,\n",
    "    ):\n",
    "        super(InteractionNetwork, self).__init__(aggr='mean')\n",
    "        self.node_fn = nn.Sequential(\n",
    "            MLP(\n",
    "                in_dim=node_in+edge_out, \n",
    "                hidden_list=[mlp_hidden]*mlp_layers, \n",
    "                out_dim=node_out),\n",
    "                nn.LayerNorm(node_out))\n",
    "        self.edge_fn = nn.Sequential(\n",
    "            MLP(\n",
    "                in_dim=node_in+node_in+edge_in, \n",
    "                hidden_list=[mlp_hidden]*mlp_layers, \n",
    "                out_dim=edge_out\n",
    "            ),\n",
    "            nn.LayerNorm(edge_out)\n",
    "        )\n",
    "\n",
    "    def forward(self, x, edge_index, e_features):\n",
    "        # x: (E, node_in)\n",
    "        # edge_index: (2, E)\n",
    "        # e_features: (E, edge_in)\n",
    "        x_residual = x\n",
    "        e_features_residual = e_features\n",
    "        x, e_features = self.propagate(edge_index=edge_index, x=x, e_features=e_features)\n",
    "        return x+x_residual, e_features+e_features_residual\n",
    "\n",
    "    def message(self, edge_index, x_i, x_j, e_features):\n",
    "\n",
    "        e_features = torch.cat([x_i, x_j, e_features], dim=-1)\n",
    "        e_features = self.edge_fn(e_features)\n",
    "        return e_features\n",
    "\n",
    "    def update(self, x_updated, x, e_features):\n",
    "        # x_updated: (E, edge_out)\n",
    "        # x: (E, node_in)\n",
    "        x_updated = torch.cat([x_updated, x], dim=-1)\n",
    "        x_updated = self.node_fn(x_updated)\n",
    "        return x_updated, e_features\n",
    "\n",
    "class Processor(MessagePassing):\n",
    "    def __init__(\n",
    "        self, \n",
    "        node_in, \n",
    "        node_out, \n",
    "        edge_in, \n",
    "        edge_out,\n",
    "        num_message_passing_steps,\n",
    "        mlp_num_layers,\n",
    "        mlp_hidden_dim,\n",
    "    ):\n",
    "        super(Processor, self).__init__(aggr='max')\n",
    "        self.gnn_stacks = nn.ModuleList([\n",
    "            InteractionNetwork(\n",
    "                node_in=node_in, \n",
    "                node_out=node_out,\n",
    "                edge_in=edge_in, \n",
    "                edge_out=edge_out,\n",
    "                mlp_layers=mlp_num_layers,\n",
    "                mlp_hidden=mlp_hidden_dim,\n",
    "            ) for _ in range(num_message_passing_steps)])\n",
    "\n",
    "    def forward(self, x, edge_index, e_features):\n",
    "        for gnn in self.gnn_stacks:\n",
    "            x, e_features = gnn(x, edge_index, e_features)\n",
    "        return x, e_features\n",
    "\n",
    "class Decoder(nn.Module):\n",
    "    def __init__(\n",
    "        self, \n",
    "        node_in, \n",
    "        node_out,\n",
    "        mlp_layers,\n",
    "        mlp_hidden,\n",
    "    ):\n",
    "        super(Decoder, self).__init__()\n",
    "\n",
    "        self.node_fn = MLP(\n",
    "                in_dim=node_in, \n",
    "                hidden_list=[mlp_hidden]*mlp_layers, \n",
    "                out_dim=node_out)\n",
    "        \n",
    "\n",
    "    def forward(self, x):\n",
    "        # x: (E, node_in)\n",
    "        return self.node_fn(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c35f17b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MAgNetGNN2d(pl.LightningModule):\n",
    "    def __init__(self,hparams):\n",
    "    \n",
    "        super().__init__()\n",
    "        \n",
    "        self.save_hyperparameters()\n",
    "\n",
    "        # Training parameters\n",
    "        self.lr = hparams.lr\n",
    "        self.weight_decay = hparams.weight_decay\n",
    "        self.factor = hparams.factor\n",
    "        self.step_size = hparams.step_size\n",
    "        self.loss = hparams.loss\n",
    "        # Model parameters\n",
    "        self.time_slice = hparams.time_slice\n",
    "        self.num_message_passing_steps = hparams.num_message_passing_steps\n",
    "        self.latent_dim = hparams.latent_dim\n",
    "        self.mlp_layers = hparams.mlp_layers\n",
    "        self.mlp_hidden = hparams.mlp_hidden\n",
    "        self.n_chan = hparams.n_chan\n",
    "        self.radius = hparams.radius\n",
    "        self.codec_neighbors = hparams.codec_neighbors\n",
    "        self.teacher_forcing = hparams.teacher_forcing\n",
    "        self.noise = hparams.noise\n",
    "        self.interpolation = hparams.interpolation\n",
    "        self.neighbors = hparams.neighbors\n",
    "        \n",
    "        if self.loss == 'l1':\n",
    "            self.criterion = nn.L1Loss()\n",
    "        elif self.loss == 'l2':\n",
    "            self.criterion = nn.MSELoss()\n",
    "        elif self.loss == 'smooth_l1':\n",
    "            self.criterion = nn.SmoothL1Loss()\n",
    "       \n",
    "        self.mse_criterion = nn.MSELoss()\n",
    "        self.mae_criterion = nn.L1Loss()\n",
    "\n",
    "        self.encoder = Encoder(\n",
    "            node_in=self.time_slice+3, \n",
    "            node_out=self.latent_dim,\n",
    "            edge_in=self.time_slice+2, \n",
    "            edge_out=self.latent_dim,\n",
    "            mlp_layers=self.mlp_layers,\n",
    "            mlp_hidden=self.mlp_hidden,\n",
    "        )\n",
    "        self.processor = Processor(\n",
    "            node_in=self.latent_dim, \n",
    "            node_out=self.latent_dim,\n",
    "            edge_in=self.latent_dim, \n",
    "            edge_out=self.latent_dim,\n",
    "            num_message_passing_steps=self.num_message_passing_steps,\n",
    "            mlp_num_layers=self.mlp_layers,\n",
    "            mlp_hidden_dim=self.mlp_hidden,\n",
    "        )\n",
    "\n",
    "        self.proj_head = nn.Linear(self.latent_dim+4, self.n_chan)\n",
    "        self.projector = MLP(\n",
    "                in_dim=self.n_chan, \n",
    "                hidden_list=[self.mlp_hidden]*self.mlp_layers, \n",
    "                out_dim=1)\n",
    "        \n",
    "        \n",
    "        self._encoder = Encoder(\n",
    "            node_in=self.time_slice+3, \n",
    "            node_out=self.latent_dim,\n",
    "            edge_in=self.time_slice+2, \n",
    "            edge_out=self.latent_dim,\n",
    "            mlp_layers=self.mlp_layers,\n",
    "            mlp_hidden=self.mlp_hidden,\n",
    "        )\n",
    "        self._processor = Processor(\n",
    "            node_in=self.latent_dim, \n",
    "            node_out=self.latent_dim,\n",
    "            edge_in=self.latent_dim, \n",
    "            edge_out=self.latent_dim,\n",
    "            num_message_passing_steps=self.num_message_passing_steps,\n",
    "            mlp_num_layers=self.mlp_layers,\n",
    "            mlp_hidden_dim=self.mlp_hidden,\n",
    "        )\n",
    "        self._decoder = Decoder(\n",
    "            node_in=self.latent_dim,\n",
    "            node_out=self.time_slice,\n",
    "            mlp_layers=self.mlp_layers,\n",
    "            mlp_hidden=self.mlp_hidden,\n",
    "        )\n",
    "    \n",
    "    def continuous_decoder(\n",
    "        self,\n",
    "        x_lr, \n",
    "        lr_encoded, \n",
    "        lr_coords, \n",
    "        hr_coords, \n",
    "        t):\n",
    "        '''\n",
    "        Args:\n",
    "            x_lr, [B, T, C, L]\n",
    "            lr_encoded, [B, L, C]: \n",
    "            lr_coords, [B, L, 1]\n",
    "            hr_coords, [B, N, 1]\n",
    "            t, [B, T]\n",
    "        '''\n",
    "        B, T, _, L = x_lr.shape\n",
    "        N = hr_coords.shape[1]\n",
    "\n",
    "        # Find nearest k low-res neighbors for each high-res coordinate (k=2 by default)\n",
    "        flat_lr_coords = lr_coords.reshape(B*L, -1)\n",
    "        batch_lr = torch.cat([torch.LongTensor([i]*L) for i in range(B)]).to(flat_lr_coords.device)\n",
    "        flat_hr_coords = hr_coords.reshape(B*N, -1)\n",
    "        batch_hr = torch.cat([torch.LongTensor([i]*N) for i in range(B)]).to(flat_hr_coords.device)\n",
    "        assign_index = knn(flat_lr_coords, flat_hr_coords, self.codec_neighbors, batch_lr, batch_hr)\n",
    "\n",
    "        lr_encoded_flat = lr_encoded.reshape(B*L, -1)\n",
    "        timesteps = t.unsqueeze(1).repeat(1,N,1) # B, N, T\n",
    "        timesteps = timesteps.reshape(B*N, -1) # B*N, T\n",
    "\n",
    "        out = []\n",
    "        for i in range(T):\n",
    "            weights = []\n",
    "            latents = []\n",
    "            x_lr_flat = x_lr[:,i].permute(0,2,1).reshape(B*L, -1)\n",
    "            timestep = timesteps[:,i:i+1]\n",
    "            for j in range(self.codec_neighbors):\n",
    "                q_feat = lr_encoded_flat[assign_index[1,j::self.codec_neighbors]]\n",
    "                q_inp = x_lr_flat[assign_index[1,j::self.codec_neighbors]]\n",
    "                q_coord = flat_lr_coords[assign_index[1,j::self.codec_neighbors]]\n",
    "                final_coord = q_coord-flat_hr_coords\n",
    "\n",
    "                final_input = torch.cat([q_feat, q_inp, final_coord, timestep], dim=-1)\n",
    "                if self.interpolation == 'area':\n",
    "                    weight = torch.norm(final_coord, 2, dim=-1)**2 # B*N, 1\n",
    "                    weight = weight.unsqueeze(-1)\n",
    "                elif self.interpolation == 'knn':\n",
    "                    weight = (1/(torch.norm(final_coord, 2, dim=-1)**2)).unsqueeze(-1)\n",
    "                elif self.interpolation == 'sph':\n",
    "                    weight = torch.pow(1 - (L*torch.norm(final_coord, 2, dim=-1)**2), 3).unsqueeze(-1)\n",
    "                latents.append(self.proj_head(final_input)) # B*N, C\n",
    "                weights.append(weight)\n",
    "            \n",
    "            if self.interpolation == 'area':\n",
    "                latent = (latents[0]*weights[1]+latents[1]*weights[0])/(weights[1]+weights[0])\n",
    "            else:\n",
    "                latent = (latents[0]*weights[0]+latents[1]*weights[1])/(weights[1]+weights[0])\n",
    "            out.append(latent)\n",
    "        \n",
    "        out = torch.stack(out, dim=1) # B*N, T, C\n",
    "        return out\n",
    "\n",
    "    \n",
    "    def _build_graph(self, u, x, t):\n",
    "        B, N, _ = u.shape\n",
    "\n",
    "        u_ = u.reshape(B*N, -1)\n",
    "        x_ = x.reshape(B*N, -1)\n",
    "\n",
    "        batch_ids = torch.cat([torch.LongTensor([i for _ in range(n)]) for i, n in enumerate(B*[N])]).to(self.device)\n",
    "#        edges = knn_graph(x_, batch=batch_ids, k=self.neighbors, loop=True) # (2, n_edges)\n",
    "        edges = radius_graph(x_, batch=batch_ids, r=self.radius, loop=True)\n",
    "        receivers = edges[0, :]\n",
    "        senders = edges[1, :]\n",
    "        edge_index = torch.stack([senders, receivers])\n",
    "        \n",
    "        node_features = []\n",
    "        node_features.append(u_)\n",
    "        node_features.append(x_)\n",
    "        node_features.append(t[:,-1:].repeat(N, 1))\n",
    "        node_features = torch.cat(node_features, dim=-1)\n",
    "        \n",
    "        edge_features = []\n",
    "\n",
    "        edge_features.append((u_[senders]-u_[receivers]))\n",
    "        edge_features.append((x_[senders]-x_[receivers]))\n",
    "        edge_features = torch.cat(edge_features, dim=-1)\n",
    "\n",
    "        return node_features, edge_index, edge_features\n",
    "\n",
    "    def forward(\n",
    "        self, \n",
    "        x_lr,\n",
    "        lr_coords, \n",
    "        hr_coords, \n",
    "        t, \n",
    "        hr_last):\n",
    "        '''\n",
    "        Args:\n",
    "            x_lr: tensor of shape [B, T, C, L] that represents the low-resolution frames\n",
    "            lr_coords: tensor of shape [B, L, 1] that represents the L coordinates for sequence of low frames in the batch\n",
    "            hr_coords: tensor of shape [B, N, 1] that represents the N coordinates for sequence of points in the batch\n",
    "            t: tensor of shape [B, T] represents the time-coordinates for each sequence in the batch\n",
    "        '''\n",
    "        B, T, C, L = x_lr.shape\n",
    "        N = hr_coords.shape[1]\n",
    "        T_out = t.shape[1] - T\n",
    "\n",
    "        # Build graph and encode it\n",
    "        u = x_lr.permute(0,3,1,2) # B, L, T, C\n",
    "        u = u.reshape(B, L, -1) # B, L, C\n",
    "        node_features, edge_index, edge_features = self._build_graph(u, lr_coords, t[:,:T])\n",
    "        node_features, edge_features = self.encoder(node_features, edge_index, edge_features)\n",
    "        lr_encoded, _ = self.processor(node_features, edge_index, edge_features)\n",
    "\n",
    "        # Get interpolated features from low-res points\n",
    "        z = self.continuous_decoder(x_lr, lr_encoded, lr_coords, hr_coords, t) # B*N, T, C\n",
    "        hr_points = self.projector(z) # B*N, T, 1\n",
    "        \n",
    "        # Build Graph\n",
    "        hr_points = hr_points.reshape(B, N, T, -1) # B, N, T, C\n",
    "        hr_points = hr_points.reshape(B, N, -1) # B, N, C\n",
    "        lr_points = x_lr.permute(0,3,1,2) # B, L, T, C\n",
    "        lr_points = lr_points.reshape(B, L, -1) # B, L, C\n",
    "\n",
    "        all_coords = torch.cat([lr_coords, hr_coords], dim=1) # B, (L+N), 1\n",
    "\n",
    "        all_feats = torch.cat([lr_points, hr_points], dim=1) # B, (L+N), C\n",
    "\n",
    "        node_features, edge_index, edge_features = self._build_graph(all_feats, all_coords, t[:,:T])\n",
    "\n",
    "\n",
    "        node_features, edge_features = self._encoder(node_features, edge_index, edge_features)\n",
    "        node_features, _ = self._processor(node_features, edge_index, edge_features)\n",
    "        node_features = self._decoder(node_features) # B*(L+N), T_out\n",
    "        ret = node_features.reshape(B, -1, node_features.shape[-1]) # B, (L+N), T_out\n",
    "\n",
    "        outputs = []\n",
    "        tt = t.unsqueeze(1).repeat(1,L+N,1)\n",
    "\n",
    "        last_values = torch.cat([x_lr[:,-1].permute(0,2,1), hr_last], dim=1) # B, (L+N), 1\n",
    "\n",
    "        for i in range(T_out):\n",
    "            delta_t = tt[:,:,T+i:T+i+1]-tt[:,:,T-1:T]\n",
    "            op = ret[...,i].unsqueeze(-1) # B, (L+N), 1\n",
    "            outputs.append(last_values+delta_t*op)\n",
    "        \n",
    "        outputs = torch.stack(outputs, dim=1) # B, T, (L+N), 1\n",
    "\n",
    "        out_lr = outputs[:,:,:L]\n",
    "        out_hr = outputs[:,:,L:]\n",
    "        hr_points = hr_points.reshape(B, N, T, -1)\n",
    "        hr_points = hr_points.permute(0,2,1,3)\n",
    "\n",
    "        return out_hr, out_lr, hr_points\n",
    "    \n",
    "    def configure_optimizers(self):\n",
    "        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)\n",
    "        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.step_size, gamma=self.factor)\n",
    "        return {\n",
    "        \"optimizer\": optimizer,\n",
    "        \"lr_scheduler\": {\n",
    "            \"scheduler\": scheduler\n",
    "        },\n",
    "    }\n",
    "    \n",
    "    def training_step(self, train_batch, batch_idx):\n",
    "        t = train_batch['t'].float()\n",
    "        u = train_batch['lr_frames'].float()\n",
    "        u_values = train_batch['hr_points'].float()\n",
    "        coords = train_batch['coords_hr'].float()\n",
    "        lr_coords = train_batch['coords_lr'].float()\n",
    "\n",
    "        u_values_future = u_values[:,self.time_slice:] # B, T_future, N, 1\n",
    "        B, T_future = u_values_future.shape[:2]\n",
    "\n",
    "        u_values_hat = []\n",
    "        hr_values_hat = []\n",
    "\n",
    "        inp = u[:,:self.time_slice]\n",
    "        noise = self.noise*torch.randn(inp.shape).to(inp.device)\n",
    "        inp = inp+noise\n",
    "\n",
    "        hr_last = u_values[:,self.time_slice-1]\n",
    "        noise = self.noise*torch.randn(hr_last.shape).to(hr_last.device)\n",
    "        hr_last = hr_last+noise\n",
    "\n",
    "        for i in range(T_future//self.time_slice):\n",
    "            out_hr, out_lr, hr_points = self.forward(inp, lr_coords, coords, t[:,i*self.time_slice:(i+2)*self.time_slice], hr_last)\n",
    "            y_hat = torch.cat([out_hr, out_lr], dim=2)\n",
    "            u_values_hat.append(y_hat)\n",
    "            hr_values_hat.append(hr_points)\n",
    "        \n",
    "            if self.teacher_forcing:\n",
    "                inp = u[:,(i+1)*self.time_slice:(i+2)*self.time_slice] # B, T, C, L\n",
    "                hr_last = u_values[:,(i+2)*self.time_slice-1]\n",
    "            else:\n",
    "                inp = out_lr.permute(0,1,3,2)\n",
    "                hr_last = out_hr[:,-1]\n",
    "\n",
    "            noise = self.noise*torch.randn(inp.shape).to(inp.device)\n",
    "            inp = inp+noise\n",
    "\n",
    "            noise = self.noise*torch.randn(hr_last.shape).to(hr_last.device)\n",
    "            hr_last = hr_last+noise\n",
    "        \n",
    "        u_values_hat = torch.cat(u_values_hat, dim=1) # B, T_out, (N+L), 1 \n",
    "        hr_values_hat = torch.cat(hr_values_hat, dim=1) # B, T_in, N, 1\n",
    "        \n",
    "        target = torch.cat([u_values_future, u[:,self.time_slice:].permute(0,1,3,2)], dim=2)\n",
    "        loss = self.criterion(u_values_hat, target)+self.criterion(hr_values_hat, u_values[:,:-self.time_slice])\n",
    "        mae_loss = self.mae_criterion(u_values_hat, target)\n",
    "        interp_loss = self.mae_criterion(hr_values_hat, u_values[:,:-self.time_slice])        \n",
    "        rel_error = rel_L2_error(u_values_hat.reshape(B, -1), target.reshape(B, -1))\n",
    "        rel_error = torch.mean(rel_error)\n",
    "        \n",
    "        self.log('train_rel_error', rel_error, prog_bar=True)\n",
    "        self.log('train_loss', loss, prog_bar=True)\n",
    "        self.log('train_mae_loss', mae_loss, prog_bar=True)\n",
    "        self.log('train_interp_loss', interp_loss, prog_bar=True)\n",
    "        \n",
    "        return loss\n",
    "\n",
    "    def validation_step(self, val_batch, batch_idx):\n",
    "        t = val_batch['t'].float()\n",
    "        u = val_batch['lr_frames'].float() # B, T, 1, L\n",
    "        u_values = val_batch['hr_points'].float()\n",
    "        coords = val_batch['coords_hr'].float()\n",
    "        lr_coords = val_batch['coords_lr'].float()\n",
    "\n",
    "        u_values_future = u_values[:,self.time_slice:] # B, T_future, N, 1\n",
    "        T_future = u_values_future.shape[1]\n",
    "\n",
    "        u_values_hat = []\n",
    "        inp = u[:,:self.time_slice]\n",
    "        hr_last = u_values[:,self.time_slice-1]\n",
    "\n",
    "        for i in range(T_future//self.time_slice):\n",
    "            out_hr, out_lr, _ = self.forward(\n",
    "                inp, \n",
    "                lr_coords, \n",
    "                coords, \n",
    "                t[:,i*self.time_slice:(i+2)*self.time_slice], \n",
    "                hr_last)\n",
    "            y_hat = torch.cat([out_hr, out_lr], dim=2)\n",
    "            u_values_hat.append(y_hat)\n",
    "            \n",
    "            inp = out_lr.permute(0,1,3,2)\n",
    "            hr_last = out_hr[:,-1]\n",
    "        \n",
    "        u_values_hat = torch.cat(u_values_hat, dim=1)\n",
    "        target = torch.cat([u_values_future, u[:,self.time_slice:].permute(0,1,3,2)], dim=2)\n",
    "        loss = self.criterion(u_values_hat, target)\n",
    "        mae_loss = self.mae_criterion(u_values_hat, target)\n",
    "        B = u_values_future.shape[0]\n",
    "        rel_error = rel_L2_error(u_values_hat.reshape(B, -1), target.reshape(B, -1))\n",
    "        rel_error = torch.mean(rel_error)\n",
    "        \n",
    "        self.log('val_rel_error', rel_error, prog_bar=True)\n",
    "        self.log('val_loss', loss, prog_bar=True)\n",
    "        self.log('val_mae_loss', mae_loss, prog_bar=True)\n",
    "        \n",
    "        return loss\n",
    "        \n",
    "    def test_step(self, test_batch, batch_idx):\n",
    "        t = test_batch['t'].float()\n",
    "        u = test_batch['lr_frames'].float() # B, T, 1, L\n",
    "        u_values = test_batch['hr_points'].float()\n",
    "        coords = test_batch['coords_hr'].float()\n",
    "        lr_coords = test_batch['coords_lr'].float()\n",
    "\n",
    "        u_values_future = u_values[:,self.time_slice:] # B, T_future, N, 1\n",
    "        T_future = u_values_future.shape[1]\n",
    "\n",
    "        u_values_hat = []\n",
    "        inp = u[:,:self.time_slice]\n",
    "        hr_last = u_values[:,self.time_slice-1]\n",
    "\n",
    "        for i in range(T_future//self.time_slice):\n",
    "            out_hr, out_lr, _ = self.forward(\n",
    "                inp, \n",
    "                lr_coords, \n",
    "                coords, \n",
    "                t[:,i*self.time_slice:(i+2)*self.time_slice], \n",
    "                hr_last)\n",
    "            y_hat = torch.cat([out_hr, out_lr], dim=2)\n",
    "            u_values_hat.append(y_hat)\n",
    "            \n",
    "            inp = out_lr.permute(0,1,3,2)\n",
    "            hr_last = out_hr[:,-1]\n",
    "        \n",
    "        u_values_hat = torch.cat(u_values_hat, dim=1)\n",
    "        target = torch.cat([u_values_future, u[:,self.time_slice:].permute(0,1,3,2)], dim=2)\n",
    "        loss = self.criterion(u_values_hat, target)\n",
    "        mae_loss = self.mae_criterion(u_values_hat, target)\n",
    "        B = u_values_future.shape[0]\n",
    "        rel_error = rel_L2_error(u_values_hat.reshape(B, -1), target.reshape(B, -1))\n",
    "        rel_error = torch.mean(rel_error)\n",
    "        \n",
    "        self.log('test_rel_error', rel_error, prog_bar=True)\n",
    "        self.log('test_loss', loss, prog_bar=True)\n",
    "        self.log('test_mae_loss', mae_loss, prog_bar=True)\n",
    "        \n",
    "        return {'test_loss': loss, 'test_rel_error': rel_error}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c9ab33ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the data\n",
    "train_dataset =  HDF5DatasetImplicitGNN_2d_irregular(data,nt,nx,'train', hparams.samples, seed=hparams.seed)\n",
    "valid_dataset = HDF5DatasetImplicitGNN_2d_irregular(data,nt,nx,'valid', seed=hparams.seed)\n",
    "test_dataset = HDF5DatasetImplicitGNN_2d_irregular(data,nt,nx,'test', seed=hparams.seed)\n",
    "\n",
    "# Create the dataloaders using the custom collate function\n",
    "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)\n",
    "valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)\n",
    "# Modify the test_loader to use the full dataset as a single batch\n",
    "test_loader = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=False, num_workers=num_workers)\n",
    "\n",
    "# Set the training and validation dataloaders in the PyTorch Lightning module\n",
    "model = MAgNetGNN2d(hparams)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b49f7ebb",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True, used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n",
      "\n",
      "  | Name          | Type      | Params\n",
      "--------------------------------------------\n",
      "0 | criterion     | L1Loss    | 0     \n",
      "1 | mse_criterion | MSELoss   | 0     \n",
      "2 | mae_criterion | L1Loss    | 0     \n",
      "3 | encoder       | Encoder   | 136 K \n",
      "4 | processor     | Processor | 1.1 M \n",
      "5 | proj_head     | Linear    | 17.0 K\n",
      "6 | projector     | MLP       | 66.2 K\n",
      "7 | _encoder      | Encoder   | 136 K \n",
      "8 | _processor    | Processor | 1.1 M \n",
      "9 | _decoder      | Decoder   | 67.3 K\n",
      "--------------------------------------------\n",
      "2.6 M     Trainable params\n",
      "0         Non-trainable params\n",
      "2.6 M     Total params\n",
      "10.282    Total estimated model params size (MB)\n",
      "/home/rakhoon/.local/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:631: UserWarning: Checkpoint directory /ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/MAgNet(shallow_irregular) exists and is not empty.\n",
      "  rank_zero_warn(f\"Checkpoint directory {dirpath} exists and is not empty.\")\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation sanity check: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/rakhoon/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/data_loading.py:432: UserWarning: The number of training samples (75) is smaller than the logging interval Trainer(log_every_n_steps=300). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n",
      "  rank_zero_warn(\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3ce3ea4aa89c498e8026743477a35bd1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 9, global step 749: val_rel_error reached 0.09214 (best 0.09214), saving model to \"/ex_hdd/sungwoong/SungWoong/Graph Simulator/gdon_ipynb/MAgNet(shallow_irregular)/best-v1.ckpt\" as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/rakhoon/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:688: UserWarning: Detected KeyboardInterrupt, attempting graceful shutdown...\n",
      "  rank_zero_warn(\"Detected KeyboardInterrupt, attempting graceful shutdown...\")\n"
     ]
    }
   ],
   "source": [
    "# Define the checkpoint callback\n",
    "checkpoint_callback = ModelCheckpoint(\n",
    "    dirpath=dirpath,\n",
    "    filename=\"best\",\n",
    "    save_top_k=1,\n",
    "    every_n_epochs=10,\n",
    "    verbose=True,\n",
    "    monitor=\"val_rel_error\",\n",
    "    mode=\"min\",\n",
    "    save_last=True\n",
    ")\n",
    "\n",
    "# Initialize the trainer\n",
    "trainer = pl.Trainer(\n",
    "    accelerator=\"gpu\",\n",
    "    max_epochs=250,\n",
    "    log_every_n_steps=300,\n",
    "    gpus=hparams.gpus, \n",
    "    default_root_dir=\"lightning_logs\",\n",
    "    callbacks=[checkpoint_callback]\n",
    ")\n",
    "\n",
    "# Train the model with the created dataloaders\n",
    "trainer.fit(model, train_loader, valid_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "34c1d6a8",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "IOStream.flush timed out\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average Test Loss: 0.03863672912120819, Average Test Relative Error: 0.08974304050207138\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/rakhoon/.local/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py:415: UserWarning: You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet. This is most likely because the model hasn't been passed to the `Trainer`\n",
      "  rank_zero_warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(0.03863672912120819, 0.08974304050207138)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def compute_test_error(model, test_loader):\n",
    "    # Move model to CPU\n",
    "    model = model.to(\"cpu\")\n",
    "    \n",
    "    # Ensure model is in eval mode\n",
    "    model.eval()\n",
    "    \n",
    "    # Accumulators for loss and relative error\n",
    "    total_test_loss = 0.0\n",
    "    total_test_rel_error = 0.0\n",
    "    \n",
    "    # Disable gradients to save memory\n",
    "    with torch.no_grad():\n",
    "        num_batches = 0\n",
    "        for batch in test_loader:\n",
    "            results = model.test_step(batch, batch_idx=num_batches)\n",
    "            total_test_loss += results['test_loss'].item()\n",
    "            total_test_rel_error += results['test_rel_error'].item()\n",
    "            num_batches += 1\n",
    "            \n",
    "    # Compute average values\n",
    "    avg_test_loss = total_test_loss / num_batches\n",
    "    avg_test_rel_error = total_test_rel_error / num_batches\n",
    "            \n",
    "    # Print or return the results\n",
    "    print(f\"Average Test Loss: {avg_test_loss}, Average Test Relative Error: {avg_test_rel_error}\")\n",
    "    return avg_test_loss, avg_test_rel_error\n",
    "\n",
    "checkpoint_path = dirpath+'/best.ckpt'\n",
    "model = MAgNetGNN2d.load_from_checkpoint(checkpoint_path)\n",
    "\n",
    "# Call the function to compute the test error\n",
    "compute_test_error(model, test_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f9e1c0c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "graph",
   "language": "python",
   "name": "graph"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
