{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "hsNadJLO_CPq",
        "outputId": "4ea90691-f247-4846-c6fc-878dd0fa1b5c"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "2.4.1 cu121\n"
          ]
        }
      ],
      "source": [
        "import torch\n",
        "TORCH = torch.__version__.split('+')[0]\n",
        "CUDA = 'cu' + torch.version.cuda.replace('.','')\n",
        "print(TORCH, CUDA)\n",
        "\n",
        "import pickle\n",
        "import os"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "%%capture\n",
        "!pip install pyg_lib -f https://data.pyg.org/whl/torch-2.4.0+cu121.html\n",
        "!pip install torch-scatter     -f https://data.pyg.org/whl/torch-2.4.0+cu121.html\n",
        "!pip install torch-cluster -f https://data.pyg.org/whl/torch-2.4.0+cu121.html\n",
        "!pip install torch-spline-conv -f https://data.pyg.org/whl/torch-2.4.0+cu121.html\n",
        "!pip install torch-sparse -f https://data.pyg.org/whl/torch-2.4.0+cu121.html\n",
        "!pip install torch-geometric"
      ],
      "metadata": {
        "id": "d6npCCK-_JtD"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import torch_geometric\n",
        "import torch_geometric.nn as geom_nn\n",
        "import torch_geometric.data as geom_data\n",
        "from torch_geometric.loader import DataLoader\n",
        "from torch_geometric.data import InMemoryDataset\n",
        "from torch_geometric.data import Data\n",
        "\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "import networkx as nx\n",
        "\n",
        "from scipy.linalg import fractional_matrix_power\n",
        "from scipy.spatial.distance import pdist\n",
        "from scipy.spatial.distance import squareform\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "import matplotlib as matplotlib\n",
        "import matplotlib.cm as cm\n",
        "\n",
        "from tqdm import tqdm\n",
        "\n",
        "import math\n",
        "from numba import cuda\n",
        "import numpy as np\n",
        "import torch\n",
        "from torch_geometric.utils import (\n",
        "    to_networkx,\n",
        "    from_networkx,\n",
        "    to_dense_adj,\n",
        "    remove_self_loops,\n",
        "    to_undirected,\n",
        ")"
      ],
      "metadata": {
        "id": "QG5QUEIe_M2L"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/drive')"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "v77ZZalQ_O5D",
        "outputId": "886a8895-4639-456d-d063-36bfb6cc4a85"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mounted at /content/drive\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Set up Functions"
      ],
      "metadata": {
        "id": "4atzWaGU_Rj1"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def data_to_kNN(X,k):\n",
        "  edge_index = torch_geometric.nn.knn_graph(X, k)\n",
        "  edge_index = to_undirected(edge_index, num_nodes=X.shape[0])\n",
        "  return edge_index\n",
        "\n",
        "def data_to_GNN_data(X,Y,k):\n",
        "  dataset = []\n",
        "  N = X.shape[0]\n",
        "  for i in range(N):\n",
        "    data = Data(x = X[i,:,:], y = Y[i,:,:], edge_index = data_to_kNN(X[i,:,:],k))\n",
        "    data.num_nodes = 20\n",
        "    data.num_edges = data.edge_index.shape[1]\n",
        "    dataset.append(data)\n",
        "  return dataset\n",
        "\n",
        "def data_to_GNN_data_fm(X,Y,k):\n",
        "  dataset = []\n",
        "  N = len(X)\n",
        "  for i in tqdm(range(N)):\n",
        "    data = Data(x = X[i], y = Y[i], edge_index = data_to_kNN(X[i],k))\n",
        "    data.num_nodes = X[i].shape[0]\n",
        "    data.num_edges = data.edge_index.shape[1]\n",
        "    dataset.append(data)\n",
        "  return dataset"
      ],
      "metadata": {
        "id": "B7e_hwuG_REL"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Cucker-Smale Consensus Model (2D)\n",
        "\n",
        "import numpy as np\n",
        "from scipy import integrate\n",
        "import matplotlib.pyplot as plt\n",
        "import torch\n",
        "\n",
        "def CS(t,z,N):\n",
        "\n",
        "  dxdt = np.zeros(N); dydt = np.zeros(N)\n",
        "  dudt = np.zeros(N); dvdt = np.zeros(N)\n",
        "\n",
        "  x = z[0:N]; y = z[N:2*N]\n",
        "  u = z[2*N:3*N]; v = z[3*N:]\n",
        "\n",
        "  Fx = np.zeros((N,N)); Fy = np.zeros((N,N))\n",
        "\n",
        "  for i in range(0,N):\n",
        "    for j in range(0,N):\n",
        "        Fx[i,j] = (1/N)*Phi(np.linalg.norm(np.array([x[i]-x[j], y[i]-y[j]])))*(u[j]-u[i])\n",
        "        Fy[i,j] = (1/N)*Phi(np.linalg.norm(np.array([x[i]-x[j], y[i]-y[j]])))*(v[j]-v[i])\n",
        "\n",
        "  for i in range(0,N):\n",
        "    dxdt[i] = u[i]\n",
        "    dydt[i] = v[i]\n",
        "    dudt[i] = np.sum(Fx[i,:])\n",
        "    dvdt[i] = np.sum(Fy[i,:])\n",
        "\n",
        "  dzdt = np.concatenate((dxdt, dydt, dudt, dvdt))\n",
        "  return dzdt\n",
        "\n",
        "def Phi(r):\n",
        "   H = 1; s = 1; b = 1\n",
        "   y = H/(s**2+r**2)**b\n",
        "   return y\n",
        "\n",
        "# %%%%%%%%%%%%%%%%%%%%%%% Main %%%%%%%%%%%%%%%%%%%%%%%\n",
        "\n",
        "N = 20 # Number of agents\n",
        "T = 100 # Final time\n",
        "\n",
        "x0 = np.random.uniform(0,1,N)\n",
        "y0 = np.random.uniform(0,1,N)\n",
        "u0 = np.random.uniform(0,1,N)\n",
        "v0 = np.random.uniform(0,1,N)\n",
        "z0 = np.concatenate((x0,y0,u0,v0))\n",
        "\n",
        "t_eval = torch.tensor(range(0,500))*T/500\n",
        "\n",
        "soln = integrate.solve_ivp(CS, (0, T), z0, method='BDF',args=[N])\n",
        "\n",
        "x = soln.y[0:N,:]; y = soln.y[N:2*N,:];\n",
        "u = soln.y[2*N:3*N,:]; v = soln.y[3*N:4*N,:]"
      ],
      "metadata": {
        "id": "2bbsQUEr_Vmm"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "from tqdm import tqdm\n",
        "\n",
        "X = torch.tensor(soln.y)\n",
        "t = torch.tensor(soln.t)\n",
        "\n",
        "N = 20 # Number of agents\n",
        "T_end = 100 # Final time\n",
        "N_initials = 500 # Number of initializations\n",
        "\n",
        "\n",
        "def create_data_one_initialization(N,T_end):\n",
        "    x0 = np.random.uniform(0,1,N)\n",
        "    y0 = np.random.uniform(0,1,N)\n",
        "    u0 = np.random.uniform(0,1,N)\n",
        "    v0 = np.random.uniform(0,1,N)\n",
        "    z0 = np.concatenate((x0,y0,u0,v0))\n",
        "\n",
        "    soln = integrate.solve_ivp(CS, (0, T_end), z0, method='BDF',args=[N])\n",
        "\n",
        "    X = torch.tensor(soln.y) # 4N x t =[x;y;xdot;ydot] X t\n",
        "    t = torch.tensor(soln.t)\n",
        "\n",
        "    Xtrn, Ytrn = create_y(X,t)\n",
        "    Xtrn = resize_tensor(Xtrn)\n",
        "    Ytrn = resize_tensor(Ytrn)\n",
        "\n",
        "    return Xtrn, Ytrn\n",
        "\n",
        "def create_data(N, T_end, N_initials):\n",
        "  Xtrn, Ytrn = create_data_one_initialization(N,T_end)\n",
        "  for i in tqdm(range(N_initials-1)):\n",
        "    Xtrn_new, Ytrn_new = create_data_one_initialization(N,T_end)\n",
        "\n",
        "    Xtrn = torch.cat((Xtrn, Xtrn_new), dim = 0)\n",
        "    Ytrn = torch.cat((Ytrn, Ytrn_new), dim = 0)\n",
        "\n",
        "  return Xtrn, Ytrn\n",
        "\n",
        "def create_y(X,t):\n",
        "  N = X.shape[0]\n",
        "  T = X.shape[1]\n",
        "\n",
        "  n = N//4\n",
        "\n",
        "  y = torch.zeros(N,T)\n",
        "  for i in range(1,T-1):\n",
        "    y[:n,i] = X[2*n:3*n, i]\n",
        "    y[n:2*n,i] = X[3*n:, i]\n",
        "    y[2*n:,i] = (X[2*n:,i+1] - X[2*n:,i-1])/(t[i+1]-t[i-1])\n",
        "\n",
        "  return X[:,1:-1], y[:,1:-1]\n",
        "\n",
        "def resize_tensor(A):\n",
        "  N, T = A.shape\n",
        "  n = N//4\n",
        "  new_A = torch.zeros(T, N//4, 4)\n",
        "\n",
        "  new_A[:,:,0] = A.T[:,:n]\n",
        "  new_A[:,:,1] = A.T[:,n:2*n]\n",
        "  new_A[:,:,2] = A.T[:,2*n:3*n]\n",
        "  new_A[:,:,3] = A.T[:,3*n:]\n",
        "\n",
        "  return new_A"
      ],
      "metadata": {
        "id": "Gjac7ljMBrid"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Set up the Models"
      ],
      "metadata": {
        "id": "Iz0bQRLdBj-9"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UsT13ZQ4IOro"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import scipy.io as sio\n",
        "import torch.nn as nn"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "efkeaJDcIPSu"
      },
      "outputs": [],
      "source": [
        "# Throughout this we are going to assume that data is of the form B x N x D\n",
        "# Where B is the batch size, N is the sequence length for the transformer\n",
        "# this the number of data points. Finally D is the embedding dimension.\n",
        "class SimpleAttention(nn.Module):\n",
        "  # Initialize the parameter\n",
        "  def __init__(self, hidden_dim):\n",
        "    super(SimpleAttention, self).__init__()\n",
        "    self.linear = nn.Linear(hidden_dim, hidden_dim)\n",
        "    self.WQ = nn.Linear(hidden_dim, hidden_dim, bias = False)\n",
        "    self.WK = nn.Linear(hidden_dim, hidden_dim, bias = False)\n",
        "    self.WV = nn.Linear(hidden_dim, hidden_dim, bias = False)\n",
        "    self.skip = nn.Linear(hidden_dim, hidden_dim)\n",
        "    self.attention = nn.MultiheadAttention(hidden_dim, 1, batch_first=True)\n",
        "\n",
        "  # Forward pass\n",
        "  def forward(self, input):\n",
        "    Q = self.WQ(input)\n",
        "    K = self.WK(input)\n",
        "    V = self.WV(input)\n",
        "    output_attention,_ = self.attention(Q,K,V)\n",
        "    output_linear = self.linear(output_attention.relu()).relu()\n",
        "    return output_linear + self.skip(input)\n",
        "\n",
        "class SimpleTransformer(nn.Module):\n",
        "# Initialize the parameter\n",
        "  def __init__(self, input_dim, hidden_dim, out_dim, num_layers):\n",
        "    super(SimpleTransformer, self).__init__()\n",
        "    self.embed = nn.Linear(input_dim, hidden_dim)\n",
        "    self.predictor = nn.Linear(hidden_dim, out_dim)\n",
        "    self.AttentionLayers = []\n",
        "    for i in range(num_layers):\n",
        "      self.AttentionLayers.append(SimpleAttention(hidden_dim))\n",
        "\n",
        "    self.AttentionLayers = nn.ModuleList(self.AttentionLayers)\n",
        "    self.num_layers = num_layers\n",
        "\n",
        "  # Forward pass\n",
        "  def forward(self, z):\n",
        "    z = self.embed(z)\n",
        "    for i in range(self.num_layers):\n",
        "      z = self.AttentionLayers[i](z)\n",
        "    return self.predictor(z)\n",
        "\n",
        "\n",
        "class FNN(nn.Module):\n",
        "  def __init__(self, input_dim, hidden_dim, out_dim, num_layers, d):\n",
        "    super(FNN, self).__init__()\n",
        "    self.embed = nn.Linear(input_dim, hidden_dim)\n",
        "    self.predictor = nn.Linear(hidden_dim, out_dim)\n",
        "    self.Layers = []\n",
        "    for i in range(num_layers):\n",
        "      self.Layers.append(nn.Linear(hidden_dim, hidden_dim))\n",
        "\n",
        "    self.Layers = nn.ModuleList(self.Layers)\n",
        "    self.num_layers = num_layers\n",
        "    self.d = d\n",
        "\n",
        "  # Input size B x N X 4\n",
        "  def forward(self,x):\n",
        "    x = torch.flatten(x, 1, 2) # B x N4\n",
        "    x = self.embed(x).relu()\n",
        "    for i in range(self.num_layers):\n",
        "      x = self.Layers[i](x).relu()\n",
        "    x = self.predictor(x)\n",
        "    x = torch.unflatten(x, 1, (-1,self.d))\n",
        "    return x\n",
        "\n",
        "def kernel_basis(X, d1, d2):\n",
        "  N = X.shape[0]\n",
        "  d = X.shape[1]\n",
        "\n",
        "  Phi = torch.zeros(N, (d1+2*d2)*d, device = X.device)\n",
        "  for i in range(d1+2*d2):\n",
        "    if i < d1:\n",
        "      for j in range(d):\n",
        "        Phi[:,d*i+j] = X[:,j].pow(i+1)\n",
        "    elif i < d1+d2:\n",
        "      for j in range(d):\n",
        "        k = i-d1+1\n",
        "        Phi[:,d*i+j] = (X[:,j]*k).sin()\n",
        "    else:\n",
        "      for j in range(d):\n",
        "        k = i-d1-d2+1\n",
        "        Phi[:,d*i+j] = torch.cos(X[:,j]*k)\n",
        "\n",
        "  return Phi\n",
        "\n",
        "class Kernel(nn.Module):\n",
        "  def __init__(self, feature_dim, out_dim, d, embed=kernel_basis):\n",
        "    super(Kernel, self).__init__()\n",
        "    self.predictor = nn.Linear(feature_dim, out_dim, device = \"cuda\")\n",
        "    self.embed = embed\n",
        "    self.d = d\n",
        "\n",
        "  # Input size B x N X 4\n",
        "  def forward(self,x):\n",
        "    x = self.predictor(x)\n",
        "    x = torch.unflatten(x, 1, (-1,self.d))\n",
        "    return x\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "e2g0EVF6JEqI"
      },
      "outputs": [],
      "source": [
        "def get_nn(in_channels, out_channels):\n",
        "  return torch.nn.Sequential(torch.nn.Linear(in_channels, out_channels), torch.nn.ReLU(),\n",
        "                             torch.nn.Linear(out_channels, out_channels))\n",
        "\n",
        "class GNN(torch.nn.Module):\n",
        "  def __init__(self, node_input_dim, output_dim, num_layers, hidden_dim = 128, device = \"cuda\", arch = \"Graph\"):\n",
        "    super().__init__()\n",
        "\n",
        "    self.num_layers = num_layers\n",
        "    self.layers = []\n",
        "    if arch == \"Transformer\":\n",
        "      self.layers.append(torch_geometric.nn.TransformerConv(node_input_dim, hidden_dim).to(device))\n",
        "      for i in range(num_layers-1):\n",
        "        self.layers.append(torch_geometric.nn.TransformerConv(hidden_dim, hidden_dim).to(device))\n",
        "    elif arch == \"Graph\":\n",
        "      self.layers.append(torch_geometric.nn.GraphConv(node_input_dim, hidden_dim).to(device))\n",
        "      for i in range(num_layers-1):\n",
        "        self.layers.append(torch_geometric.nn.GraphConv(hidden_dim, hidden_dim).to(device))\n",
        "    elif arch == \"GIN\":\n",
        "      self.layers.append(torch_geometric.nn.GINConv(get_nn(node_input_dim, hidden_dim)).to(device))\n",
        "      for i in range(num_layers-1):\n",
        "        self.layers.append(torch_geometric.nn.GINConv(get_nn(hidden_dim, hidden_dim)).to(device))\n",
        "\n",
        "    self.layers = torch.nn.ModuleList(self.layers)\n",
        "    self.lin = torch.nn.Linear(hidden_dim, output_dim).to(device)\n",
        "\n",
        "  def forward(self, x, edge_index, batch):\n",
        "    for i in range(self.num_layers):\n",
        "      x = self.layers[i](x, edge_index).relu()\n",
        "\n",
        "    z = self.lin(x)\n",
        "    return z"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Training the Model"
      ],
      "metadata": {
        "id": "Eyft-IBABxzb"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "Xtrn, Ytrn = torch.load(\"drive/MyDrive/MeasureMaps/data-cucker-2.pt\")\n",
        "\n",
        "train_idx = torch.load(\"drive/MyDrive/MeasureMaps/Cucker-data-split.pt\")\n",
        "\n",
        "train_X = [Xtrn[i] for i in train_idx]\n",
        "train_Y = [Ytrn[i] for i in train_idx]\n",
        "\n",
        "dataset_knn = data_to_GNN_data(Xtrn, Ytrn, 3)\n",
        "dataset_full = data_to_GNN_data(Xtrn, Ytrn, 20)\n",
        "train_data = torch.utils.data.TensorDataset(Xtrn[train_idx,:,:], Ytrn[train_idx,:,:])\n",
        "\n",
        "train_loader = torch.utils.data.DataLoader(train_data, shuffle = True, batch_size=500)\n",
        "train_loader_gnn_full = DataLoader([dataset_full[i] for i in train_idx.numpy()], shuffle = True, batch_size = 500)\n",
        "train_loader_gnn_knn = DataLoader([dataset_knn[i] for i in train_idx.numpy()], shuffle = True, batch_size = 500)"
      ],
      "metadata": {
        "id": "8ItBdD1PBznE"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "depths = [3,4,5]\n",
        "widths = [128,256,512]\n",
        "lrs = [1e-4, 2e-4, 1e-3]\n",
        "T = 5\n",
        "\n",
        "epochs = 1001\n",
        "\n",
        "for t in range(3,T):\n",
        "  for d in depths:\n",
        "    for h in widths:\n",
        "      for lr in lrs:\n",
        "        model_transformer = SimpleTransformer(4,h,4,d).to('cuda')\n",
        "\n",
        "        optimizer_transformer = torch.optim.Adam(model_transformer.parameters(), lr = lr)\n",
        "        scheduler_transformer = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_transformer, epochs)\n",
        "\n",
        "        for i in tqdm(range(epochs)):\n",
        "          for X,Y in train_loader:\n",
        "            optimizer_transformer.zero_grad()\n",
        "            Y_pred = model_transformer(X)\n",
        "            loss = torch.nn.functional.mse_loss(Y_pred, Y)\n",
        "            loss.backward()\n",
        "            optimizer_transformer.step()\n",
        "          if i % 100 == 0:\n",
        "            torch.save(model_transformer,\"drive/MyDrive/MeasureMaps/Transformer/trial-\"+str(t)+\"-depth-\"+str(d)+\"-width-\"+str(h)+\"-lr-\"+str(lr)+\"-epoch-\"+str(i)+\".pt\")\n",
        "          scheduler_transformer.step()"
      ],
      "metadata": {
        "id": "SlkLr-l3B8ou"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "degree = [2,3,4]\n",
        "frequencies = [4,5,6]\n",
        "lrs = [1e-4, 2e-4, 1e-3]\n",
        "T = 5\n",
        "\n",
        "epochs = 1001\n",
        "\n",
        "for t in range(T):\n",
        "  for d1 in degree:\n",
        "    for d2 in frequencies:\n",
        "      for lr in lrs:\n",
        "        Xtrn_data_kernel = kernel_basis(torch.flatten(Xtrn_data,1,2),d1,d2)\n",
        "        train_data_kernel = torch.utils.data.TensorDataset(Xtrn_data_kernel[train_idx,:], Ytrn_data[train_idx,:,:])\n",
        "        train_loader_kernel = torch.utils.data.DataLoader(train_data_kernel, shuffle = True, batch_size=500)\n",
        "\n",
        "        model_kernel = Kernel((d1+2*d2)*4*20,80,4).to('cuda')\n",
        "\n",
        "        optimizer_kernel = torch.optim.Adam(model_kernel.parameters(), lr = lr)\n",
        "        scheduler_kernel = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_kernel, epochs)\n",
        "\n",
        "        for i in tqdm(range(epochs)):\n",
        "          for X,Y in train_loader_kernel:\n",
        "            optimizer_kernel.zero_grad()\n",
        "            Y_pred = model_kernel(X)\n",
        "            loss = torch.nn.functional.mse_loss(Y_pred, Y)\n",
        "            loss.backward()\n",
        "            optimizer_kernel.step()\n",
        "          if i % 100 == 0:\n",
        "            torch.save(model_kernel,\"drive/MyDrive/MeasureMaps/Kernel/trial-\"+str(t)+\"-degree-\"+str(d1)+\"-frequency-\"+str(d2)+\"-lr-\"+str(lr)+\"-epoch-\"+str(i)+\".pt\")\n",
        "          scheduler_kernel.step()"
      ],
      "metadata": {
        "id": "ifnOylwFDtnK"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from tqdm import tqdm\n",
        "\n",
        "depths = [3,4,5]\n",
        "widths = [128,256,512]\n",
        "lrs = [1e-4, 2e-4, 1e-3]\n",
        "T = 5\n",
        "N = 20\n",
        "\n",
        "epochs = 1001\n",
        "\n",
        "for t in range(1):\n",
        "  for d in depths:\n",
        "    for h in widths:\n",
        "      for lr in lrs:\n",
        "        model_fnn = FNN(4*N,h,4*N,d,4).to('cuda')\n",
        "\n",
        "\n",
        "        optimizer_fnn = torch.optim.Adam(model_fnn.parameters(), lr = lr)\n",
        "        scheduler_fnn = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_fnn, epochs)\n",
        "\n",
        "        for i in tqdm(range(epochs)):\n",
        "          for X,Y in train_loader:\n",
        "            optimizer_fnn.zero_grad()\n",
        "            Y_pred = model_fnn(X)\n",
        "            loss = torch.nn.functional.mse_loss(Y_pred, Y)\n",
        "            loss.backward()\n",
        "            optimizer_fnn.step()\n",
        "          if i % 100 == 0:\n",
        "            torch.save(model_fnn,\"drive/MyDrive/MeasureMaps/FNN/trial-\"+str(t)+\"-depth-\"+str(d)+\"-width-\"+str(h)+\"-lr-\"+str(lr)+\"-epoch-\"+str(i)+\".pt\")\n",
        "          scheduler_fnn.step()"
      ],
      "metadata": {
        "id": "yh98hLRYDwoR"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from tqdm import tqdm\n",
        "\n",
        "depths = [3,4,5]\n",
        "widths = [128,256,512]\n",
        "lrs = [1e-4, 2e-4, 1e-3]\n",
        "T = 5\n",
        "\n",
        "epochs = 1001\n",
        "\n",
        "for t in range(T):\n",
        "  for d in depths:\n",
        "    for h in widths:\n",
        "      for lr in lrs:\n",
        "        model_gnn = GNN(4,4,d,h).to('cuda')\n",
        "\n",
        "\n",
        "        optimizer_gnn = torch.optim.Adam(model_gnn.parameters(), lr = lr)\n",
        "        scheduler_gnn = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_gnn, epochs)\n",
        "\n",
        "        for i in tqdm(range(epochs)):\n",
        "          for data in train_loader_gnn_full:\n",
        "            optimizer_gnn.zero_grad()\n",
        "            data = data.to('cuda')\n",
        "            Y_pred = model_gnn(data.x, data.edge_index, data.batch)\n",
        "            loss = torch.nn.functional.mse_loss(Y_pred, data.y)\n",
        "            loss.backward()\n",
        "            optimizer_gnn.step()\n",
        "          if i % 100 == 0:\n",
        "            torch.save(model_gnn,\"drive/MyDrive/MeasureMaps/GNN/GraphConv Full/trial-\"+str(t)+\"-depth-\"+str(d)+\"-width-\"+str(h)+\"-lr-\"+str(lr)+\"-epoch-\"+str(i)+\".pt\")\n",
        "          scheduler_gnn.step()"
      ],
      "metadata": {
        "id": "Frzi8tVBD0Te"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "depths = [3,4,5]\n",
        "widths = [128,256,512]\n",
        "lrs = [1e-4, 2e-4, 1e-3]\n",
        "T = 5\n",
        "\n",
        "epochs = 1001\n",
        "\n",
        "for t in range(T):\n",
        "  for d in depths:\n",
        "    for h in widths:\n",
        "      for lr in lrs:\n",
        "        model_gnn = GNN(4,4,d,h).to('cuda')\n",
        "\n",
        "\n",
        "        optimizer_gnn = torch.optim.Adam(model_gnn.parameters(), lr = lr)\n",
        "        scheduler_gnn = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_gnn, epochs)\n",
        "\n",
        "        for i in tqdm(range(epochs)):\n",
        "          for data in train_loader_gnn_knn:\n",
        "            optimizer_gnn.zero_grad()\n",
        "            data = data.to('cuda')\n",
        "            Y_pred = model_gnn(data.x, data.edge_index, data.batch)\n",
        "            loss = torch.nn.functional.mse_loss(Y_pred, data.y)\n",
        "            loss.backward()\n",
        "            optimizer_gnn.step()\n",
        "          if i % 100 == 0:\n",
        "            torch.save(model_gnn,\"drive/MyDrive/MeasureMaps/GNN/GraphConv KNN/trial-\"+str(t)+\"-depth-\"+str(d)+\"-width-\"+str(h)+\"-lr-\"+str(lr)+\"-epoch-\"+str(i)+\".pt\")\n",
        "          scheduler_gnn.step()"
      ],
      "metadata": {
        "id": "JMQQ9-bUD4sw"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "depths = [3,4,5]\n",
        "widths = [128,256,512]\n",
        "lrs = [1e-4, 2e-4, 1e-3]\n",
        "T = 5\n",
        "\n",
        "epochs = 1001\n",
        "\n",
        "for t in range(T):\n",
        "  for d in depths:\n",
        "    for h in widths:\n",
        "      for lr in lrs:\n",
        "        model_gnn = GNN(4,4,d,h,arch=\"Transformer\").to('cuda')\n",
        "\n",
        "\n",
        "        optimizer_gnn = torch.optim.Adam(model_gnn.parameters(), lr = lr)\n",
        "        scheduler_gnn = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_gnn, epochs)\n",
        "\n",
        "        for i in tqdm(range(epochs)):\n",
        "          for data in train_loader_gnn_full:\n",
        "            optimizer_gnn.zero_grad()\n",
        "            data = data.to('cuda')\n",
        "            Y_pred = model_gnn(data.x, data.edge_index, data.batch)\n",
        "            loss = torch.nn.functional.mse_loss(Y_pred, data.y)\n",
        "            loss.backward()\n",
        "            optimizer_gnn.step()\n",
        "          if i % 100 == 0:\n",
        "            torch.save(model_gnn,\"drive/MyDrive/MeasureMaps/GNN/TransformerConv Full/trial-\"+str(t)+\"-depth-\"+str(d)+\"-width-\"+str(h)+\"-lr-\"+str(lr)+\"-epoch-\"+str(i)+\".pt\")\n",
        "          scheduler_gnn.step()"
      ],
      "metadata": {
        "id": "n3Lyzt4BDg8t"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "depths = [3,4,5]\n",
        "widths = [128,256,512]\n",
        "lrs = [1e-4, 2e-4, 1e-3]\n",
        "T = 5\n",
        "\n",
        "epochs = 1001\n",
        "\n",
        "for t in range(T):\n",
        "  for d in depths:\n",
        "    for h in widths:\n",
        "      for lr in lrs:\n",
        "        model_gnn = GNN(4,4,d,h,arch=\"Transformer\").to('cuda')\n",
        "\n",
        "\n",
        "        optimizer_gnn = torch.optim.Adam(model_gnn.parameters(), lr = lr)\n",
        "        scheduler_gnn = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_gnn, epochs)\n",
        "\n",
        "        for i in tqdm(range(epochs)):\n",
        "          for data in train_loader_gnn_knn:\n",
        "            optimizer_gnn.zero_grad()\n",
        "            data = data.to('cuda')\n",
        "            Y_pred = model_gnn(data.x, data.edge_index, data.batch)\n",
        "            loss = torch.nn.functional.mse_loss(Y_pred, data.y)\n",
        "            loss.backward()\n",
        "            optimizer_gnn.step()\n",
        "          if i % 100 == 0:\n",
        "            torch.save(model_gnn,\"drive/MyDrive/MeasureMaps/GNN/TransformerConv KNN/trial-\"+str(t)+\"-depth-\"+str(d)+\"-width-\"+str(h)+\"-lr-\"+str(lr)+\"-epoch-\"+str(i)+\".pt\")\n",
        "          scheduler_gnn.step()"
      ],
      "metadata": {
        "id": "OQwIy3vQDg41"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Testing Models"
      ],
      "metadata": {
        "id": "gUMRTmNvEUJz"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "Xtrn, Ytrn = torch.load(\"drive/MyDrive/MeasureMaps/data-cucker-2.pt\")\n",
        "idxes = torch.load(\"drive/MyDrive/MeasureMaps/Cucker-data-split.pt\")"
      ],
      "metadata": {
        "id": "tChIUUC3Dg0D"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "N = Xtrn.shape[0]\n",
        "val_idx = idxes[1]\n",
        "test_idx = idxes[1]\n",
        "print(val_idx.shape, test_idx.shape)"
      ],
      "metadata": {
        "id": "0-FSGBLiDgsN"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "val_data = torch.utils.data.TensorDataset(Xtrn[val_idx,:], Ytrn[val_idx,:,:])\n",
        "val_loader = torch.utils.data.DataLoader(val_data, shuffle = True, batch_size=421)\n",
        "\n",
        "test_data = torch.utils.data.TensorDataset(Xtrn[test_idx,:], Ytrn[test_idx,:,:])\n",
        "test_loader = torch.utils.data.DataLoader(test_data, shuffle = True, batch_size=421)"
      ],
      "metadata": {
        "id": "AM_SYhhDDgjo"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def test_transformer_fnn_cs(model, loader):\n",
        "  N = 4\n",
        "  loss = torch.zeros(N)\n",
        "  model.eval()\n",
        "  i = 0\n",
        "  for X,Y in loader:\n",
        "    X = X.to('cuda')\n",
        "    Y = Y.to('cuda')\n",
        "    Y_pred = model(X)\n",
        "    loss[i] = torch.nn.functional.mse_loss(Y_pred, Y).cpu().detach()\n",
        "  return loss"
      ],
      "metadata": {
        "id": "H9YggsbgDgfS"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "dataset_knn = data_to_GNN_data(Xtrn, Ytrn, 3)\n",
        "dataset_full = data_to_GNN_data(Xtrn, Ytrn, 20)"
      ],
      "metadata": {
        "id": "U2NZHsr9E0Cw"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "val_loader_gnn_full = DataLoader([dataset_full[i] for i in val_idx.numpy()], shuffle = True, batch_size = 1)\n",
        "val_loader_gnn_knn = DataLoader([dataset_knn[i] for i in val_idx.numpy()], shuffle = True, batch_size = 1)\n",
        "\n",
        "test_loader_gnn_full = DataLoader([dataset_full[i] for i in test_idx.numpy()], shuffle = True, batch_size = 1)\n",
        "test_loader_gnn_knn = DataLoader([dataset_knn[i] for i in test_idx.numpy()], shuffle = True, batch_size = 1)"
      ],
      "metadata": {
        "id": "9F-5uqamE4Uz"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "depths = [5]\n",
        "widths = [256]\n",
        "lrs = [2e-4]\n",
        "T = 5\n",
        "\n",
        "best_loss = 100000\n",
        "best_d = -1\n",
        "best_w = -1\n",
        "best_lr = -1\n",
        "\n",
        "for d in depths: # depth\n",
        "  for h in widths: # width\n",
        "    for lr in lrs:\n",
        "      avg_loss = 0\n",
        "      print(\"Testing: \",d,h,lr)\n",
        "      for t in range(T):\n",
        "        model = torch.load(\"drive/MyDrive/MeasureMaps/FNN/trial-\"+str(t)+\"-depth-\"+str(d)+\"-width-\"+str(h)+\"-lr-\"+str(lr)+\"-epoch-1000.pt\")\n",
        "        loss = test_transformer_fnn_cs(model, val_loader)\n",
        "        avg_loss = loss.mean()\n",
        "        print(avg_loss)\n",
        "      if avg_loss < best_loss:\n",
        "        best_loss = avg_loss\n",
        "        best_d = d\n",
        "        best_w = h\n",
        "        best_lr = lr\n",
        "print(best_loss, best_d, best_w, best_lr)"
      ],
      "metadata": {
        "id": "omeGdQXFE_kk"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "loss_avg = torch.zeros(5)\n",
        "loss_std = torch.zeros(5)\n",
        "for t in range(T):\n",
        "  model = torch.load(\"drive/MyDrive/MeasureMaps/FNN/trial-\"+str(t)+\"-depth-\"+str(best_d)+\"-width-\"+str(best_w)+\"-lr-\"+str(best_lr)+\"-epoch-1000.pt\")\n",
        "  loss = test_transformer_fnn_cs(model, test_loader)\n",
        "  loss_avg[t] = loss.mean()\n",
        "  loss_std[t] = loss.std()\n",
        "print(loss_avg.mean(), loss_avg.std())"
      ],
      "metadata": {
        "id": "93uZeVi1FCZj"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def test_GNN(model, loader):\n",
        "  N = len(loader)\n",
        "  loss = torch.zeros(N)\n",
        "  model.eval()\n",
        "  i = 0\n",
        "  for data in loader:\n",
        "    data = data.to('cuda')\n",
        "    Y_pred = model(data.x, data.edge_index, data.batch)\n",
        "    loss[i] = torch.nn.functional.mse_loss(Y_pred, data.y).cpu().detach()\n",
        "    i += 1\n",
        "  return loss"
      ],
      "metadata": {
        "id": "GQ7wJCX_FVO-"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "depths = [3,4,5]\n",
        "widths = [128,256,512]\n",
        "lrs = [2e-4, 1e-3, 1e-4]\n",
        "T = 5\n",
        "\n",
        "best_loss = 100000\n",
        "best_d = -1\n",
        "best_w = -1\n",
        "best_lr = -1\n",
        "\n",
        "for d in depths: # depth\n",
        "  for h in widths: # width\n",
        "    for lr in lrs:\n",
        "      avg_loss = 0\n",
        "      print(\"Testing: \",d,h,lr)\n",
        "      for t in range(T):\n",
        "        model = torch.load(\"drive/MyDrive/MeasureMaps/GNN/GraphConv KNN/trial-\"+str(t)+\"-depth-\"+str(d)+\"-width-\"+str(h)+\"-lr-\"+str(lr)+\"-epoch-1000.pt\")\n",
        "        loss = test_GNN(model, val_loader_gnn_knn)\n",
        "        avg_loss += loss.mean()/T\n",
        "      print(avg_loss)\n",
        "      if avg_loss < best_loss:\n",
        "        best_loss = avg_loss\n",
        "        best_d = d\n",
        "        best_w = h\n",
        "        best_lr = lr\n",
        "print(best_loss, best_d, best_w, best_lr)"
      ],
      "metadata": {
        "id": "mtZZdhJ7FM-H"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "loss_avg = torch.zeros(5)\n",
        "loss_std = torch.zeros(5)\n",
        "for t in range(T):\n",
        "  model = torch.load(\"drive/MyDrive/MeasureMaps/GNN/GraphConv KNN/trial-\"+str(t)+\"-depth-\"+str(best_d)+\"-width-\"+str(best_w)+\"-lr-\"+str(best_lr)+\"-epoch-1000.pt\")\n",
        "  loss = test_GNN(model, test_loader_gnn_knn)\n",
        "  loss_avg[t] = loss.mean()\n",
        "  loss_std[t] = loss.std()\n",
        "print(loss_avg.mean(), loss_avg.std())"
      ],
      "metadata": {
        "id": "CjNp9Tx0FP39"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}