{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3dab437-3ccb-4ac0-aa94-503304cfe547",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'figure.figsize': (5.5, 1.699593469062211),\n",
       " 'figure.constrained_layout.use': True,\n",
       " 'figure.autolayout': False,\n",
       " 'savefig.bbox': 'tight',\n",
       " 'savefig.pad_inches': 0.015}"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "import torch_geometric\n",
    "\n",
    "from torch.nn import ModuleList\n",
    "\n",
    "from torch_geometric.data import Data\n",
    "from torch_geometric.utils import grid, remove_self_loops\n",
    "from torch_geometric.nn.conv import GCNConv\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from tueplots import bundles\n",
    "from tueplots import figsizes, fontsizes, fonts\n",
    "\n",
    "bundles.neurips2024()\n",
    "figsizes.neurips2024()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5b24da7-d775-40ec-9a11-959f8cef07ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_grid(size, self_loops=False):\n",
    "    torch.manual_seed(0)\n",
    "    # Create edge index for a 2D grid\n",
    "    (row, col), pos = grid(height=size, width=size)\n",
    "\n",
    "    if not self_loops:\n",
    "        edge_index, _ = remove_self_loops(torch.stack((row,col), dim=0))\n",
    "    else:\n",
    "        edge_index, _ = remove_self_loops(torch.stack((row,col), dim=0))\n",
    "    \n",
    "    # Total number of nodes\n",
    "    num_nodes = size*size\n",
    "    \n",
    "    # (Optional) Create dummy node features\n",
    "    x = torch.ones((num_nodes,1), requires_grad=True)\n",
    "    \n",
    "    # Create the graph data object\n",
    "    data = Data(x=x, edge_index=edge_index)\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ee4fced-ad4e-430e-b1d6-4defe2e07646",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LinearGCN(torch.nn.Module):\n",
    "\n",
    "    def __init__(self, dim_input, dim_emb, num_layers):\n",
    "        super().__init__()\n",
    "        self.num_layers = num_layers\n",
    "        layers = []\n",
    "        for l in range(num_layers):\n",
    "            layers.append(GCNConv(dim_input if l == 0 else dim_emb, dim_emb, add_self_loops=False, normalize=True))\n",
    "\n",
    "        self.layers = ModuleList(layers)\n",
    "\n",
    "    def forward(self, x, edge_index):\n",
    "        for l in range(self.num_layers):\n",
    "            h = self.layers[l](x if l == 0 else h, edge_index)\n",
    "\n",
    "            h.retain_grad()\n",
    "        return h"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "468b1024-4c15-4b28-aabe-e026f59c89b4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-04-28 12:27:24,155\tINFO worker.py:1783 -- Started a local Ray instance.\n"
     ]
    }
   ],
   "source": [
    "import ray\n",
    "import pandas as pd\n",
    "import torch\n",
    "from itertools import product\n",
    "\n",
    "# Initialize Ray\n",
    "ray.init(ignore_reinit_error=True)\n",
    "\n",
    "# Assuming get_grid and LinearGCN are already defined\n",
    "size = 10\n",
    "g = get_grid(size=size, self_loops=False)\n",
    "\n",
    "edge_index = g.edge_index\n",
    "x = g.x\n",
    "\n",
    "param_grid = list(product(\n",
    "    [1]+list(range(10, 510, 10)),\n",
    "    [8, 32, 64, 128],\n",
    "    list(range(1))\n",
    "))\n",
    "\n",
    "\n",
    "@ray.remote(\n",
    "    num_cpus=1,\n",
    "    num_gpus=0,\n",
    "    max_calls=1,\n",
    "    # max_calls=1 --> the worker automatically exits after executing the task\n",
    "    # (thereby releasing the GPU resources).\n",
    ")\n",
    "def compute_jacobian(args):\n",
    "    num_layers, dim_emb, edge_index, run_id, x = args\n",
    "    model = LinearGCN(dim_input=1, dim_emb=dim_emb, num_layers=num_layers)\n",
    "    def func(x_input):\n",
    "        return model(x_input, edge_index)\n",
    "\n",
    "    try:\n",
    "        out = torch.autograd.functional.jacobian(func, x)\n",
    "        mean_val = torch.norm(out, p=1).item()\n",
    "        shape = tuple(out.shape)\n",
    "    except Exception as e:\n",
    "        mean_val = None\n",
    "        shape = str(e)\n",
    "\n",
    "    return {\n",
    "        'num_layers': str(num_layers),\n",
    "        'size of node embedding': str(dim_emb),\n",
    "        'jacobian_shape': str(shape),\n",
    "        'jacobian_mean': mean_val,\n",
    "        'run_id': run_id\n",
    "    }\n",
    "\n",
    "# Launch Ray tasks\n",
    "futures = [\n",
    "    compute_jacobian.remote((num_layers, dim_emb, edge_index, run_id, x))\n",
    "    for num_layers, dim_emb, run_id in param_grid\n",
    "]\n",
    "\n",
    "# Collect results\n",
    "results = ray.get(futures)\n",
    "df = pd.DataFrame(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00fa7aec-faca-4905-baf0-fe21578b241c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "with plt.rc_context(bundles.neurips2024()):\n",
    "    sns.lineplot(data=df, x='num_layers',  y='jacobian_mean', hue='size of node embedding')\n",
    "    plt.xlabel('Number of Message Passing Layers')\n",
    "    plt.ylabel('Sensitivity')\n",
    "    legend = plt.legend(loc='upper right')\n",
    "    for text in legend.get_texts():\n",
    "        text.set_ha('right')  # Right-align each legend text\n",
    "\n",
    "    # rotate x-axis labels\n",
    "    plt.xticks(rotation=90)\n",
    "    \n",
    "\n",
    "    plt.savefig('figs/grid_sensitivity.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3e54cb0-17b9-4c1f-8afd-204c0d52df97",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ce61980-3cc7-4475-96f6-ccec132895a2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
