{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyObEw7LUjWdqKlhltSJ63mV"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","source":["!pip install torch==2.5.0"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":867},"id":"IzOzEUgTBhBr","executionInfo":{"status":"ok","timestamp":1750384240426,"user_tz":420,"elapsed":117346,"user":{"displayName":"Divyaanand Sinha","userId":"10890862448949785739"}},"outputId":"8e971d87-4080-4c97-cfe6-efa1b5e1687b"},"execution_count":3,"outputs":[{"output_type":"stream","name":"stdout","text":["Collecting torch==2.5.0\n","  Downloading torch-2.5.0-cp311-cp311-manylinux1_x86_64.whl.metadata (28 kB)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch==2.5.0) (3.18.0)\n","Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.11/dist-packages (from torch==2.5.0) (4.14.0)\n","Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch==2.5.0) (3.5)\n","Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch==2.5.0) (3.1.6)\n","Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch==2.5.0) (2025.3.2)\n","Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch==2.5.0) (12.4.127)\n","Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch==2.5.0) (12.4.127)\n","Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch==2.5.0) (12.4.127)\n","Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch==2.5.0) (9.1.0.70)\n","Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.11/dist-packages (from torch==2.5.0) (12.4.5.8)\n","Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.11/dist-packages (from torch==2.5.0) (11.2.1.3)\n","Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.11/dist-packages (from torch==2.5.0) (10.3.5.147)\n","Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.11/dist-packages (from torch==2.5.0) (11.6.1.9)\n","Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.11/dist-packages (from torch==2.5.0) (12.3.1.170)\n","Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch==2.5.0) (2.21.5)\n","Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch==2.5.0) (12.4.127)\n","Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch==2.5.0) (12.4.127)\n","Collecting triton==3.1.0 (from torch==2.5.0)\n","  Downloading triton-3.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.3 kB)\n","Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch==2.5.0) (1.13.1)\n","Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch==2.5.0) (1.3.0)\n","Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch==2.5.0) (3.0.2)\n","Downloading torch-2.5.0-cp311-cp311-manylinux1_x86_64.whl (906.5 MB)\n","\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m906.5/906.5 MB\u001b[0m \u001b[31m783.4 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading triton-3.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (209.5 MB)\n","\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m209.5/209.5 MB\u001b[0m \u001b[31m5.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hInstalling collected packages: triton, torch\n","  Attempting uninstall: triton\n","    Found existing installation: triton 3.2.0\n","    Uninstalling triton-3.2.0:\n","      Successfully uninstalled triton-3.2.0\n","  Attempting uninstall: torch\n","    Found existing installation: torch 2.6.0+cu124\n","    Uninstalling torch-2.6.0+cu124:\n","      Successfully uninstalled torch-2.6.0+cu124\n","\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n","torchaudio 2.6.0+cu124 requires torch==2.6.0, but you have torch 2.5.0 which is incompatible.\n","torchvision 0.21.0+cu124 requires torch==2.6.0, but you have torch 2.5.0 which is incompatible.\u001b[0m\u001b[31m\n","\u001b[0mSuccessfully installed torch-2.5.0 triton-3.1.0\n"]},{"output_type":"display_data","data":{"application/vnd.colab-display-data+json":{"pip_warning":{"packages":["torch","torchgen","triton"]},"id":"55439f7298314c278961c9990430c4c8"}},"metadata":{}}]},{"cell_type":"code","execution_count":1,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"O3XWhd--8sCI","executionInfo":{"status":"ok","timestamp":1750383632424,"user_tz":420,"elapsed":158295,"user":{"displayName":"Divyaanand Sinha","userId":"10890862448949785739"}},"outputId":"f859cd95-4ca9-4e23-c4ee-8a1e55651cea"},"outputs":[{"output_type":"stream","name":"stdout","text":["Looking in links: https://data.pyg.org/whl/torch-2.5.0+cu124.html\n","Collecting pyg_lib\n","  Downloading https://data.pyg.org/whl/torch-2.5.0%2Bcu124/pyg_lib-0.4.0%2Bpt25cu124-cp311-cp311-linux_x86_64.whl (2.5 MB)\n","\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.5/2.5 MB\u001b[0m \u001b[31m13.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting torch_scatter\n","  Downloading https://data.pyg.org/whl/torch-2.5.0%2Bcu124/torch_scatter-2.1.2%2Bpt25cu124-cp311-cp311-linux_x86_64.whl (10.8 MB)\n","\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.8/10.8 MB\u001b[0m \u001b[31m27.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting torch_sparse\n","  Downloading https://data.pyg.org/whl/torch-2.5.0%2Bcu124/torch_sparse-0.6.18%2Bpt25cu124-cp311-cp311-linux_x86_64.whl (5.2 MB)\n","\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.2/5.2 MB\u001b[0m \u001b[31m15.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting torch_cluster\n","  Downloading https://data.pyg.org/whl/torch-2.5.0%2Bcu124/torch_cluster-1.6.3%2Bpt25cu124-cp311-cp311-linux_x86_64.whl (3.4 MB)\n","\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.4/3.4 MB\u001b[0m \u001b[31m17.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting torch_spline_conv\n","  Downloading https://data.pyg.org/whl/torch-2.5.0%2Bcu124/torch_spline_conv-1.2.2%2Bpt25cu124-cp311-cp311-linux_x86_64.whl (1.0 MB)\n","\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m18.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: scipy in /usr/local/lib/python3.11/dist-packages (from torch_sparse) (1.15.3)\n","Requirement already satisfied: numpy<2.5,>=1.23.5 in /usr/local/lib/python3.11/dist-packages (from scipy->torch_sparse) (2.0.2)\n","Installing collected packages: torch_spline_conv, torch_scatter, pyg_lib, torch_sparse, torch_cluster\n","Successfully installed pyg_lib-0.4.0+pt25cu124 torch_cluster-1.6.3+pt25cu124 torch_scatter-2.1.2+pt25cu124 torch_sparse-0.6.18+pt25cu124 torch_spline_conv-1.2.2+pt25cu124\n","Collecting torch-geometric\n","  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)\n","\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m63.1/63.1 kB\u001b[0m \u001b[31m2.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: aiohttp in /usr/local/lib/python3.11/dist-packages (from torch-geometric) (3.11.15)\n","Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch-geometric) (2025.3.2)\n","Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch-geometric) (3.1.6)\n","Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from torch-geometric) (2.0.2)\n","Requirement already satisfied: psutil>=5.8.0 in /usr/local/lib/python3.11/dist-packages (from torch-geometric) (5.9.5)\n","Requirement already satisfied: pyparsing in /usr/local/lib/python3.11/dist-packages (from torch-geometric) (3.2.3)\n","Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from torch-geometric) (2.32.3)\n","Requirement already satisfied: tqdm in /usr/local/lib/python3.11/dist-packages (from torch-geometric) (4.67.1)\n","Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->torch-geometric) (2.6.1)\n","Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from aiohttp->torch-geometric) (1.3.2)\n","Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->torch-geometric) (25.3.0)\n","Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.11/dist-packages (from aiohttp->torch-geometric) (1.7.0)\n","Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.11/dist-packages (from aiohttp->torch-geometric) (6.4.4)\n","Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->torch-geometric) (0.3.2)\n","Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->torch-geometric) (1.20.1)\n","Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch-geometric) (3.0.2)\n","Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->torch-geometric) (3.4.2)\n","Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->torch-geometric) (3.10)\n","Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->torch-geometric) (2.4.0)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->torch-geometric) (2025.6.15)\n","Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)\n","\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m27.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hInstalling collected packages: torch-geometric\n","Successfully installed torch-geometric-2.6.1\n","Collecting rdkit\n","  Downloading rdkit-2025.3.3-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (4.0 kB)\n","Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from rdkit) (2.0.2)\n","Requirement already satisfied: Pillow in /usr/local/lib/python3.11/dist-packages (from rdkit) (11.2.1)\n","Downloading rdkit-2025.3.3-cp311-cp311-manylinux_2_28_x86_64.whl (34.9 MB)\n","\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m34.9/34.9 MB\u001b[0m \u001b[31m16.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hInstalling collected packages: rdkit\n","Successfully installed rdkit-2025.3.3\n","Collecting ogb\n","  Downloading ogb-1.3.6-py3-none-any.whl.metadata (6.2 kB)\n","Requirement already satisfied: torch>=1.6.0 in /usr/local/lib/python3.11/dist-packages (from ogb) (2.6.0+cu124)\n","Requirement already satisfied: numpy>=1.16.0 in /usr/local/lib/python3.11/dist-packages (from ogb) (2.0.2)\n","Requirement already satisfied: tqdm>=4.29.0 in /usr/local/lib/python3.11/dist-packages (from ogb) (4.67.1)\n","Requirement already satisfied: scikit-learn>=0.20.0 in /usr/local/lib/python3.11/dist-packages (from ogb) (1.6.1)\n","Requirement already satisfied: pandas>=0.24.0 in /usr/local/lib/python3.11/dist-packages (from ogb) (2.2.2)\n","Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.11/dist-packages (from ogb) (1.17.0)\n","Requirement already satisfied: urllib3>=1.24.0 in /usr/local/lib/python3.11/dist-packages (from ogb) (2.4.0)\n","Collecting outdated>=0.2.0 (from ogb)\n","  Downloading outdated-0.2.2-py2.py3-none-any.whl.metadata (4.7 kB)\n","Requirement already satisfied: setuptools>=44 in /usr/local/lib/python3.11/dist-packages (from outdated>=0.2.0->ogb) (75.2.0)\n","Collecting littleutils (from outdated>=0.2.0->ogb)\n","  Downloading littleutils-0.2.4-py3-none-any.whl.metadata (679 bytes)\n","Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from outdated>=0.2.0->ogb) (2.32.3)\n","Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas>=0.24.0->ogb) (2.9.0.post0)\n","Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas>=0.24.0->ogb) (2025.2)\n","Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas>=0.24.0->ogb) (2025.2)\n","Requirement already satisfied: scipy>=1.6.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn>=0.20.0->ogb) (1.15.3)\n","Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn>=0.20.0->ogb) (1.5.1)\n","Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn>=0.20.0->ogb) (3.6.0)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch>=1.6.0->ogb) (3.18.0)\n","Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from torch>=1.6.0->ogb) (4.14.0)\n","Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch>=1.6.0->ogb) (3.5)\n","Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch>=1.6.0->ogb) (3.1.6)\n","Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch>=1.6.0->ogb) (2025.3.2)\n","Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.6.0->ogb)\n","  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n","Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.6.0->ogb)\n","  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n","Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.6.0->ogb)\n","  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n","Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.6.0->ogb)\n","  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n","Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.6.0->ogb)\n","  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n","Collecting nvidia-cufft-cu12==11.2.1.3 (from torch>=1.6.0->ogb)\n","  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n","Collecting nvidia-curand-cu12==10.3.5.147 (from torch>=1.6.0->ogb)\n","  Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n","Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch>=1.6.0->ogb)\n","  Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n","Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch>=1.6.0->ogb)\n","  Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n","Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch>=1.6.0->ogb) (0.6.2)\n","Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch>=1.6.0->ogb) (2.21.5)\n","Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=1.6.0->ogb) (12.4.127)\n","Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch>=1.6.0->ogb)\n","  Downloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n","Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch>=1.6.0->ogb) (3.2.0)\n","Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch>=1.6.0->ogb) (1.13.1)\n","Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch>=1.6.0->ogb) (1.3.0)\n","Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch>=1.6.0->ogb) (3.0.2)\n","Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->outdated>=0.2.0->ogb) (3.4.2)\n","Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->outdated>=0.2.0->ogb) (3.10)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->outdated>=0.2.0->ogb) (2025.6.15)\n","Downloading ogb-1.3.6-py3-none-any.whl (78 kB)\n","\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m78.8/78.8 kB\u001b[0m \u001b[31m3.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading outdated-0.2.2-py2.py3-none-any.whl (7.5 kB)\n","Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB)\n","\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m4.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB)\n","\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m48.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)\n","\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m30.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)\n","\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m35.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\n","\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m1.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB)\n","\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m6.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB)\n","\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m9.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB)\n","\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m9.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB)\n","\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m6.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n","\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m25.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading littleutils-0.2.4-py3-none-any.whl (8.1 kB)\n","Installing collected packages: nvidia-nvjitlink-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, littleutils, outdated, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, ogb\n","  Attempting uninstall: nvidia-nvjitlink-cu12\n","    Found existing installation: nvidia-nvjitlink-cu12 12.5.82\n","    Uninstalling nvidia-nvjitlink-cu12-12.5.82:\n","      Successfully uninstalled nvidia-nvjitlink-cu12-12.5.82\n","  Attempting uninstall: nvidia-curand-cu12\n","    Found existing installation: nvidia-curand-cu12 10.3.6.82\n","    Uninstalling nvidia-curand-cu12-10.3.6.82:\n","      Successfully uninstalled nvidia-curand-cu12-10.3.6.82\n","  Attempting uninstall: nvidia-cufft-cu12\n","    Found existing installation: nvidia-cufft-cu12 11.2.3.61\n","    Uninstalling nvidia-cufft-cu12-11.2.3.61:\n","      Successfully uninstalled nvidia-cufft-cu12-11.2.3.61\n","  Attempting uninstall: nvidia-cuda-runtime-cu12\n","    Found existing installation: nvidia-cuda-runtime-cu12 12.5.82\n","    Uninstalling nvidia-cuda-runtime-cu12-12.5.82:\n","      Successfully uninstalled nvidia-cuda-runtime-cu12-12.5.82\n","  Attempting uninstall: nvidia-cuda-nvrtc-cu12\n","    Found existing installation: nvidia-cuda-nvrtc-cu12 12.5.82\n","    Uninstalling nvidia-cuda-nvrtc-cu12-12.5.82:\n","      Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.5.82\n","  Attempting uninstall: nvidia-cuda-cupti-cu12\n","    Found existing installation: nvidia-cuda-cupti-cu12 12.5.82\n","    Uninstalling nvidia-cuda-cupti-cu12-12.5.82:\n","      Successfully uninstalled nvidia-cuda-cupti-cu12-12.5.82\n","  Attempting uninstall: nvidia-cublas-cu12\n","    Found existing installation: nvidia-cublas-cu12 12.5.3.2\n","    Uninstalling nvidia-cublas-cu12-12.5.3.2:\n","      Successfully uninstalled nvidia-cublas-cu12-12.5.3.2\n","  Attempting uninstall: nvidia-cusparse-cu12\n","    Found existing installation: nvidia-cusparse-cu12 12.5.1.3\n","    Uninstalling nvidia-cusparse-cu12-12.5.1.3:\n","      Successfully uninstalled nvidia-cusparse-cu12-12.5.1.3\n","  Attempting uninstall: nvidia-cudnn-cu12\n","    Found existing installation: nvidia-cudnn-cu12 9.3.0.75\n","    Uninstalling nvidia-cudnn-cu12-9.3.0.75:\n","      Successfully uninstalled nvidia-cudnn-cu12-9.3.0.75\n","  Attempting uninstall: nvidia-cusolver-cu12\n","    Found existing installation: nvidia-cusolver-cu12 11.6.3.83\n","    Uninstalling nvidia-cusolver-cu12-11.6.3.83:\n","      Successfully uninstalled nvidia-cusolver-cu12-11.6.3.83\n","Successfully installed littleutils-0.2.4 nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nvjitlink-cu12-12.4.127 ogb-1.3.6 outdated-0.2.2\n"]}],"source":["!pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.5.0+cu124.html\n","!pip install torch-geometric\n","!pip install rdkit\n","!pip install ogb"]},{"cell_type":"code","source":["import torch\n","import math\n","\n","from torch_geometric.data import Data\n","from torch_geometric.datasets import OGB_MAG,KarateClub, FakeDataset, FacebookPagePage, GeometricShapes, GitHub\n","import torch.nn.functional as F\n","import torch_geometric\n","from torch_geometric.nn import GCNConv,SAGEConv, GATConv\n","from torch_geometric.loader import DataLoader,NeighborLoader\n","from torch_geometric.transforms import GCNNorm\n","from torch_geometric.utils import to_dense_adj\n","import networkx as nx\n","from sklearn.metrics import mean_absolute_percentage_error as mape\n","from tqdm import tqdm\n","import argparse\n","from ogb.nodeproppred import PygNodePropPredDataset\n","torch.manual_seed(0)\n","device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n","\n","parser = argparse.ArgumentParser()\n","\n","# parser.add_argument('--dataset_name',\n","#                     help='name of dataset',\n","#                     choices=('fb','github','fake'),\n","#                     type=str,\n","#                     required=True)\n","# args = parser.parse_args()\n","\n","\n","fb=FacebookPagePage(root='data/fb')\n","fake = FakeDataset(avg_degree=4,num_channels=10,avg_num_nodes=20,task='node')\n","github= GitHub(root='data/github')\n","ogb = PygNodePropPredDataset(name = 'ogbn-arxiv')\n","dataset=github\n","\n","# if(args.dataset_name=='fb'):\n","#   print('Facbook Data')\n","#   dataset=fb # set your dataset\n","# elif(args.dataset_name=='github'):\n","#   print('Github Data')\n","#   dataset=github # set your dataset\n","# else:\n","#   print('Synthetic Data')\n","#   dataset=fake\n","\n","batch_size=1\n","\n","\n","def get_dummy_batch_data(tree_degree=10):\n","  dummy_data_batch=[]\n","  inps=[]\n","  dummy_labels=[]\n","  for i in range(batch_size):\n","    tree=nx.balanced_tree(tree_degree,2)\n","    inp=(torch.randn((len(tree),dataset.num_node_features))).to(device).requires_grad_(True)\n","    dummy_label=torch.randn((1,dataset.num_classes)).to(device).requires_grad_(True)\n","    dummy_data=torch_geometric.utils.from_networkx(tree).to(device)\n","    dummy_data.x=inp\n","    dummy_data.y=dummy_label\n","\n","    dummy_labels.append(dummy_label)\n","    dummy_data_batch.append(dummy_data)\n","    inps.append(inp)\n","\n","  return dummy_data_batch,dummy_labels,inps\n","\n","\n","\n","######################################\n","def get_dummy_batch_data_src_graph():\n","  dummy_data_batch=[]\n","  inps=[]\n","  dummy_labels=[]\n","  for i in range(batch_size):\n","    node_idx,edge_index,_,_=torch_geometric.utils.k_hop_subgraph(0,2,data_batch.edge_index,relabel_nodes=True)\n","    inp=(torch.randn((len(node_idx),dataset.num_node_features))).to(device).requires_grad_(True)\n","\n","    dummy_label=torch.randn((1,dataset.num_classes)).to(device).requires_grad_(True)\n","    dummy_data=Data(x=inp,edge_index=edge_index,y=dummy_label)\n","\n","    dummy_labels.append(dummy_label)\n","    dummy_data_batch.append(dummy_data)\n","    inps.append(inp)\n","\n","  return dummy_data_batch,dummy_labels,inps\n","\n","dummy_data_batch,dummy_labels,inps=get_dummy_batch_data()\n","print(len(dummy_data_batch))\n","print(len(inps))\n","\n","\n","def label_to_onehot(target, num_classes=dataset.num_classes):\n","    target = torch.unsqueeze(target, 1)\n","    onehot_target = torch.zeros(target.size(0), num_classes, device=target.device)\n","    onehot_target.scatter_(1, target, 1)\n","    return onehot_target\n","\n","def cross_entropy_for_onehot(pred, target):\n","    return torch.mean(torch.sum(- target * F.log_softmax(pred, dim=-1), 1) )\n","\n","criterion = cross_entropy_for_onehot\n","device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n","\n","def weights_init_sage(m):\n","    if hasattr(m,\"lin_l\"):\n","      m.lin_l.weight.data.uniform_(-1,1)\n","      m.lin_r.weight.data.uniform_(-1,1)\n","\n","      if hasattr(m.lin_l ,\"bias\"):\n","          m.lin_l.bias.data.uniform_(-1,1)\n","\n","class GNN(torch.nn.Module):\n","    def __init__(self):\n","        super().__init__()\n","        self.conv1=GCNConv(dataset.num_node_features,100, bias=False)\n","        self.act=torch.nn.Sigmoid()\n","        self.conv2=GCNConv(100, dataset.num_classes,bias=False)\n","\n","    def forward(self, data):\n","        x, edge_index = data.x, data.edge_index\n","        x = self.conv1(x,edge_index)\n","        x = self.act(x)\n","        x = self.conv2(x,edge_index)\n","        x = self.act(x)\n","\n","        return F.softmax(x, dim=1)\n","\n","model = GNN().to(device)\n","data = dataset[0].to(device)\n","\n","import numpy as np\n","from scipy.optimize import linear_sum_assignment as lsa\n","\n","def get_one_hop_matrix(data):\n","  one_hop=torch_geometric.utils.k_hop_subgraph(0,1,data.edge_index)[0]\n","  return data.x[one_hop][1:]\n","\n","def get_cost_matrix(x_true,x_pred):\n","  print(x_true.shape)\n","  cost_matrix=np.zeros((x_true.shape[0],x_pred.shape[0]))\n","  for i,a in enumerate(x_true):\n","    for j,b in enumerate(x_pred):\n","      cost_matrix[i][j]=torch.norm(a-b)\n","  return cost_matrix\n","\n","def get_perfect_cost(x_true,x_pred,get_vals=False):\n","  C=get_cost_matrix(x_true,x_pred)\n","  print(C.shape)\n","  row_ind,col_ind=lsa(C)\n","\n","  norms=torch.norm(x_true,dim=1).cpu().numpy()\n","  print(row_ind,col_ind)\n","  if get_vals:\n","    return (C[row_ind,col_ind]/norms)\n","\n","  return np.mean((C[row_ind,col_ind]/norms)*100),np.std((C[row_ind,col_ind]/norms)*100), np.min((C[row_ind,col_ind]/norms)*100)\n","\n","def rnmse(x_true,x_pred):\n","  # print(\"max is {}\".format(torch.max(dim=1)))\n","  diff=x_true-x_pred\n","  return torch.norm(diff)/torch.norm(x_true)\n","\n","def get_eigenvalues(X):\n","  tmp=X.T @ X\n","  e,v=torch.linalg.eig(X.T @ X)\n","  sorted_e,indices=torch.sort(e[torch.isreal(e)].real,descending=True)\n","  return sorted_e\n","\n","def get_neighbor(data,idx):\n","  dummy_one_hop_nodes=torch_geometric.utils.k_hop_subgraph(idx,1,data.edge_index)[0]\n","  eig=get_eigenvalues(data.x[dummy_one_hop_nodes[1:],:])\n","  return (eig>1).sum()\n","\n","def get_batched_data_error(x_true,x_pred):\n","  return get_perfect_cost(x_true,x_pred)\n","\n","\n","def run_optimizer(model,dataloader,\n","                  inps,num_nodes=5):\n","  reproduced_inps= []\n","  rounds=0\n","  print('starting')\n","  for data_batch in dataloader:\n","    if(rounds>=200):\n","      break\n","    batch_size=data_batch.batch_size\n","    print(\"Running data_batch={}, batch size is {}\".format(rounds,batch_size))\n","    if(data_batch.edge_index.shape[1]==0):\n","      continue\n","    onehop=(torch_geometric.utils.k_hop_subgraph(0,1,data_batch.edge_index)[0]).shape\n","    if onehop[0]>10 or onehop[0]<3:\n","      continue\n","    print(\"number of onehop neighbors {}\".format(onehop[0]))\n","    # data_batch.to(device)\n","    # print(data_batch)\n","    # if(data_batch.batch_size<num_nodes):\n","    #   break\n","    out = model(data_batch)[:batch_size]\n","    y_out=out\n","    y=label_to_onehot(data_batch.y[:batch_size])\n","    print(\"y shape is {}\".format(y.shape))\n","    print(\"y out shape is {}\".format(y_out.shape))\n","\n","\n","    loss = criterion(y_out, y)\n","    print(loss)\n","    dy_dw = torch.autograd.grad(loss, model.parameters())\n","    original_dy_dw = list((_.detach().clone() for _ in dy_dw))\n","    with torch.no_grad():\n","      dummy_data_batch,_,inps=get_dummy_batch_data(tree_degree=10)\n","      inps[0][0]=data_batch.x[0]\n","\n","    optimizer=torch.optim.Adam(inps, lr=0.1)\n","    # grads=None\n","    for iters in range(2000):\n","        def closure():\n","          optimizer.zero_grad()\n","          preds=[]\n","          dummy_onehot_labels=[]\n","          dummy_onehot_label = y\n","\n","          for i in range(data_batch.batch_size):\n","            dummy_data_batch[i].x=inps[i]\n","            pred = model(dummy_data_batch[i])[0]\n","            dummy_onehot_label = y\n","            preds.append(pred)\n","            dummy_onehot_labels.append(dummy_onehot_label)\n","\n","          preds=torch.stack(preds)\n","          dummy_onehot_labels=torch.stack(dummy_onehot_labels).squeeze(1)\n","          dummy_loss = criterion(preds, dummy_onehot_labels) # TODO: fix the gt_label to dummy_label in both code and slides.\n","          dummy_dy_dw = torch.autograd.grad(dummy_loss, model.parameters(), create_graph=True)\n","\n","          dot_product = 0\n","          mag1 = 0\n","          mag2 = 0\n","          cosine_loss=0\n","          for gx, gy in zip(dummy_dy_dw, original_dy_dw): # TODO: fix the variablas here\n","              dot_product += (gx*gy).sum()\n","              mag1 += torch.linalg.norm(gx)**2\n","              mag2 += torch.linalg.norm(gy)**2\n","\n","          loss=1- (dot_product/(mag1.sqrt()*mag2.sqrt()))\n","          loss.backward(retain_graph=True)\n","          return loss\n","\n","        optimizer.step(closure)\n","        if iters % 100== 0:\n","            current_loss = closure()\n","            if(current_loss==0.0):\n","              break\n","            print(iters, \"%.4f\" % current_loss.item())\n","\n","        # n_pred=get_neighbor(dummy_data_batch[0],0).item()\n","    # stats={\"original\":data_batch.x, \"reproduced\": dummy_data_batch[0].x} #for non-batched data\n","    # stats={\"original\":data_batch, \"reproduced\": dummy_data_batch[0]} #for neighbors\n","\n","    reproduced=[]\n","    for x_pred in inps:\n","      reproduced.append(x_pred[0])\n","    stats={\"original\":data_batch.x[:batch_size], \"reproduced\": reproduced} #for batched data\n","\n","\n","    reproduced_inps.append(stats)\n","    rounds+=1\n","  return reproduced_inps\n","\n","\n","\n","\n","dl=NeighborLoader(data,batch_size=batch_size,num_neighbors=[-1]*2,input_nodes=torch.arange(data.num_nodes),replace=False,directed=False)\n","outputs=run_optimizer(model,dl,inps)\n","\n","\n","#for non batched neighboring nodes recovery\n","# data_true=[]\n","# data_pred=[]\n","# avg_error=0\n","# means=[]\n","# mins=[]\n","# stds=[]\n","# all_costs=[]\n","# for out in outputs:\n","#   # print(\"Original {}\".format(out['original']))\n","#   actual=out['original'].x\n","#   reproduced=out['reproduced'].x\n","#   actual_onehop=out['original'].x[torch_geometric.utils.k_hop_subgraph(0,1,out['original'].edge_index)[0]]\n","#   recovered_onehop=out['reproduced'].x[torch_geometric.utils.k_hop_subgraph(0,1,out['reproduced'].edge_index)[0]]\n","#   # print(actual_onehop)\n","#   # print(recovered_onehop)\n","#   best_costs=get_perfect_cost(actual_onehop[1:],recovered_onehop[1:],get_vals=True)\n","#   all_costs.append(best_costs)\n","# ###########################################################################\n","\n","# for non batched source node feature recovery\n","data_true=[]\n","data_pred=[]\n","avg_error=[]\n","for out in outputs:\n","  # print(\"Original {}\".format(out['original']))\n","  avg_error.append(rnmse(out['original'][0],out['reproduced'][0]).item())\n","  # print(rnmse(out['original'][0],out['reproduced'][0]))\n","print(avg_error)\n","print(\"Mean:{}\".format(np.mean(avg_error)))\n","print(\"Std:{}\".format(np.std(avg_error)))\n","print(\"Min:{}\".format(np.min(avg_error)))\n","##########################################################################\n","\n","\n","# # for batched data\n","# data_true=[]\n","# data_pred=[]\n","# for out in outputs:\n","#   # print(\"Original {}\".format(out['original']))\n","#   data_true.append(out['original'])\n","#   data_pred.append(out['reproduced'])\n","\n","# for idx,tensor in enumerate(data_pred[0]):\n","#   data_pred[0][idx]=tensor.unsqueeze(dim=0)\n","\n","# data_true=torch.cat(data_true)\n","# data_pred=torch.cat(data_pred[0]).squeeze()\n","# print(data_true.shape)\n","# print(data_pred.shape)\n","# get_batched_data_error(data_true,data_pred)\n","# ####################################################################"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"aooX1PNr8vgh","outputId":"3494b480-4606-4362-ddaa-43ddcfa615f3"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.11/dist-packages/ogb/nodeproppred/dataset_pyg.py:69: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n","  self.data, self.slices = torch.load(self.processed_paths[0])\n"]},{"output_type":"stream","name":"stdout","text":["1\n","1\n","starting\n","Running data_batch=0, batch size is 1\n","Running data_batch=0, batch size is 1\n"]},{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.11/dist-packages/torch_geometric/sampler/neighbor_sampler.py:55: UserWarning: The usage of the 'directed' argument in 'NeighborSampler' is deprecated. Use `subgraph_type='induced'` instead.\n","  warnings.warn(f\"The usage of the 'directed' argument in \"\n"]},{"output_type":"stream","name":"stdout","text":["number of onehop neighbors 9\n","y shape is torch.Size([1, 2])\n","y out shape is torch.Size([1, 2])\n","tensor(0.6892, grad_fn=<MeanBackward0>)\n","0 0.0267\n","100 0.0002\n","200 0.0001\n","300 0.0001\n","400 0.0000\n","500 0.0000\n","600 0.0000\n","700 0.0000\n","800 0.0000\n","900 0.0000\n","1000 0.0000\n","1100 0.0000\n","1200 0.0000\n","1300 0.0000\n","1400 0.0000\n","1500 0.0000\n","1600 0.0000\n"]}]},{"cell_type":"code","source":[],"metadata":{"id":"9mZjw6Hs-EYU"},"execution_count":null,"outputs":[]}]}