{"cells":[{"cell_type":"code","source":[],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"0RlETSQ-pYmm","executionInfo":{"status":"ok","timestamp":1743639235540,"user_tz":420,"elapsed":114,"user":{"displayName":"Divya Anand Sinha","userId":"00262442409674458618"}},"outputId":"a1d9b51c-88e2-45c7-9cb2-426dada00613"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["/bin/bash: line 1: nvidia-smi: command not found\n"]}]},{"cell_type":"code","source":["# Add this in a Google Colab cell to install the correct version of Pytorch Geometric.\n","import torch\n","print(torch.__version__)\n","print(torch.version.cuda)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"fceNfATkpppb","executionInfo":{"status":"ok","timestamp":1750381766236,"user_tz":420,"elapsed":4,"user":{"displayName":"Divyaanand Sinha","userId":"10890862448949785739"}},"outputId":"2ae491a6-4b1f-4c0e-8657-8f6792e0c243"},"execution_count":2,"outputs":[{"output_type":"stream","name":"stdout","text":["2.6.0+cu124\n","12.4\n"]}]},{"cell_type":"code","execution_count":3,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":20307,"status":"ok","timestamp":1750381787673,"user":{"displayName":"Divyaanand Sinha","userId":"10890862448949785739"},"user_tz":420},"id":"2JIX4guUyTU8","outputId":"047dcfa4-5d4d-4a36-b83e-18dd8c4c0046"},"outputs":[{"output_type":"stream","name":"stdout","text":["Looking in links: https://data.pyg.org/whl/torch-2.6.0+cu124.html\n","Collecting pyg_lib\n","  Downloading https://data.pyg.org/whl/torch-2.6.0%2Bcu124/pyg_lib-0.4.0%2Bpt26cu124-cp311-cp311-linux_x86_64.whl (4.7 MB)\n","\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.7/4.7 MB\u001b[0m \u001b[31m35.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting torch_scatter\n","  Downloading https://data.pyg.org/whl/torch-2.6.0%2Bcu124/torch_scatter-2.1.2%2Bpt26cu124-cp311-cp311-linux_x86_64.whl (10.8 MB)\n","\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.8/10.8 MB\u001b[0m \u001b[31m44.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting torch_sparse\n","  Downloading https://data.pyg.org/whl/torch-2.6.0%2Bcu124/torch_sparse-0.6.18%2Bpt26cu124-cp311-cp311-linux_x86_64.whl (5.0 MB)\n","\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.0/5.0 MB\u001b[0m \u001b[31m45.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting torch_cluster\n","  Downloading https://data.pyg.org/whl/torch-2.6.0%2Bcu124/torch_cluster-1.6.3%2Bpt26cu124-cp311-cp311-linux_x86_64.whl (3.4 MB)\n","\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.4/3.4 MB\u001b[0m \u001b[31m49.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting torch_spline_conv\n","  Downloading https://data.pyg.org/whl/torch-2.6.0%2Bcu124/torch_spline_conv-1.2.2%2Bpt26cu124-cp311-cp311-linux_x86_64.whl (1.0 MB)\n","\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m14.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: scipy in /usr/local/lib/python3.11/dist-packages (from torch_sparse) (1.15.3)\n","Requirement already satisfied: numpy<2.5,>=1.23.5 in /usr/local/lib/python3.11/dist-packages (from scipy->torch_sparse) (2.0.2)\n","Installing collected packages: torch_spline_conv, torch_scatter, pyg_lib, torch_sparse, torch_cluster\n","Successfully installed pyg_lib-0.4.0+pt26cu124 torch_cluster-1.6.3+pt26cu124 torch_scatter-2.1.2+pt26cu124 torch_sparse-0.6.18+pt26cu124 torch_spline_conv-1.2.2+pt26cu124\n","Collecting torch_geometric\n","  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)\n","\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m63.1/63.1 kB\u001b[0m \u001b[31m2.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: aiohttp in /usr/local/lib/python3.11/dist-packages (from torch_geometric) (3.11.15)\n","Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch_geometric) (2025.3.2)\n","Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch_geometric) (3.1.6)\n","Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from torch_geometric) (2.0.2)\n","Requirement already satisfied: psutil>=5.8.0 in /usr/local/lib/python3.11/dist-packages (from torch_geometric) (5.9.5)\n","Requirement already satisfied: pyparsing in /usr/local/lib/python3.11/dist-packages (from torch_geometric) (3.2.3)\n","Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from torch_geometric) (2.32.3)\n","Requirement already satisfied: tqdm in /usr/local/lib/python3.11/dist-packages (from torch_geometric) (4.67.1)\n","Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->torch_geometric) (2.6.1)\n","Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from aiohttp->torch_geometric) (1.3.2)\n","Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->torch_geometric) (25.3.0)\n","Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.11/dist-packages (from aiohttp->torch_geometric) (1.7.0)\n","Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.11/dist-packages (from aiohttp->torch_geometric) (6.4.4)\n","Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->torch_geometric) (0.3.2)\n","Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->torch_geometric) (1.20.1)\n","Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch_geometric) (3.0.2)\n","Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->torch_geometric) (3.4.2)\n","Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->torch_geometric) (3.10)\n","Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->torch_geometric) (2.4.0)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->torch_geometric) (2025.6.15)\n","Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)\n","\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m21.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hInstalling collected packages: torch_geometric\n","Successfully installed torch_geometric-2.6.1\n","Collecting rdkit\n","  Downloading rdkit-2025.3.3-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (4.0 kB)\n","Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from rdkit) (2.0.2)\n","Requirement already satisfied: Pillow in /usr/local/lib/python3.11/dist-packages (from rdkit) (11.2.1)\n","Downloading rdkit-2025.3.3-cp311-cp311-manylinux_2_28_x86_64.whl (34.9 MB)\n","\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m34.9/34.9 MB\u001b[0m \u001b[31m67.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hInstalling collected packages: rdkit\n","Successfully installed rdkit-2025.3.3\n"]}],"source":["!pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.6.0+cu124.html\n","!pip install torch_geometric\n","!pip install rdkit\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":2,"status":"ok","timestamp":1743552491648,"user":{"displayName":"Divya Anand Sinha","userId":"00262442409674458618"},"user_tz":420},"id":"olWGyxpWeuRw","outputId":"ebb6e2a8-0ae6-4743-a67c-c513ecb2b3bc"},"outputs":[{"output_type":"execute_result","data":{"text/plain":["torch.Size([16, 9])"]},"metadata":{},"execution_count":8}],"source":["data.x.shape\n"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":311},"executionInfo":{"elapsed":2993,"status":"error","timestamp":1743552495215,"user":{"displayName":"Divya Anand Sinha","userId":"00262442409674458618"},"user_tz":420},"id":"MELIY57ZdWtr","outputId":"c7163831-8352-4662-95c8-2e672e3953b9"},"outputs":[{"output_type":"error","ename":"MessageError","evalue":"Error: credential propagation was unsuccessful","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mMessageError\u001b[0m                              Traceback (most recent call last)","\u001b[0;32m<ipython-input-9-d5df0069828e>\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mgoogle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcolab\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdrive\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mdrive\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmount\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'/content/drive'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m","\u001b[0;32m/usr/local/lib/python3.11/dist-packages/google/colab/drive.py\u001b[0m in \u001b[0;36mmount\u001b[0;34m(mountpoint, force_remount, timeout_ms, readonly)\u001b[0m\n\u001b[1;32m     98\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mmount\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmountpoint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mforce_remount\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtimeout_ms\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m120000\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreadonly\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     99\u001b[0m   \u001b[0;34m\"\"\"Mount your Google Drive at the specified mountpoint path.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 100\u001b[0;31m   return _mount(\n\u001b[0m\u001b[1;32m    101\u001b[0m       \u001b[0mmountpoint\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    102\u001b[0m       \u001b[0mforce_remount\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mforce_remount\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.11/dist-packages/google/colab/drive.py\u001b[0m in \u001b[0;36m_mount\u001b[0;34m(mountpoint, force_remount, timeout_ms, ephemeral, readonly)\u001b[0m\n\u001b[1;32m    135\u001b[0m   )\n\u001b[1;32m    136\u001b[0m   \u001b[0;32mif\u001b[0m \u001b[0mephemeral\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 137\u001b[0;31m     _message.blocking_request(\n\u001b[0m\u001b[1;32m    138\u001b[0m         \u001b[0;34m'request_auth'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    139\u001b[0m         \u001b[0mrequest\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0;34m'authType'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'dfs_ephemeral'\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.11/dist-packages/google/colab/_message.py\u001b[0m in \u001b[0;36mblocking_request\u001b[0;34m(request_type, request, timeout_sec, parent)\u001b[0m\n\u001b[1;32m    174\u001b[0m       \u001b[0mrequest_type\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrequest\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparent\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mparent\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexpect_reply\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    175\u001b[0m   )\n\u001b[0;32m--> 176\u001b[0;31m   \u001b[0;32mreturn\u001b[0m \u001b[0mread_reply_from_input\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrequest_id\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtimeout_sec\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m","\u001b[0;32m/usr/local/lib/python3.11/dist-packages/google/colab/_message.py\u001b[0m in \u001b[0;36mread_reply_from_input\u001b[0;34m(message_id, timeout_sec)\u001b[0m\n\u001b[1;32m    101\u001b[0m     ):\n\u001b[1;32m    102\u001b[0m       \u001b[0;32mif\u001b[0m \u001b[0;34m'error'\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mreply\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 103\u001b[0;31m         \u001b[0;32mraise\u001b[0m \u001b[0mMessageError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreply\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'error'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    104\u001b[0m       \u001b[0;32mreturn\u001b[0m \u001b[0mreply\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'data'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    105\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;31mMessageError\u001b[0m: Error: credential propagation was unsuccessful"]}],"source":["from google.colab import drive\n","drive.mount('/content/drive')"]},{"cell_type":"code","execution_count":6,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":51341,"status":"ok","timestamp":1750382012312,"user":{"displayName":"Divyaanand Sinha","userId":"10890862448949785739"},"user_tz":420},"id":"tggygGqLdWYr","outputId":"9ddb18a3-ba7c-4d86-a7a0-f47a36a60486"},"outputs":[{"output_type":"stream","name":"stdout","text":["Sampled all node [18574 23658 32561 ... 25613  7256 26554]\n","Adj unique tensor([0., 1.], device='cuda:0')\n","tensor([[ 1,  9,  0, 19,  0, 12, 11,  6, 11,  7, 11,  8, 11,  9,  2, 12,  3, 12,\n","          5, 12,  6, 17,  6,  7,  7, 17,  7,  9,  4, 14,  4, 12,  4, 15, 10, 19,\n","         10, 12, 12, 18, 12, 14, 12, 13, 12,  9, 12, 20, 12, 16, 12, 15, 12, 19,\n","         14, 19, 15, 19, 20, 19],\n","        [ 9,  1, 19,  0, 12,  0,  6, 11,  7, 11,  8, 11,  9, 11, 12,  2, 12,  3,\n","         12,  5, 17,  6,  7,  6, 17,  7,  9,  7, 14,  4, 12,  4, 15,  4, 19, 10,\n","         12, 10, 18, 12, 14, 12, 13, 12,  9, 12, 20, 12, 16, 12, 15, 12, 19, 12,\n","         19, 14, 19, 15, 19, 20]], device='cuda:0')\n","Number of sampled nodes 21\n","Adj  torch.Size([21, 21])\n","y shape torch.Size([1, 2])\n","y_out shape torch.Size([1, 2])\n","alpha 1e-09 beta 1e-07\n","tensor(0.0034, device='cuda:0', grad_fn=<AddBackward0>)\n","0 Loss 0.0034\n","tensor(3.6514e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","100 Loss 0.0000\n","tensor(1.4778e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","200 Loss 0.0000\n","tensor(1.3152e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","300 Loss 0.0000\n","tensor(1.2817e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","400 Loss 0.0000\n","tensor(1.2851e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","500 Loss 0.0000\n","tensor(1.2871e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","600 Loss 0.0000\n","tensor(1.2762e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","700 Loss 0.0000\n","tensor(1.2648e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","800 Loss 0.0000\n","tensor(1.2651e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","900 Loss 0.0000\n","0.9629629629629629\n","0.9307086614173228 0.9629629629629629 0.9773242630385488\n","tensor([[-0.4805,  0.3983,  0.1519,  ...,  2.5808, -0.2978,  0.1775],\n","        [-0.6911,  0.3629, -0.3812,  ...,  2.7927, -0.5381, -0.2075],\n","        [-1.2073,  1.1467,  0.4570,  ...,  0.7104, -0.5663, -0.0841],\n","        ...,\n","        [-0.0916, -0.3585,  3.9636,  ...,  2.5393, -0.4342, -0.1599],\n","        [-0.4144, -0.1953,  0.2298,  ...,  1.2843, -0.4003, -0.0900],\n","        [ 0.0200, -0.3370, -0.2167,  ...,  0.6009, -0.1064,  0.2269]],\n","       device='cuda:0', requires_grad=True)\n","tensor(0.6199, device='cuda:0', grad_fn=<MeanBackward0>)\n","Adj unique tensor([0., 1.], device='cuda:0')\n","tensor([[ 6, 16,  0, 30,  0, 27,  0, 33, 36, 10, 36, 29, 36, 20, 36,  3, 27, 16,\n","         27, 11, 27, 28, 27,  1, 27, 10, 27, 13, 27,  3, 27, 30, 27, 25, 27, 18,\n","         27,  8, 27, 20, 27, 23,  2, 34,  3, 29,  3, 10,  3, 11,  4,  9,  4,  5,\n","          5,  9,  7, 16,  8, 16,  8, 23,  9, 37, 10, 11, 10, 29, 10,  1, 11, 16,\n","         11,  1,  1, 33, 12, 16, 30, 16, 14, 37, 15, 35, 16, 25, 16, 37, 16, 32,\n","         16, 24, 16, 13, 16, 31, 16, 17, 16, 19, 16, 23, 16, 28, 16, 18, 16, 21,\n","         22, 37, 26, 37, 37, 35, 37, 34, 37, 29, 37, 33, 33, 20],\n","        [16,  6, 30,  0, 27,  0, 33,  0, 10, 36, 29, 36, 20, 36,  3, 36, 16, 27,\n","         11, 27, 28, 27,  1, 27, 10, 27, 13, 27,  3, 27, 30, 27, 25, 27, 18, 27,\n","          8, 27, 20, 27, 23, 27, 34,  2, 29,  3, 10,  3, 11,  3,  9,  4,  5,  4,\n","          9,  5, 16,  7, 16,  8, 23,  8, 37,  9, 11, 10, 29, 10,  1, 10, 16, 11,\n","          1, 11, 33,  1, 16, 12, 16, 30, 37, 14, 35, 15, 25, 16, 37, 16, 32, 16,\n","         24, 16, 13, 16, 31, 16, 17, 16, 19, 16, 23, 16, 28, 16, 18, 16, 21, 16,\n","         37, 22, 37, 26, 35, 37, 34, 37, 29, 37, 33, 37, 20, 33]],\n","       device='cuda:0')\n","Number of sampled nodes 38\n","Adj  torch.Size([38, 38])\n","y shape torch.Size([1, 2])\n","y_out shape torch.Size([1, 2])\n","alpha 1e-09 beta 1e-07\n","tensor(0.0040, device='cuda:0', grad_fn=<AddBackward0>)\n","0 Loss 0.0040\n","tensor(4.9041e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","100 Loss 0.0000\n","tensor(2.1988e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","200 Loss 0.0000\n","tensor(2.0552e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","300 Loss 0.0000\n","tensor(1.9801e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","400 Loss 0.0000\n","tensor(1.9345e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","500 Loss 0.0000\n","tensor(1.9217e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","600 Loss 0.0000\n","tensor(1.9089e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","700 Loss 0.0000\n","tensor(2.1027e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","800 Loss 0.0000\n","tensor(1.9118e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","900 Loss 0.0000\n","0.9420289855072463\n","0.7648805833188661 0.9420289855072463 0.9577562326869806\n","tensor([[-0.5763, -0.4179, -0.3663,  ..., -0.3483,  0.0750,  0.2597],\n","        [-0.7752, -0.2336,  0.4404,  ..., -0.1141, -0.2415,  0.1142],\n","        [ 0.1652, -0.7026, -0.3119,  ..., -0.1432, -0.6123,  0.3798],\n","        ...,\n","        [ 0.2658,  0.4631, -0.2895,  ..., -0.3803, -0.4471,  0.1380],\n","        [-0.0679,  0.0512,  0.5625,  ..., -0.1220, -0.2285, -0.2194],\n","        [-0.2645, -0.4281,  0.0314,  ..., -0.2176,  0.0284,  1.8644]],\n","       device='cuda:0', requires_grad=True)\n","tensor(0.7690, device='cuda:0', grad_fn=<MeanBackward0>)\n","Adj unique tensor([0., 1.], device='cuda:0')\n","tensor([[ 0, 29,  0,  5,  0,  4,  0, 16,  1, 29,  1,  2,  1, 16,  1, 33,  2, 36,\n","          2, 29,  2, 33,  2, 18,  2, 16,  3, 29,  3, 16,  4, 18,  4, 29,  4, 33,\n","          4,  8, 26, 29, 26,  5,  5, 25,  5, 12,  5, 34,  5,  8,  5, 29,  5, 20,\n","          5, 16,  5, 31,  5,  7,  5, 28,  6, 34,  6, 16,  6, 29,  7, 29,  9, 29,\n","         10, 33, 10, 29, 11, 29, 12, 29, 13, 29, 15, 29, 15, 31, 16,  8, 16, 32,\n","         16, 24, 16, 25, 16, 18, 16, 20, 16, 34, 16, 19, 16, 29, 16, 36, 16, 27,\n","         18, 19, 18, 29, 19,  8, 19, 32, 19, 29, 19, 36, 19, 20, 19, 33, 19, 34,\n","         20, 33, 20, 25, 20, 29, 20, 32, 20,  8, 21, 29, 35, 14, 23, 29, 23, 34,\n","         24, 29, 24, 33, 24,  8, 25, 29, 25, 31, 17, 36, 17, 29, 17, 28, 14, 29,\n","         34, 29, 27, 29, 28, 29, 29, 36, 29, 33, 29,  8, 29, 30, 29, 32, 29, 31,\n","         29, 22],\n","        [29,  0,  5,  0,  4,  0, 16,  0, 29,  1,  2,  1, 16,  1, 33,  1, 36,  2,\n","         29,  2, 33,  2, 18,  2, 16,  2, 29,  3, 16,  3, 18,  4, 29,  4, 33,  4,\n","          8,  4, 29, 26,  5, 26, 25,  5, 12,  5, 34,  5,  8,  5, 29,  5, 20,  5,\n","         16,  5, 31,  5,  7,  5, 28,  5, 34,  6, 16,  6, 29,  6, 29,  7, 29,  9,\n","         33, 10, 29, 10, 29, 11, 29, 12, 29, 13, 29, 15, 31, 15,  8, 16, 32, 16,\n","         24, 16, 25, 16, 18, 16, 20, 16, 34, 16, 19, 16, 29, 16, 36, 16, 27, 16,\n","         19, 18, 29, 18,  8, 19, 32, 19, 29, 19, 36, 19, 20, 19, 33, 19, 34, 19,\n","         33, 20, 25, 20, 29, 20, 32, 20,  8, 20, 29, 21, 14, 35, 29, 23, 34, 23,\n","         29, 24, 33, 24,  8, 24, 29, 25, 31, 25, 36, 17, 29, 17, 28, 17, 29, 14,\n","         29, 34, 29, 27, 29, 28, 36, 29, 33, 29,  8, 29, 30, 29, 32, 29, 31, 29,\n","         22, 29]], device='cuda:0')\n","Number of sampled nodes 37\n","Adj  torch.Size([37, 37])\n","y shape torch.Size([1, 2])\n","y_out shape torch.Size([1, 2])\n","alpha 1e-09 beta 1e-07\n","tensor(0.0034, device='cuda:0', grad_fn=<AddBackward0>)\n","0 Loss 0.0034\n","tensor(2.7056e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","100 Loss 0.0000\n","tensor(2.0331e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","200 Loss 0.0000\n","tensor(1.9124e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","300 Loss 0.0000\n","tensor(1.7840e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","400 Loss 0.0000\n","tensor(1.5610e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","500 Loss 0.0000\n","tensor(1.5444e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","600 Loss 0.0000\n","tensor(1.4890e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","700 Loss 0.0000\n","tensor(1.4413e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","800 Loss 0.0000\n","tensor(1.4238e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","900 Loss 0.0000\n","0.9830508474576272\n","0.6589194293490839 0.9830508474576272 0.9086924762600438\n","tensor([[ 0.3975, -0.2245,  0.1808,  ...,  0.1141, -0.4320, -0.0809],\n","        [-0.1146, -0.0814,  0.1860,  ...,  0.0939, -0.4183, -0.4324],\n","        [-0.4593, -0.7677,  0.2004,  ..., -0.2845, -0.3565, -0.3082],\n","        ...,\n","        [-0.4494, -1.1012, -0.1333,  ...,  0.2246, -0.0255, -0.0866],\n","        [-0.7907,  0.3261, -0.3474,  ..., -0.2319, -0.6155, -0.2542],\n","        [-0.4120,  0.1930,  0.3087,  ..., -0.0255, -0.4068, -0.2467]],\n","       device='cuda:0', requires_grad=True)\n","tensor(0.7917, device='cuda:0', grad_fn=<MeanBackward0>)\n","Adj unique tensor([0., 1.], device='cuda:0')\n","tensor([[ 0,  7,  1, 16, 18,  4, 18,  8, 18,  3, 18, 25, 18, 14, 18, 12, 18, 13,\n","         18,  2, 18,  7, 18, 10, 18, 21, 18, 11,  2,  7,  2, 25,  3,  7,  5,  7,\n","          5, 11,  5, 20,  7,  6,  7, 13,  7, 22,  7,  9,  7, 14,  7,  8,  7, 11,\n","          7, 20,  7,  4,  7, 21,  7, 12,  7, 25,  7, 19, 12, 10, 10, 24, 10, 25,\n","         10, 16, 10, 13, 16, 19, 16, 23, 16, 17, 16, 24, 17, 23, 15, 19],\n","        [ 7,  0, 16,  1,  4, 18,  8, 18,  3, 18, 25, 18, 14, 18, 12, 18, 13, 18,\n","          2, 18,  7, 18, 10, 18, 21, 18, 11, 18,  7,  2, 25,  2,  7,  3,  7,  5,\n","         11,  5, 20,  5,  6,  7, 13,  7, 22,  7,  9,  7, 14,  7,  8,  7, 11,  7,\n","         20,  7,  4,  7, 21,  7, 12,  7, 25,  7, 19,  7, 10, 12, 24, 10, 25, 10,\n","         16, 10, 13, 10, 19, 16, 23, 16, 17, 16, 24, 16, 23, 17, 19, 15]],\n","       device='cuda:0')\n","Number of sampled nodes 26\n","Adj  torch.Size([26, 26])\n","y shape torch.Size([1, 2])\n","y_out shape torch.Size([1, 2])\n","alpha 1e-09 beta 1e-07\n","tensor(0.0041, device='cuda:0', grad_fn=<AddBackward0>)\n","0 Loss 0.0041\n","tensor(4.2813e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","100 Loss 0.0000\n","tensor(2.1205e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","200 Loss 0.0000\n","tensor(1.9142e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","300 Loss 0.0000\n","tensor(1.8831e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","400 Loss 0.0000\n","tensor(1.8766e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","500 Loss 0.0000\n","tensor(1.8793e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","600 Loss 0.0000\n","tensor(1.8573e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","700 Loss 0.0000\n","tensor(1.8821e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","800 Loss 0.0000\n","tensor(1.8827e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","900 Loss 0.0000\n","0.974025974025974\n","0.9244356833642547 0.974025974025974 0.977810650887574\n","tensor([[ 0.7199,  1.1188,  0.0841,  ..., -0.2977,  0.5186, -0.2667],\n","        [-1.4513, -0.6876, -0.3222,  ..., -0.3542, -0.1471, -0.1595],\n","        [-0.5868, -0.2334, -0.0396,  ..., -0.6927, -0.7533, -0.6554],\n","        ...,\n","        [-0.6863,  0.3194, -0.0075,  ..., -0.1854, -0.5230, -0.3646],\n","        [-0.6394, -0.3987, -0.3485,  ..., -0.1005, -0.9882, -0.5479],\n","        [-0.6630, -0.2624, -0.0035,  ..., -0.5459, -0.3076, -0.4191]],\n","       device='cuda:0', requires_grad=True)\n","tensor(0.6880, device='cuda:0', grad_fn=<MeanBackward0>)\n","Adj unique tensor([0., 1.], device='cuda:0')\n","tensor([[ 0, 13,  0, 16, 17,  8, 17, 16, 17,  2, 17, 20, 17, 21, 17, 15, 17, 24,\n","          1, 16,  2, 16, 23, 16,  3, 10,  3, 19,  4, 16,  5, 13,  5, 19, 21,  8,\n","         21, 16, 21, 19, 21, 12,  6, 16,  7, 16,  8, 19,  9, 16, 10, 19, 11, 16,\n","         12, 16, 14, 19, 14, 18, 15, 19, 16, 22, 16, 20, 16, 24, 18, 19],\n","        [13,  0, 16,  0,  8, 17, 16, 17,  2, 17, 20, 17, 21, 17, 15, 17, 24, 17,\n","         16,  1, 16,  2, 16, 23, 10,  3, 19,  3, 16,  4, 13,  5, 19,  5,  8, 21,\n","         16, 21, 19, 21, 12, 21, 16,  6, 16,  7, 19,  8, 16,  9, 19, 10, 16, 11,\n","         16, 12, 19, 14, 18, 14, 19, 15, 22, 16, 20, 16, 24, 16, 19, 18]],\n","       device='cuda:0')\n","Number of sampled nodes 25\n","Adj  torch.Size([25, 25])\n","y shape torch.Size([1, 2])\n","y_out shape torch.Size([1, 2])\n","alpha 1e-09 beta 1e-07\n","tensor(0.0041, device='cuda:0', grad_fn=<AddBackward0>)\n","0 Loss 0.0041\n","tensor(4.7848e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","100 Loss 0.0000\n","tensor(2.1313e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","200 Loss 0.0000\n","tensor(1.9465e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","300 Loss 0.0000\n","tensor(1.9132e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","400 Loss 0.0000\n","tensor(1.8833e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","500 Loss 0.0000\n","tensor(1.8662e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","600 Loss 0.0000\n","tensor(1.8768e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","700 Loss 0.0000\n","tensor(1.8644e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","800 Loss 0.0000\n","tensor(1.8463e-05, device='cuda:0', grad_fn=<AddBackward0>)\n","900 Loss 0.0000\n","0.9787234042553191\n","0.8276705276705276 0.9787234042553191 0.96\n","tensor([[-0.1190,  0.3080, -0.0864,  ..., -0.1559, -0.3649, -0.5323],\n","        [ 0.1389, -0.2305, -0.0038,  ...,  0.4882,  1.0600,  0.1912],\n","        [-0.9940,  0.0193, -0.2737,  ...,  1.8194, -0.9953,  0.2496],\n","        ...,\n","        [-0.5674,  0.3331, -0.3614,  ..., -0.0193, -0.2781, -0.3639],\n","        [-0.3145, -0.4120, -0.0696,  ..., -0.3764,  0.0557, -0.3483],\n","        [ 0.2483, -0.1981,  0.3405,  ...,  0.0226, -0.2848, -0.2287]],\n","       device='cuda:0', requires_grad=True)\n","tensor(0.6983, device='cuda:0', grad_fn=<MeanBackward0>)\n","-------Adjacency Error-----\n","tensor([0.8213, 0.9682, 0.9563])\n","tensor([0.1142, 0.0164, 0.0282])\n","-------X Error-----\n","tensor([0.7532, 0.7134])\n","tensor([0.0904, 0.0686])\n"]}],"source":["# from torch._C import int64\n","import torch\n","import math\n","from torch_geometric.data import Data\n","from torch_geometric.datasets import Planetoid,TUDataset,QM9,PPI,KarateClub, FakeDataset, MoleculeNet, GNNBenchmarkDataset,FacebookPagePage, GeometricShapes, GitHub\n","# from torch_geometric.datasets import GraphGenerator\n","import torch.nn.functional as F\n","import torch_geometric\n","from torch_geometric.nn import GCNConv,SAGEConv, GATConv,DenseGCNConv, DenseGINConv,  DenseSAGEConv, global_mean_pool\n","from torch_geometric.loader import DataLoader,NeighborLoader, DenseDataLoader\n","from torch_geometric.transforms import GCNNorm\n","from torch_geometric.utils import to_dense_adj, to_dense_batch\n","import networkx as nx\n","from scipy.optimize import linear_sum_assignment\n","import numpy as np\n","import copy\n","from scipy.optimize import linear_sum_assignment as lsa\n","from torch.optim.lr_scheduler import ExponentialLR\n","import argparse\n","import json\n","from sklearn import metrics\n","import seaborn as sns\n","import matplotlib.pyplot as plt\n","import numpy as np\n","\n","torch.manual_seed(0)\n","device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n","\n","\n","cora =  Planetoid(root='data', name='CiteSeer')\n","fb=FacebookPagePage(root='data/fb')\n","git = GitHub(root='data/git')\n","fake = FakeDataset(avg_degree=10,num_channels=10,avg_num_nodes=50,task='node')\n","dataset=git\n","cora= dataset[0].to(device)\n","num_classes=dataset.num_classes #number of labels\n","num_node_features=dataset.num_node_features\n","\n","#onehot labels\n","def label_to_onehot(target, num_classes=num_classes):\n","    target = torch.unsqueeze(target, 1)\n","    onehot_target = torch.zeros(target.size(0), num_classes, device=target.device)\n","    onehot_target.scatter_(1, target, 1)\n","    return onehot_target\n","\n","#cross-entropy loss\n","def cross_entropy_for_onehot(pred, target):\n","    return torch.mean(torch.sum(- target * F.log_softmax(pred, dim=-1), 1))\n","\n","def cross_entropy_for_onehot2(pred, target):\n","    x= - target * F.log_softmax(pred, dim=-1)\n","    return torch.sum(- target * F.log_softmax(pred, dim=-1), -1)\n","\n","criterion = cross_entropy_for_onehot\n","# criterion = cross_entropy_for_onehot2\n","\n","#min-max normalisation\n","def normalise(mat):\n","  maxx=mat.max()\n","  minn=mat.min()\n","  print(\"maxx is {}, minn is {}\".format(maxx,minn))\n","  if(maxx==minn):\n","     return mat\n","\n","  return (mat-minn)/(maxx-minn)\n","\n","#Cost matrix for Hungarian\n","#Adjacency error\n","def get_A_error(A_true,A_pred, thresholds):\n","  err_lst=[]\n","  N=A_true.shape[0]\n","  A_pred=normalise(A_pred)\n","  # print(\"Normalised A {}\".format(A_pred))\n","  err=torch.sum(torch.abs(A_true-A_pred))/(N*N)\n","  A_pred=normalise(A_pred)\n","  err_lst.append(err)\n","  for threshold in thresholds:\n","    new_A=torch.where(A_pred>threshold,1.0,0.0)\n","    # print(\"New A {}\".format(new_A))\n","    err_lst.append(torch.sum(torch.abs(A_true-new_A))/(N*N))\n","  return err_lst\n","\n","def init_matrix(shape,init):\n","  if(init==-1):\n","    return torch.randn(shape).to(device).requires_grad_(True)\n","\n","  return (torch.ones(shape)*init).to(device).requires_grad_(True)\n","\n","def get_root_inv_degree_matrix( adj):\n","    degrees= adj.sum(axis=1).unsqueeze(0)\n","    # print(\"Degree are {}\".format(degrees))\n","    degrees.pow_(-0.5)\n","    # print(\"Degree are {}\".format(degrees))\n","\n","    n=adj.shape[0]\n","    degree_mat=torch.zeros(adj.shape).to(device)\n","    degree_mat[range(n),range(n)]=degrees\n","    return degree_mat\n","def get_degree_matrix( adj):\n","    degrees= adj.sum(axis=1).unsqueeze(0)\n","    # print(\"Degree are {}\".format(degrees))\n","    # degrees.pow_(-0.5)\n","    # print(\"Degree are {}\".format(degrees))\n","\n","    n=adj.shape[0]\n","    degree_mat=torch.zeros(adj.shape).to(device)\n","    degree_mat[range(n),range(n)]=degrees\n","    return degree_mat\n","# def get_laplacian(adj):\n","#   N = adj.shape[0]\n","#   degree_mat  = get_root_inv_degree_matrix(adj)\n","#   lap = torch.eye(N).to(device) - degree_mat @ adj @ degree_mat\n","#   return lap\n","def get_laplacian(adj):\n","  N = adj.shape[0]\n","  degree_mat  = get_degree_matrix(adj)\n","  lap = degree_mat - adj\n","  return lap\n","\n","\n","def get_X_error_old(X_true,X_pred):\n","  row_norm_orig= torch.norm(X_true,dim=-1)\n","  err_norms= torch.norm(X_true-X_pred,dim=-1)\n","  errs=torch.div(err_norms,row_norm_orig)\n","  err_x=torch.mean(errs)\n","  return err_x\n","\n","def get_X_error(X_true,X_pred):\n","  row_norm_orig= torch.norm(X_true,dim=-1).unsqueeze(1)\n","  row_norm_dummy= torch.norm(X_pred,dim=-1).unsqueeze(1)\n","  x=torch.div(X_true,row_norm_orig)\n","  x_hat=torch.div(X_pred,row_norm_dummy)\n","  err_x=torch.mean(torch.norm(x-x_hat,dim=-1))\n","  return err_x\n","\n","\n","\n","# Run Algorithm 2 in the paper\n","def run_optimizer(data,params='all', init=-1,module=DenseGCNConv,alpha= 1e-9, beta= 1e-7):\n","  # data.to(device)\n","  # print(\"Data is {}\".format(data))\n","\n","  # data.adj=to_dense_adj(data.edge_index,max_num_nodes=data.x.shape[0]).squeeze(0)\n","  # data.adj=to_dense_adj(data.edge_index).squeeze(0)\n","\n","  # batch_size=data.x.shape[0]\n","  # print(data.x.shape)\n","  # print(data.adj.shape)\n","\n","  class GNN(torch.nn.Module):\n","      def __init__(self):\n","          super().__init__()\n","          self.act=torch.nn.Sigmoid()\n","          self.conv1=module(num_node_features,100,bias=True)\n","          self.conv2=module(100, 100,bias=True)\n","          # self.conv2=module(100,num_classes,bias=True)\n","          self.lin=torch.nn.Linear(100*data.x.shape[0],dataset.num_classes)\n","\n","      def forward(self, data):\n","          x, adj= data.x, data.adj\n","          x = self.conv1(x, adj)\n","          x = self.act(x)\n","          x = self.conv2(x, adj)\n","          x = self.act(x)\n","          x = x.view(1,-1)\n","          # print(x.shape)\n","          x = self.lin(x)\n","          x = self.act(x)\n","          return F.softmax(x,dim=-1)\n","\n","  model = GNN().to(device)\n","  print(\"Adj \",data.adj.shape)\n","  out = model(data)\n","  # print(\"out shape {}\".format(out.shape))\n","\n","  # y_out=out[0,:].unsqueeze(0)\n","  y_out=out\n","\n","  y=label_to_onehot(data.y[0].unsqueeze(0))\n","  # y=label_to_onehot(data.y[0])\n","  print(\"y shape {}\".format(y.shape))\n","  print(\"y_out shape {}\".format(y_out.shape))\n","\n","  # loss = criterion(y_out, y).squeeze()\n","  loss = criterion(y_out, y)\n","\n","\n","\n","  # print('Loss shape ',loss.shape)\n","  original_dy_dw = []\n","  # for l in loss:\n","  #   #  print('loss is ',l)\n","  #    dy_dw = torch.autograd.grad(l, model.parameters(),retain_graph=True)\n","  #    original_dy_dw.extend(list((_.detach().clone() for _ in dy_dw)))\n","    #  break\n","  # exit()\n","  # print('grad reqs ',original_dy_dw[0].requires_grad)\n","  dy_dw = torch.autograd.grad(loss, model.parameters())\n","  original_dy_dw = list((_.detach().clone() for _ in dy_dw))\n","  # print(\"original {}\".format(original_dy_dw))\n","  dummy_data=copy.deepcopy(data)\n","  # print(\"Initial diff {}\".format(dummy_data.adj-data.adj))\n","  inp=copy.deepcopy(data.x).to(device).requires_grad_(True)\n","  adj=copy.deepcopy(data.adj).to(device).requires_grad_(True)\n","  # opt_lst=[]\n","  # adj=torch.randint(low=0, high=2, size=(data.x.shape[0], data.x.shape[0]),dtype=torch.float).to(device).requires_grad_(True)\n","  # adj=torch.zeros(size=(data.x.shape[0], data.x.shape[0]),dtype=torch.float).to(device).requires_grad_(True)\n","    # adj=torch.randn((data.x.shape[0], data.x.shape[0])).to(device).requires_grad_(True)\n","\n","  # if(params=='all'):\n","  #   inp = init_matrix(data.x.shape,init)\n","  #   adj =  init_matrix((n,n),init)\n","  #   opt_lst=[inp,adj]\n","  # elif (params=='X'):\n","  #   inp = init_matrix(data.x.shape,init)\n","  #   opt_lst = [inp]\n","  # else:\n","  #   # adj = init_matrix((n,n),init)\n","  #   print(\"Doing randint\")\n","  adj=torch.randint(low=0, high=2, size=(data.x.shape[0], data.x.shape[0]),dtype=torch.float).to(device).requires_grad_(True)\n","  # adj=torch.randn((data.x.shape[0], data.x.shape[0])).to(device).requires_grad_(True)\n","  opt_lst = [adj]\n","\n","\n","  # print(\"Adj is {}\".format(adj))\n","  inp = init_matrix(data.x.shape,init)\n","  opt_lst = [adj,inp]\n","  # opt_lst = [inp]\n","  # opt_lst = [adj]\n","\n","\n","\n","  # print(\"Inp is {}\".format(inp))\n","  optimizer=torch.optim.Adam(opt_lst, lr=0.1)\n","  print('alpha',alpha,'beta',beta)\n","  for iters in range(1000):\n","      def closure():\n","          optimizer.zero_grad()\n","          dummy_data.x=inp\n","          symm_adj=torch.tril(adj,diagonal=-1)+torch.tril(adj,diagonal=0).T\n","          dummy_data.adj=symm_adj\n","          pred = model(dummy_data)\n","          # pred = model(dummy_data)[0]\n","\n","          # print(\"preds.shape={}\".format(pred.shape))\n","          dummy_onehot_label = y\n","          # print(\"-----pred shape \",pred.shape)\n","          # print(\"-----y shape \",y.shape)\n","\n","          # print(\"pred.shape {}\".format(pred.shape))\n","          # print(\"dummy_onehot_label.shape {}\".format(dummy_onehot_label.shape))\n","\n","          # l=dummy_loss = criterion(pred, dummy_onehot_label).squeeze() # TODO: fix the gt_label to dummy_label in both code and slides.\n","          l=dummy_loss = criterion(pred, dummy_onehot_label) # TODO: fix the gt_label to dummy_label in both code and slides.\n","\n","          # print(\"Dummy Loss {}\".format(dummy_loss.shape))\n","          # print(dummy_loss.requires_grad)\n","\n","          # dummy_dy_dw = []\n","          # for l in dummy_loss:\n","          #   # print('l is ',l)\n","          #   dy_dw = torch.autograd.grad(l, model.parameters(),create_graph=True)\n","          #   # print('dy_dw grad',dy_dw[0].requires_grad)\n","          #   dummy_dy_dw.extend(list((_ for _ in dy_dw)))\n","\n","          # dy_dw = torch.autograd.grad(l, model.parameters(),retain_graph=True)\n","          # print('dy_dw grads ',dy_dw[0].requires_grad)\n","\n","          dummy_dy_dw = torch.autograd.grad(dummy_loss, model.parameters(), create_graph=True)\n","          # print('dummy_dy_dw grads ',dummy_dy_dw[0].requires_grad)\n","          dot_product = 0.0\n","          mag1 = 0.0\n","          mag2 = 0.0\n","          # print(dummy_dy_dw)\n","          for gx, gy in zip(dummy_dy_dw, original_dy_dw): # TODO: fix the variablas here\n","\n","            dot_product += (gx*gy).sum()\n","\n","            mag1 += torch.linalg.norm(gx)**2\n","            mag2 += torch.linalg.norm(gy)**2\n","          symm_lap = get_laplacian(dummy_data.adj)\n","          # print('symm_lap shape',symm_lap.shape)\n","          alpha = 1e-8\n","          beta = 1e-7\n","          smooth_loss = torch.trace(dummy_data.x.T @ symm_lap @ dummy_data.x)\n","          sparse_loss = torch.norm(adj)*0.5\n","          # print('sparse loss is ',sparse_loss)\n","          # print(mag1,mag2)\n","          cosine_loss=1- (dot_product/(mag1.sqrt()*mag2.sqrt()) )\n","          # print('cosine_loss ',cosine_loss)\n","          # print('gx ',gx, 'gy ',gy)\n","          # loss = cosine_loss\n","          # print('require grad ',data.adj.requires_grad)\n","          # print('smooth loss is ', smooth_loss)\n","          # loss=cosine_loss + (alpha*smooth_loss)\n","          loss=cosine_loss + (alpha*smooth_loss)  + (beta*sparse_loss)\n","          # loss=cosine_loss\n","\n","\n","          # loss= (alpha*sparse_loss)\n","          # print('loss is ',loss)\n","          loss.backward(retain_graph=True)\n","          return loss\n","\n","      optimizer.step(closure)\n","      with torch.no_grad():\n","        adj.clamp_(0,1)\n","        # inp.clamp_(0,1)\n","\n","      if iters % 100 == 0:\n","        current_loss = closure()\n","        print(current_loss)\n","        print(iters, \"Loss %.4f\" % current_loss.item())\n","        if(current_loss.item()==0.0):\n","          break\n","  return data,dummy_data\n","\n","\n","# print(\"Dataset size \",len(dataset))\n","loader=DataLoader(dataset,batch_size=1)\n","\n","\n","thresholds=[0.2, 0.4, 0.5, 0.6, 0.8]\n","results={}\n","\n","\n","\n","\n","# data.adj = to_dense_adj(data.edge_index).squeeze()\n","# print(data.adj.shape)\n","# _, edge_index, _, _ =torch_geometric.utils.k_hop_subgraph(100,2,data.edge_index,relabel_nodes=True)\n","node_lst=[]\n","all_nodes = [i for i in range(cora.num_nodes)]\n","sampled_all_nodes = np.random.choice(all_nodes,size=cora.num_nodes,replace=False)\n","idx= 0\n","print('Sampled all node',sampled_all_nodes)\n","while len(node_lst)<10:\n","    # print(sampled_all_nodes[idx])\n","    nodes, edge_index, u,v =torch_geometric.utils.k_hop_subgraph(int(sampled_all_nodes[idx]),3,cora.edge_index,relabel_nodes=True)\n","    if(len(nodes)<40 and len(nodes)>=20):\n","      node_lst.append((nodes, edge_index, u,v))\n","    idx+=1\n","\n","# print('Randomly generated node list',node_lst)\n","ans = {}\n","# alphas = [1e-13,1e-12,1e-11,1e-10,1e-9,1e-8,1e-7,1e-6,1e-5,1e-4]\n","betas = [1e-13,1e-12,1e-11,1e-10,1e-9,1e-8,1e-7,1e-6,1e-5,1e-4]\n","\n","# alphas= [1e-4]\n","# for beta in betas:\n","avg_A_error=[]\n","avg_X_error=[]\n","#   print('running for beta=',beta)\n","for i in range(5):\n","  # source_node=node_lst[i]\n","  # print('For source node',source_node)\n","  nodes, edge_index, __, _ = node_lst[i]\n","  # print(list(nodes).index(230),edge_index)\n","  data=Data().to(device)\n","  inp =cora.x[nodes]\n","  data.edge_index =  edge_index\n","  data.x =  inp\n","  data.y = cora.y[nodes]\n","  data.adj = to_dense_adj(data.edge_index).squeeze()\n","  data.adj.clamp_(0,1)\n","  print('Adj unique',data.adj.unique())\n","  print(data.edge_index)\n","  print(\"Number of sampled nodes\",data.x.shape[0])\n","  data,dummy_data=run_optimizer(data)\n","  a1 = data.adj.flatten().detach().cpu().numpy()\n","  a2 = torch.bernoulli(dummy_data.adj).flatten().cpu().detach().numpy()\n","  area = metrics.roc_auc_score(a1, a2)\n","  ap = metrics.precision_score(a1, a2)\n","  print(ap)\n","  acc = metrics.accuracy_score(a1, a2)\n","  print(area,ap,acc)\n","  avg_X_error.append([get_X_error(data.x,dummy_data.x),get_X_error_old(data.x,dummy_data.x)])\n","  print(dummy_data.x)\n","  print(get_X_error_old(data.x,dummy_data.x))\n","  avg_A_error.append([area,ap,acc])\n","# print(get_A_error(data.adj,dummy_data.adj,thresholds))\n","# print(get_X_error_old(data.x,dummy_data.x))\n","# for idx,data in enumerate(loader):\n","#   print(\"Graph number: {}\".format(idx))\n","#   data.adj=to_dense_adj(data.edge_index,max_num_nodes=data.x.shape[0]).squeeze(0)\n","#   flag=torch_geometric.utils.is_undirected(data.edge_index)\n","#   if(not flag):\n","#     print(\"Ooooops graph directed!!!!!!\")\n","#     break\n","#   print(dummy_data.adj)\n","#   print(\"Error {}\".format(get_A_error(data.adj,dummy_data.adj,thresholds)[0].item()))\n","#   if dataset=='random':\n","#     print(\"Prob is {} , nodes is {}\".format(data.prob.item(),data.num_nodes))\n","#     if (data.prob.item() not in results):\n","\n","#       results[data.prob.item()]={}\n","#     if(data.num_nodes not in results[data.prob.item()]):\n","#       results[data.prob.item()][data.num_nodes]=[]\n","\n","#     results[data.prob.item()][data.num_nodes].append(get_A_error(data.adj,dummy_data.adj,thresholds)[0].item())\n","#   print(\"done\")\n","#   avg_X_error.append([get_X_error(data.x,dummy_data.x),get_X_error_old(data.x,dummy_data.x)])\n","#   avg_A_error.append(get_A_error(data.adj,dummy_data.adj,thresholds))\n","#   if(idx>=20 ):\n","#     break\n","# print(avg_A_error)\n","\n","print(\"-------Adjacency Error-----\")\n","avg_A_error=torch.Tensor(avg_A_error)\n","print(avg_A_error.mean(axis=0))\n","print(avg_A_error.std(axis=0))\n","# ans[beta]=[avg_A_error.mean(axis=0).tolist(),avg_A_error.std(axis=0).tolist()]\n","\n","print(\"-------X Error-----\")\n","avg_X_error=torch.Tensor(avg_X_error)\n","print(avg_X_error.mean(axis=0))\n","print(avg_X_error.std(axis=0))\n","\n","# print(ans)\n","# # if dataset=='random':\n","# with open('result_beta.json', 'w') as fp:\n","#   json.dump(ans, fp)(base)"]},{"cell_type":"markdown","metadata":{"id":"sYA8_OsNpkKj"},"source":[]}],"metadata":{"colab":{"provenance":[{"file_id":"1ogiezBsJrhM_PXJy1ORRdw4TI3k9zSfs","timestamp":1750381028099},{"file_id":"19ATDOfJutR7RLtKDgfSrVkH9pfZfl7lR","timestamp":1673659183867}],"machine_shape":"hm","gpuType":"A100"},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"},"accelerator":"GPU"},"nbformat":4,"nbformat_minor":0}