{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "hsNadJLO_CPq",
        "outputId": "4ea90691-f247-4846-c6fc-878dd0fa1b5c"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "2.4.1 cu121\n"
          ]
        }
      ],
      "source": [
        "import torch\n",
        "TORCH = torch.__version__.split('+')[0]\n",
        "CUDA = 'cu' + torch.version.cuda.replace('.','')\n",
        "print(TORCH, CUDA)\n",
        "\n",
        "import pickle\n",
        "import os"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "%%capture\n",
        "!pip install pyg_lib -f https://data.pyg.org/whl/torch-2.4.0+cu121.html\n",
        "!pip install torch-scatter     -f https://data.pyg.org/whl/torch-2.4.0+cu121.html\n",
        "!pip install torch-cluster -f https://data.pyg.org/whl/torch-2.4.0+cu121.html\n",
        "!pip install torch-spline-conv -f https://data.pyg.org/whl/torch-2.4.0+cu121.html\n",
        "!pip install torch-sparse -f https://data.pyg.org/whl/torch-2.4.0+cu121.html\n",
        "!pip install torch-geometric"
      ],
      "metadata": {
        "id": "d6npCCK-_JtD"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import torch_geometric\n",
        "import torch_geometric.nn as geom_nn\n",
        "import torch_geometric.data as geom_data\n",
        "from torch_geometric.loader import DataLoader\n",
        "from torch_geometric.data import InMemoryDataset\n",
        "from torch_geometric.data import Data\n",
        "\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "import networkx as nx\n",
        "\n",
        "from scipy.linalg import fractional_matrix_power\n",
        "from scipy.spatial.distance import pdist\n",
        "from scipy.spatial.distance import squareform\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "import matplotlib as matplotlib\n",
        "import matplotlib.cm as cm\n",
        "\n",
        "from tqdm import tqdm\n",
        "\n",
        "import math\n",
        "from numba import cuda\n",
        "import numpy as np\n",
        "import torch\n",
        "from torch_geometric.utils import (\n",
        "    to_networkx,\n",
        "    from_networkx,\n",
        "    to_dense_adj,\n",
        "    remove_self_loops,\n",
        "    to_undirected,\n",
        ")"
      ],
      "metadata": {
        "id": "QG5QUEIe_M2L"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/drive')"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "v77ZZalQ_O5D",
        "outputId": "886a8895-4639-456d-d063-36bfb6cc4a85"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mounted at /content/drive\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Set up Functions"
      ],
      "metadata": {
        "id": "4atzWaGU_Rj1"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def data_to_kNN(X,k):\n",
        "  edge_index = torch_geometric.nn.knn_graph(X, k)\n",
        "  edge_index = to_undirected(edge_index, num_nodes=X.shape[0])\n",
        "  return edge_index\n",
        "\n",
        "def data_to_GNN_data(X,Y,k):\n",
        "  dataset = []\n",
        "  N = X.shape[0]\n",
        "  for i in range(N):\n",
        "    data = Data(x = X[i,:,:], y = Y[i,:,:], edge_index = data_to_kNN(X[i,:,:],k))\n",
        "    data.num_nodes = 20\n",
        "    data.num_edges = data.edge_index.shape[1]\n",
        "    dataset.append(data)\n",
        "  return dataset\n",
        "\n",
        "def data_to_GNN_data_fm(X,Y,k):\n",
        "  dataset = []\n",
        "  N = len(X)\n",
        "  for i in tqdm(range(N)):\n",
        "    data = Data(x = X[i], y = Y[i], edge_index = data_to_kNN(X[i],k))\n",
        "    data.num_nodes = X[i].shape[0]\n",
        "    data.num_edges = data.edge_index.shape[1]\n",
        "    dataset.append(data)\n",
        "  return dataset"
      ],
      "metadata": {
        "id": "B7e_hwuG_REL"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vbBR8QuND5zv"
      },
      "source": [
        "# Preprocessing Fish Data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Q2Td10RkD2Qd"
      },
      "outputs": [],
      "source": [
        "import json\n",
        "\n",
        "# Opening JSON file\n",
        "f = open('drive/MyDrive/MeasureMaps/schooling_frames.json')\n",
        "\n",
        "# returns JSON object as\n",
        "# a dictionary\n",
        "data = json.load(f)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "g4_TzgovwqXY"
      },
      "outputs": [],
      "source": [
        "fish_ids_to_idx = dict()\n",
        "\n",
        "idx = 0\n",
        "for i in range(1,len(data)+1):\n",
        "  curr_time = data[str(i)]\n",
        "  fish_idx = data[str(i)][\"onfish\"]\n",
        "  for j in range(len(fish_idx)):\n",
        "    if fish_idx[j] not in fish_ids_to_idx:\n",
        "      fish = dict()\n",
        "      fish[\"id\"] = idx\n",
        "      fish[\"times\"] = [i]\n",
        "      fish[\"x\"] = torch.tensor([curr_time[\"px\"][j], curr_time[\"py\"][j], curr_time[\"vx\"][j], curr_time[\"vy\"][j]]).unsqueeze(0)\n",
        "      fish_ids_to_idx[fish_idx[j]] = fish\n",
        "      idx += 1\n",
        "    else:\n",
        "      fish = fish_ids_to_idx[fish_idx[j]]\n",
        "      fish[\"times\"].append(i)\n",
        "      try:\n",
        "        x = torch.tensor([curr_time[\"px\"][j], curr_time[\"py\"][j], curr_time[\"vx\"][j], curr_time[\"vy\"][j]]).unsqueeze(0)\n",
        "      except:\n",
        "        x = torch.tensor([curr_time[\"px\"][j], curr_time[\"py\"][j], 0,0]).unsqueeze(0)\n",
        "      fish[\"x\"] = torch.cat((fish[\"x\"],x), dim=0)\n",
        "      fish_ids_to_idx[fish_idx[j]] = fish"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "NkDRu5TjfdmK",
        "outputId": "3023322a-6faa-429d-d642-9b65d9bf574c"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "104.00130220444609"
            ]
          },
          "execution_count": 15,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "avg = 0\n",
        "for k in fish_ids_to_idx.keys():\n",
        "  avg += len(fish_ids_to_idx[k][\"times\"])\n",
        "\n",
        "avg/len(fish_ids_to_idx.keys())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8Za9EempftXq"
      },
      "outputs": [],
      "source": [
        "fish_ids_to_idx[110450][\"times\"]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hCClMYk1f9jH"
      },
      "outputs": [],
      "source": [
        "def create_y_for_fish(fish):\n",
        "  X = fish[\"x\"].T\n",
        "  t = fish[\"times\"]\n",
        "\n",
        "  x,y = create_y(X,t)\n",
        "  fish[\"x\"] = x\n",
        "  fish[\"y\"] = y\n",
        "  fish[\"times\"] = t[1:-1]\n",
        "  return fish\n",
        "\n",
        "def create_y(X,t):\n",
        "  N = X.shape[0]\n",
        "  T = X.shape[1]\n",
        "\n",
        "  n = N//4\n",
        "\n",
        "  y = torch.zeros(N,T)\n",
        "  for i in range(1,T-1):\n",
        "    y[:n,i] = X[2*n:3*n, i]\n",
        "    y[n:2*n,i] = X[3*n:, i]\n",
        "    y[2*n:,i] = (X[2*n:,i+1] - X[2*n:,i-1])/(t[i+1]-t[i-1])\n",
        "\n",
        "  return X[:,1:-1], y[:,1:-1]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iwTyczanwlXj"
      },
      "outputs": [],
      "source": [
        "for k in fish_ids_to_idx.keys():\n",
        "  fish_ids_to_idx[k] = create_y_for_fish(fish_ids_to_idx[k])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bYiWniUs8Joe"
      },
      "outputs": [],
      "source": [
        "times_to_fish = {}\n",
        "keys = fish_ids_to_idx.keys()\n",
        "for t in range(1,5001):\n",
        "  fish_idx = []\n",
        "  for k in keys:\n",
        "    if t in fish_ids_to_idx[k][\"times\"]:\n",
        "      fish_idx.append(k)\n",
        "  times_to_fish[t] = fish_idx"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cGUnd2eHxNSK"
      },
      "outputs": [],
      "source": [
        "Xs = []\n",
        "Ys = []\n",
        "keys = fish_ids_to_idx.keys()\n",
        "for t in range(1,5001):\n",
        "  fish_idxs = times_to_fish[t]\n",
        "  x = torch.zeros(len(fish_idxs),4)\n",
        "  y = torch.zeros(len(fish_idxs),4)\n",
        "  for i,fish_id in enumerate(fish_idxs):\n",
        "    fish = fish_ids_to_idx[fish_id]\n",
        "    t_idx = fish[\"times\"].index(t)\n",
        "    x[i,:] = fish[\"x\"].T[t_idx,:]\n",
        "    y[i,:] = fish[\"y\"].T[t_idx,:]\n",
        "  Xs.append(x)\n",
        "  Ys.append(y)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "T5CI-ziRfl2z"
      },
      "outputs": [],
      "source": [
        "torch.save((Xs, Ys), \"drive/MyDrive/MeasureMaps/FishMilling.pt\")"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Setting Up Models"
      ],
      "metadata": {
        "id": "DNn3Uc2yG3JB"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UsT13ZQ4IOro"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import scipy.io as sio\n",
        "import torch.nn as nn"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "efkeaJDcIPSu"
      },
      "outputs": [],
      "source": [
        "# Throughout this we are going to assume that data is of the form B x N x D\n",
        "# Where B is the batch size, N is the sequence length for the transformer\n",
        "# this the number of data points. Finally D is the embedding dimension.\n",
        "class SimpleAttention(nn.Module):\n",
        "  # Initialize the parameter\n",
        "  def __init__(self, hidden_dim):\n",
        "    super(SimpleAttention, self).__init__()\n",
        "    self.linear = nn.Linear(hidden_dim, hidden_dim)\n",
        "    self.WQ = nn.Linear(hidden_dim, hidden_dim, bias = False)\n",
        "    self.WK = nn.Linear(hidden_dim, hidden_dim, bias = False)\n",
        "    self.WV = nn.Linear(hidden_dim, hidden_dim, bias = False)\n",
        "    self.skip = nn.Linear(hidden_dim, hidden_dim)\n",
        "    self.attention = nn.MultiheadAttention(hidden_dim, 1, batch_first=True)\n",
        "\n",
        "  # Forward pass\n",
        "  def forward(self, input):\n",
        "    Q = self.WQ(input)\n",
        "    K = self.WK(input)\n",
        "    V = self.WV(input)\n",
        "    output_attention,_ = self.attention(Q,K,V)\n",
        "    output_linear = self.linear(output_attention.relu()).relu()\n",
        "    return output_linear + self.skip(input)\n",
        "\n",
        "class SimpleTransformer(nn.Module):\n",
        "# Initialize the parameter\n",
        "  def __init__(self, input_dim, hidden_dim, out_dim, num_layers):\n",
        "    super(SimpleTransformer, self).__init__()\n",
        "    self.embed = nn.Linear(input_dim, hidden_dim)\n",
        "    self.predictor = nn.Linear(hidden_dim, out_dim)\n",
        "    self.AttentionLayers = []\n",
        "    for i in range(num_layers):\n",
        "      self.AttentionLayers.append(SimpleAttention(hidden_dim))\n",
        "\n",
        "    self.AttentionLayers = nn.ModuleList(self.AttentionLayers)\n",
        "    self.num_layers = num_layers\n",
        "\n",
        "  # Forward pass\n",
        "  def forward(self, z):\n",
        "    z = self.embed(z)\n",
        "    for i in range(self.num_layers):\n",
        "      z = self.AttentionLayers[i](z)\n",
        "    return self.predictor(z)\n",
        "\n",
        "\n",
        "class FNN(nn.Module):\n",
        "  def __init__(self, input_dim, hidden_dim, out_dim, num_layers, d):\n",
        "    super(FNN, self).__init__()\n",
        "    self.embed = nn.Linear(input_dim, hidden_dim)\n",
        "    self.predictor = nn.Linear(hidden_dim, out_dim)\n",
        "    self.Layers = []\n",
        "    for i in range(num_layers):\n",
        "      self.Layers.append(nn.Linear(hidden_dim, hidden_dim))\n",
        "\n",
        "    self.Layers = nn.ModuleList(self.Layers)\n",
        "    self.num_layers = num_layers\n",
        "    self.d = d\n",
        "\n",
        "  # Input size B x N X 4\n",
        "  def forward(self,x):\n",
        "    x = torch.flatten(x, 1, 2) # B x N4\n",
        "    x = self.embed(x).relu()\n",
        "    for i in range(self.num_layers):\n",
        "      x = self.Layers[i](x).relu()\n",
        "    x = self.predictor(x)\n",
        "    x = torch.unflatten(x, 1, (-1,self.d))\n",
        "    return x\n",
        "\n",
        "def kernel_basis(X, d1, d2):\n",
        "  N = X.shape[0]\n",
        "  d = X.shape[1]\n",
        "\n",
        "  Phi = torch.zeros(N, (d1+2*d2)*d, device = X.device)\n",
        "  for i in range(d1+2*d2):\n",
        "    if i < d1:\n",
        "      for j in range(d):\n",
        "        Phi[:,d*i+j] = X[:,j].pow(i+1)\n",
        "    elif i < d1+d2:\n",
        "      for j in range(d):\n",
        "        k = i-d1+1\n",
        "        Phi[:,d*i+j] = (X[:,j]*k).sin()\n",
        "    else:\n",
        "      for j in range(d):\n",
        "        k = i-d1-d2+1\n",
        "        Phi[:,d*i+j] = torch.cos(X[:,j]*k)\n",
        "\n",
        "  return Phi\n",
        "\n",
        "class Kernel(nn.Module):\n",
        "  def __init__(self, feature_dim, out_dim, d, embed=kernel_basis):\n",
        "    super(Kernel, self).__init__()\n",
        "    self.predictor = nn.Linear(feature_dim, out_dim, device = \"cuda\")\n",
        "    self.embed = embed\n",
        "    self.d = d\n",
        "\n",
        "  # Input size B x N X 4\n",
        "  def forward(self,x):\n",
        "    x = self.predictor(x)\n",
        "    x = torch.unflatten(x, 1, (-1,self.d))\n",
        "    return x\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "e2g0EVF6JEqI"
      },
      "outputs": [],
      "source": [
        "def get_nn(in_channels, out_channels):\n",
        "  return torch.nn.Sequential(torch.nn.Linear(in_channels, out_channels), torch.nn.ReLU(),\n",
        "                             torch.nn.Linear(out_channels, out_channels))\n",
        "\n",
        "class GNN(torch.nn.Module):\n",
        "  def __init__(self, node_input_dim, output_dim, num_layers, hidden_dim = 128, device = \"cuda\", arch = \"Graph\"):\n",
        "    super().__init__()\n",
        "\n",
        "    self.num_layers = num_layers\n",
        "    self.layers = []\n",
        "    if arch == \"Transformer\":\n",
        "      self.layers.append(torch_geometric.nn.TransformerConv(node_input_dim, hidden_dim).to(device))\n",
        "      for i in range(num_layers-1):\n",
        "        self.layers.append(torch_geometric.nn.TransformerConv(hidden_dim, hidden_dim).to(device))\n",
        "    elif arch == \"Graph\":\n",
        "      self.layers.append(torch_geometric.nn.GraphConv(node_input_dim, hidden_dim).to(device))\n",
        "      for i in range(num_layers-1):\n",
        "        self.layers.append(torch_geometric.nn.GraphConv(hidden_dim, hidden_dim).to(device))\n",
        "    elif arch == \"GIN\":\n",
        "      self.layers.append(torch_geometric.nn.GINConv(get_nn(node_input_dim, hidden_dim)).to(device))\n",
        "      for i in range(num_layers-1):\n",
        "        self.layers.append(torch_geometric.nn.GINConv(get_nn(hidden_dim, hidden_dim)).to(device))\n",
        "\n",
        "    self.layers = torch.nn.ModuleList(self.layers)\n",
        "    self.lin = torch.nn.Linear(hidden_dim, output_dim).to(device)\n",
        "\n",
        "  def forward(self, x, edge_index, batch):\n",
        "    for i in range(self.num_layers):\n",
        "      x = self.layers[i](x, edge_index).relu()\n",
        "\n",
        "    z = self.lin(x)\n",
        "    return z"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Training"
      ],
      "metadata": {
        "id": "NWouYiW9HEcJ"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kE7VqDvGb9Vo"
      },
      "outputs": [],
      "source": [
        "Xtrn, Ytrn = torch.load(\"drive/MyDrive/MeasureMaps/FishMilling.pt\")\n",
        "\n",
        "dataset_knn = data_to_GNN_data_fm(Xtrn, Ytrn, 3)\n",
        "dataset_full = data_to_GNN_data_fm(Xtrn, Ytrn, 20)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mj7TR0W8a_fa"
      },
      "outputs": [],
      "source": [
        "train_idx = torch.load(\"drive/MyDrive/MeasureMaps/FM-data-split.pt\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "okS60owkIIqY"
      },
      "outputs": [],
      "source": [
        "train_data = torch.utils.data.TensorDataset(Xtrn[train_idx,:,:], Ytrn[train_idx,:,:])\n",
        "\n",
        "train_loader = torch.utils.data.DataLoader(train_data, shuffle = True, batch_size=500)\n",
        "train_loader_gnn_full = DataLoader([dataset_full[i] for i in train_idx.numpy()], shuffle = True, batch_size = 500)\n",
        "train_loader_gnn_knn = DataLoader([dataset_knn[i] for i in train_idx.numpy()], shuffle = True, batch_size = 500)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "background_save": true
        },
        "id": "aWeai7rWdbYN"
      },
      "outputs": [],
      "source": [
        "from tqdm import tqdm\n",
        "\n",
        "depths = [3,4,5]\n",
        "widths = [128,256,512]\n",
        "lrs = [2e-4, 1e-3, 1e-4]\n",
        "T = 5\n",
        "N = len(train_X)\n",
        "\n",
        "epochs = 11\n",
        "\n",
        "for t in range(T):\n",
        "  for d in depths: # depth\n",
        "    for h in widths: # width\n",
        "      for lr in lrs:\n",
        "        if os.path.exists(\"drive/MyDrive/MeasureMaps/Transformer/fm-trial-\"+str(t)+\"-depth-\"+str(d)+\"-width-\"+str(h)+\"-lr-\"+str(lr)+\"-epoch-\"+str(10)+\".pt\"):\n",
        "          print(\"Done with\",t,d,h,lr)\n",
        "          continue\n",
        "        model_transformer = SimpleTransformer(4,h,4,d).to('cuda')\n",
        "\n",
        "        optimizer_transformer = torch.optim.Adam(model_transformer.parameters(), lr = lr)\n",
        "        scheduler_transformer = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_transformer, epochs)\n",
        "\n",
        "        for i in tqdm(range(epochs)):\n",
        "          for j in range(N):\n",
        "            optimizer_transformer.zero_grad()\n",
        "            X = train_X[j].to('cuda')\n",
        "            Y = train_Y[j].to('cuda')\n",
        "            Y_pred = model_transformer(X)\n",
        "            loss = torch.nn.functional.mse_loss(Y_pred, Y)\n",
        "            loss.backward()\n",
        "            optimizer_transformer.step()\n",
        "          # torch.save(model_transformer,\"drive/MyDrive/MeasureMaps/Transformer/fm-trial-\"+str(t)+\"-depth-\"+str(d)+\"-width-\"+str(h)+\"-lr-\"+str(lr)+\"-epoch-\"+str(i)+\".pt\")\n",
        "          scheduler_transformer.step()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "background_save": true
        },
        "id": "lZzreSBFl25g"
      },
      "outputs": [],
      "source": [
        "from tqdm import tqdm\n",
        "\n",
        "degree = [2,3,4]\n",
        "frequencies = [4,5,6]\n",
        "lrs = [1e-4, 2e-4, 1e-3]\n",
        "T = 5\n",
        "\n",
        "epochs = 1001\n",
        "\n",
        "for t in range(T):\n",
        "  for d1 in degree:\n",
        "    for d2 in frequencies:\n",
        "      for lr in lrs:\n",
        "        if os.path.exists(\"drive/MyDrive/MeasureMaps/Kernel/trial-\"+str(t)+\"-depth-\"+str(d)+\"-width-\"+str(h)+\"-lr-\"+str(lr)+\"-epoch-\"+str(1000)+\".pt\"):\n",
        "          continue\n",
        "        Xtrn_data_kernel = kernel_basis(torch.flatten(Xtrn_data,1,2),d1,d2)\n",
        "        train_data_kernel = torch.utils.data.TensorDataset(Xtrn_data_kernel[train_idx,:], Ytrn_data[train_idx,:,:])\n",
        "        train_loader_kernel = torch.utils.data.DataLoader(train_data_kernel, shuffle = True, batch_size=500)\n",
        "\n",
        "        model_kernel = Kernel((d1+2*d2)*4*20,80,4).to('cuda')\n",
        "\n",
        "        optimizer_kernel = torch.optim.Adam(model_kernel.parameters(), lr = lr)\n",
        "        scheduler_kernel = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_kernel, epochs)\n",
        "\n",
        "        for i in tqdm(range(epochs)):\n",
        "          for X,Y in train_loader_kernel:\n",
        "            optimizer_kernel.zero_grad()\n",
        "            Y_pred = model_kernel(X)\n",
        "            loss = torch.nn.functional.mse_loss(Y_pred, Y)\n",
        "            loss.backward()\n",
        "            optimizer_kernel.step()\n",
        "          if i % 100 == 0:\n",
        "            torch.save(model_kernel,\"drive/MyDrive/MeasureMaps/Kernel/trial-\"+str(t)+\"-degree-\"+str(d1)+\"-frequency-\"+str(d2)+\"-lr-\"+str(lr)+\"-epoch-\"+str(i)+\".pt\")\n",
        "          scheduler_kernel.step()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "c20j3ohMe8gV"
      },
      "outputs": [],
      "source": [
        "from tqdm import tqdm\n",
        "\n",
        "depths = [3,4,5]\n",
        "widths = [128,256,512]\n",
        "lrs = [1e-4, 2e-4, 1e-3]\n",
        "T = 5\n",
        "N = 20\n",
        "\n",
        "epochs = 1001\n",
        "\n",
        "for t in range(1):\n",
        "  for d in depths:\n",
        "    for h in widths:\n",
        "      for lr in lrs:\n",
        "        if os.path.exists(\"drive/MyDrive/MeasureMaps/FNN/trial-\"+str(t)+\"-depth-\"+str(d)+\"-width-\"+str(h)+\"-lr-\"+str(lr)+\"-epoch-\"+str(1000)+\".pt\"):\n",
        "          continue\n",
        "        model_fnn = FNN(4*N,h,4*N,d,4).to('cuda')\n",
        "\n",
        "\n",
        "        optimizer_fnn = torch.optim.Adam(model_fnn.parameters(), lr = lr)\n",
        "        scheduler_fnn = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_fnn, epochs)\n",
        "\n",
        "        for i in tqdm(range(epochs)):\n",
        "          for X,Y in train_loader:\n",
        "            optimizer_fnn.zero_grad()\n",
        "            Y_pred = model_fnn(X)\n",
        "            loss = torch.nn.functional.mse_loss(Y_pred, Y)\n",
        "            loss.backward()\n",
        "            optimizer_fnn.step()\n",
        "          if i % 100 == 0:\n",
        "            torch.save(model_fnn,\"drive/MyDrive/MeasureMaps/FNN/trial-\"+str(t)+\"-depth-\"+str(d)+\"-width-\"+str(h)+\"-lr-\"+str(lr)+\"-epoch-\"+str(i)+\".pt\")\n",
        "          scheduler_fnn.step()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TKhoLgHPfZyz"
      },
      "outputs": [],
      "source": [
        "from tqdm import tqdm\n",
        "\n",
        "depths = [3,4,5]\n",
        "widths = [128,256,512]\n",
        "lrs = [1e-4, 2e-4, 1e-3]\n",
        "T = 5\n",
        "\n",
        "epochs = 501\n",
        "\n",
        "for t in range(T):\n",
        "  for d in depths:\n",
        "    for h in widths:\n",
        "      for lr in lrs:\n",
        "        # if os.path.exists(\"drive/MyDrive/MeasureMaps/GNN/GraphConv Full/fm-trial-\"+str(t)+\"-depth-\"+str(d)+\"-width-\"+str(h)+\"-lr-\"+str(lr)+\"-epoch-\"+str(10)+\".pt\"):\n",
        "        #   print(\"done with\", t, d, h, lr)\n",
        "        #   continue\n",
        "        model_gnn = GNN(4,4,d,h).to('cuda')\n",
        "\n",
        "\n",
        "        optimizer_gnn = torch.optim.Adam(model_gnn.parameters(), lr = lr)\n",
        "        scheduler_gnn = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_gnn, epochs)\n",
        "\n",
        "        for i in tqdm(range(epochs)):\n",
        "          for data in train_loader_gnn_full:\n",
        "            optimizer_gnn.zero_grad()\n",
        "            data = data.to('cuda')\n",
        "            Y_pred = model_gnn(data.x, data.edge_index, data.batch)\n",
        "            loss = torch.nn.functional.mse_loss(Y_pred, data.y)\n",
        "            loss.backward()\n",
        "            optimizer_gnn.step()\n",
        "          print(loss)\n",
        "          torch.save(model_gnn,\"drive/MyDrive/MeasureMaps/GNN/GraphConv Full/fm-trial-\"+str(t)+\"-depth-\"+str(d)+\"-width-\"+str(h)+\"-lr-\"+str(lr)+\"-epoch-\"+str(i)+\".pt\")\n",
        "          scheduler_gnn.step()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BiYbgHl7gfA9"
      },
      "outputs": [],
      "source": [
        "from tqdm import tqdm\n",
        "\n",
        "depths = [3,4,5]\n",
        "widths = [128,256,512]\n",
        "lrs = [1e-4, 2e-4, 1e-3]\n",
        "T = 5\n",
        "\n",
        "epochs = 11\n",
        "\n",
        "for t in range(T):\n",
        "  for d in depths:\n",
        "    for h in widths:\n",
        "      for lr in lrs:\n",
        "        # if os.path.exists(\"drive/MyDrive/MeasureMaps/GNN/GraphConv KNN/fm-trial-\"+str(t)+\"-depth-\"+str(d)+\"-width-\"+str(h)+\"-lr-\"+str(lr)+\"-epoch-\"+str(10)+\".pt\"):\n",
        "        #   print(\"Done with\",t,d,h,lr)\n",
        "        #   continue\n",
        "        model_gnn = GNN(4,4,d,h).to('cuda')\n",
        "\n",
        "\n",
        "        optimizer_gnn = torch.optim.Adam(model_gnn.parameters(), lr = lr)\n",
        "        scheduler_gnn = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_gnn, epochs)\n",
        "\n",
        "        for i in tqdm(range(epochs)):\n",
        "          for data in train_loader_gnn_knn:\n",
        "            optimizer_gnn.zero_grad()\n",
        "            data = data.to('cuda')\n",
        "            Y_pred = model_gnn(data.x, data.edge_index, data.batch)\n",
        "            loss = torch.nn.functional.mse_loss(Y_pred, data.y)\n",
        "            loss.backward()\n",
        "            optimizer_gnn.step()\n",
        "          print(loss)\n",
        "          # torch.save(model_gnn,\"drive/MyDrive/MeasureMaps/GNN/GraphConv KNN/fm-trial-\"+str(t)+\"-depth-\"+str(d)+\"-width-\"+str(h)+\"-lr-\"+str(lr)+\"-epoch-\"+str(i)+\".pt\")\n",
        "          scheduler_gnn.step()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "background_save": true
        },
        "id": "lcuczXGAhN9C"
      },
      "outputs": [],
      "source": [
        "from tqdm import tqdm\n",
        "\n",
        "depths = [3,4,5]\n",
        "widths = [128,256,512]\n",
        "lrs = [1e-4, 2e-4, 1e-3]\n",
        "T = 5\n",
        "\n",
        "epochs = 11\n",
        "\n",
        "for t in range(T):\n",
        "  for d in depths:\n",
        "    for h in widths:\n",
        "      for lr in lrs:\n",
        "        # if os.path.exists(\"drive/MyDrive/MeasureMaps/GNN/TransformerConv Full/fm-trial-\"+str(t)+\"-depth-\"+str(d)+\"-width-\"+str(h)+\"-lr-\"+str(lr)+\"-epoch-\"+str(10)+\".pt\"):\n",
        "        #   print(\"Done with\",t,d,h,lr)\n",
        "        #   continue\n",
        "        model_gnn = GNN(4,4,d,h,arch=\"Transformer\").to('cuda')\n",
        "\n",
        "\n",
        "        optimizer_gnn = torch.optim.Adam(model_gnn.parameters(), lr = lr)\n",
        "        scheduler_gnn = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_gnn, epochs)\n",
        "\n",
        "        for i in tqdm(range(epochs)):\n",
        "          for data in train_loader_gnn_full:\n",
        "            optimizer_gnn.zero_grad()\n",
        "            data = data.to('cuda')\n",
        "            Y_pred = model_gnn(data.x, data.edge_index, data.batch)\n",
        "            loss = torch.nn.functional.mse_loss(Y_pred, data.y)\n",
        "            loss.backward()\n",
        "            optimizer_gnn.step()\n",
        "          print(loss)\n",
        "          # torch.save(model_gnn,\"drive/MyDrive/MeasureMaps/GNN/TransformerConv Full/fm-trial-\"+str(t)+\"-depth-\"+str(d)+\"-width-\"+str(h)+\"-lr-\"+str(lr)+\"-epoch-\"+str(i)+\".pt\")\n",
        "          scheduler_gnn.step()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cfNKKVtzhNvg"
      },
      "outputs": [],
      "source": [
        "from tqdm import tqdm\n",
        "\n",
        "depths = [3,4,5]\n",
        "widths = [128,256,512]\n",
        "lrs = [1e-4, 2e-4, 1e-3]\n",
        "T = 5\n",
        "\n",
        "epochs = 11\n",
        "\n",
        "for t in range(T):\n",
        "  for d in depths:\n",
        "    for h in widths:\n",
        "      for lr in lrs:\n",
        "        if os.path.exists(\"drive/MyDrive/MeasureMaps/GNN/TransformerConv KNN/fm-trial-\"+str(t)+\"-depth-\"+str(d)+\"-width-\"+str(h)+\"-lr-\"+str(lr)+\"-epoch-\"+str(10)+\".pt\"):\n",
        "          print(\"Done with\",t,d,h,lr)\n",
        "          continue\n",
        "        model_gnn = GNN(4,4,d,h,arch=\"Transformer\").to('cuda')\n",
        "\n",
        "        optimizer_gnn = torch.optim.Adam(model_gnn.parameters(), lr = lr)\n",
        "        scheduler_gnn = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_gnn, epochs)\n",
        "\n",
        "        for i in tqdm(range(epochs)):\n",
        "          for data in train_loader_gnn_knn:\n",
        "            optimizer_gnn.zero_grad()\n",
        "            data = data.to('cuda')\n",
        "            Y_pred = model_gnn(data.x, data.edge_index, data.batch)\n",
        "            loss = torch.nn.functional.mse_loss(Y_pred, data.y)\n",
        "            loss.backward()\n",
        "            optimizer_gnn.step()\n",
        "          torch.save(model_gnn,\"drive/MyDrive/MeasureMaps/GNN/TransformerConv KNN/fm-trial-\"+str(t)+\"-depth-\"+str(d)+\"-width-\"+str(h)+\"-lr-\"+str(lr)+\"-epoch-\"+str(i)+\".pt\")\n",
        "          scheduler_gnn.step()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aU4pWjQGymWh"
      },
      "source": [
        "# Testing"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "Xtrn, Ytrn = torch.load(\"drive/MyDrive/MeasureMaps/FishMilling.pt\")\n",
        "train_idx = torch.load(\"drive/MyDrive/MeasureMaps/FM-data-split.pt\")"
      ],
      "metadata": {
        "id": "NZJNc7DBHnXP"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "val_idx = torch.load(\"drive/MyDrive/MeasureMaps/FM-val-idx.pt\")\n",
        "test_idx = torch.load(\"drive/MyDrive/MeasureMaps/FM-test-idx.pt\")"
      ],
      "metadata": {
        "id": "F-uW2vtuHnLb"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "dataset_knn = data_to_GNN_data_fm(Xtrn, Ytrn, 3)\n",
        "dataset_full = data_to_GNN_data_fm(Xtrn, Ytrn, 20)"
      ],
      "metadata": {
        "id": "vKgbyq4fHnGK"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "val_X = [Xtrn[i] for i in val_idx]\n",
        "val_Y = [Ytrn[i] for i in val_idx]\n",
        "\n",
        "val_loader_gnn_full = DataLoader([dataset_full[i] for i in val_idx.numpy()], shuffle = True, batch_size = 1)\n",
        "val_loader_gnn_knn = DataLoader([dataset_knn[i] for i in val_idx.numpy()], shuffle = True, batch_size = 1)\n",
        "\n",
        "test_X = [Xtrn[i] for i in test_idx]\n",
        "test_Y = [Ytrn[i] for i in test_idx]\n",
        "\n",
        "test_loader_gnn_full = DataLoader([dataset_full[i] for i in test_idx.numpy()], shuffle = True, batch_size = 1)\n",
        "test_loader_gnn_knn = DataLoader([dataset_knn[i] for i in test_idx.numpy()], shuffle = True, batch_size = 1)"
      ],
      "metadata": {
        "id": "p1tTEuQAHnAT"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def test_transformer_fm(model, loader_X, loader_Y):\n",
        "  N = len(loader_Y)\n",
        "  loss = torch.zeros(N)\n",
        "  model.eval()\n",
        "  for i in range(N):\n",
        "    X = loader_X[i].to('cuda')\n",
        "    Y = loader_Y[i].to('cuda')\n",
        "    Y_pred = model(X)\n",
        "    loss[i] = torch.nn.functional.mse_loss(Y_pred, Y).cpu().detach()\n",
        "  return loss"
      ],
      "metadata": {
        "id": "B6Q_cEZ8Hm5S"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "depths = [3,4,5]\n",
        "widths = [128,256,512]\n",
        "lrs = [2e-4, 1e-3, 1e-4]\n",
        "T = 5\n",
        "\n",
        "best_loss = 100000\n",
        "best_d = -1\n",
        "best_w = -1\n",
        "best_lr = -1\n",
        "\n",
        "for d in depths: # depth\n",
        "  for h in widths: # width\n",
        "    for lr in lrs:\n",
        "      avg_loss = 0\n",
        "      print(\"Testing: \",d,h,lr)\n",
        "      for t in range(T):\n",
        "        model = torch.load(\"drive/MyDrive/MeasureMaps/Transformer/fm-trial-\"+str(t)+\"-depth-\"+str(d)+\"-width-\"+str(h)+\"-lr-\"+str(lr)+\"-epoch-10.pt\")\n",
        "        loss = test_transformer_fm(model, val_X, val_Y)\n",
        "        avg_loss += loss.mean()/T\n",
        "      if avg_loss < best_loss:\n",
        "        best_loss = avg_loss\n",
        "        best_d = d\n",
        "        best_w = h\n",
        "        best_lr = lr\n",
        "print(best_loss, best_d, best_w, best_lr)"
      ],
      "metadata": {
        "id": "yYQT51XrHmzT"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "OdASyiDp925D",
        "outputId": "a7147821-23bc-4705-96fa-94379ef46f7f"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "tensor(0.0222) tensor(1.8757e-05)\n"
          ]
        }
      ],
      "source": [
        "loss_avg = torch.zeros(5)\n",
        "loss_std = torch.zeros(5)\n",
        "for t in range(T):\n",
        "  model = torch.load(\"drive/MyDrive/MeasureMaps/Transformer/fm-trial-\"+str(t)+\"-depth-\"+str(best_d)+\"-width-\"+str(best_w)+\"-lr-\"+str(best_lr)+\"-epoch-10.pt\")\n",
        "  loss = test_transformer_fm(model, test_X, test_Y)\n",
        "  loss_avg[t] = loss.mean()\n",
        "  loss_std[t] = loss.std()\n",
        "print(loss_avg.mean(), loss_avg.std())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "o27GEiwgBBB6"
      },
      "outputs": [],
      "source": [
        "def test_GNN_fm(model, loader):\n",
        "  N = len(loader)\n",
        "  loss = torch.zeros(N)\n",
        "  model.eval()\n",
        "  i = 0\n",
        "  for data in loader:\n",
        "    data = data.to('cuda')\n",
        "    Y_pred = model(data.x, data.edge_index, data.batch)\n",
        "    loss[i] = torch.nn.functional.mse_loss(Y_pred, data.y).cpu().detach()\n",
        "    i += 1\n",
        "  return loss"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "z9ghP4aVBetS",
        "outputId": "4a2ba166-78b5-40ec-ffe6-5d794cdf527b"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Testing:  3 128 0.0002\n",
            "tensor(24886930.)\n",
            "Testing:  3 128 0.001\n",
            "tensor(8389030.)\n",
            "Testing:  3 128 0.0001\n",
            "tensor(1.5088e+08)\n",
            "Testing:  3 256 0.0002\n",
            "tensor(7238146.5000)\n",
            "Testing:  3 256 0.001\n",
            "tensor(8501644.)\n",
            "Testing:  3 256 0.0001\n",
            "tensor(7180415.5000)\n",
            "Testing:  3 512 0.0002\n",
            "tensor(6561584.)\n",
            "Testing:  3 512 0.001\n",
            "tensor(5210658.)\n",
            "Testing:  3 512 0.0001\n",
            "tensor(4909415.)\n",
            "Testing:  4 128 0.0002\n",
            "tensor(1.1334e+09)\n",
            "Testing:  4 128 0.001\n",
            "tensor(4.5062e+08)\n",
            "Testing:  4 128 0.0001\n",
            "tensor(2.8523e+09)\n",
            "Testing:  4 256 0.0002\n",
            "tensor(4.7832e+08)\n",
            "Testing:  4 256 0.001\n",
            "tensor(2.3417e+08)\n",
            "Testing:  4 256 0.0001\n",
            "tensor(1.6139e+09)\n",
            "Testing:  4 512 0.0002\n",
            "tensor(3.8287e+08)\n",
            "Testing:  4 512 0.001\n",
            "tensor(1.9696e+08)\n",
            "Testing:  4 512 0.0001\n",
            "tensor(3.9414e+08)\n",
            "Testing:  5 128 0.0002\n",
            "tensor(5.4111e+10)\n",
            "Testing:  5 128 0.001\n",
            "tensor(2.2792e+10)\n",
            "Testing:  5 128 0.0001\n",
            "tensor(3.1073e+11)\n",
            "Testing:  5 256 0.0002\n",
            "tensor(2.9897e+10)\n",
            "Testing:  5 256 0.001\n",
            "tensor(1.2194e+10)\n",
            "Testing:  5 256 0.0001\n",
            "tensor(4.4744e+10)\n",
            "Testing:  5 512 0.0002\n",
            "tensor(2.8784e+10)\n",
            "Testing:  5 512 0.001\n",
            "tensor(1.2767e+10)\n",
            "Testing:  5 512 0.0001\n",
            "tensor(3.0203e+10)\n",
            "100000 -1 -1 -1\n"
          ]
        }
      ],
      "source": [
        "depths = [3,4,5]\n",
        "widths = [128,256,512]\n",
        "lrs = [2e-4, 1e-3, 1e-4]\n",
        "T = 5\n",
        "\n",
        "best_loss = 100000\n",
        "best_d = -1\n",
        "best_w = -1\n",
        "best_lr = -1\n",
        "\n",
        "for d in depths: # depth\n",
        "  for h in widths: # width\n",
        "    for lr in lrs:\n",
        "      avg_loss = 0\n",
        "      print(\"Testing: \",d,h,lr)\n",
        "      for t in range(T):\n",
        "        model = torch.load(\"drive/MyDrive/MeasureMaps/GNN/GraphConv Full/fm-trial-\"+str(t)+\"-depth-\"+str(d)+\"-width-\"+str(h)+\"-lr-\"+str(lr)+\"-epoch-10.pt\")\n",
        "        loss = test_GNN_fm(model, val_loader_gnn_full)\n",
        "        avg_loss += loss.mean()/T\n",
        "      print(avg_loss)\n",
        "      if avg_loss < best_loss:\n",
        "        best_loss = avg_loss\n",
        "        best_d = d\n",
        "        best_w = h\n",
        "        best_lr = lr\n",
        "print(best_loss, best_d, best_w, best_lr)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "pDvQfqO2BooA",
        "outputId": "14f7c538-a83a-4b69-84c7-f55c52f72934"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "tensor(0.1210) tensor(0.0207)\n"
          ]
        }
      ],
      "source": [
        "loss_avg = torch.zeros(5)\n",
        "loss_std = torch.zeros(5)\n",
        "for t in range(T):\n",
        "  model = torch.load(\"drive/MyDrive/MeasureMaps/GNN/GraphConv Full/fm-trial-\"+str(t)+\"-depth-\"+str(best_d)+\"-width-\"+str(best_w)+\"-lr-\"+str(best_lr)+\"-epoch-10.pt\")\n",
        "  loss = test_GNN_fm(model, test_loader_gnn_full)\n",
        "  loss_avg[t] = loss.mean()\n",
        "  loss_std[t] = loss.std()\n",
        "print(loss_avg.mean(), loss_avg.std())"
      ]
    }
  ]
}