{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "machine_shape": "hm",
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "source": [
        "#install\n",
        "!pip install torch-scatter\n",
        "!pip install torch_geometric\n",
        "!pip install captum"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "collapsed": true,
        "id": "9b67R7ry2srl",
        "outputId": "853618aa-0778-46dd-d9b8-32387d5dd027"
      },
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting torch-scatter\n",
            "  Downloading torch_scatter-2.1.2.tar.gz (108 kB)\n",
            "\u001b[?25l     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/108.0 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m108.0/108.0 kB\u001b[0m \u001b[31m4.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "Building wheels for collected packages: torch-scatter\n",
            "  Building wheel for torch-scatter (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for torch-scatter: filename=torch_scatter-2.1.2-cp312-cp312-linux_x86_64.whl size=640889 sha256=6114b72ed89140f77e2e9821eac761bc484c313b8b4a54184c86adcca339d98a\n",
            "  Stored in directory: /root/.cache/pip/wheels/84/20/50/44800723f57cd798630e77b3ec83bc80bd26a1e3dc3a672ef5\n",
            "Successfully built torch-scatter\n",
            "Installing collected packages: torch-scatter\n",
            "Successfully installed torch-scatter-2.1.2\n",
            "Collecting torch_geometric\n",
            "  Downloading torch_geometric-2.7.0-py3-none-any.whl.metadata (63 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m63.7/63.7 kB\u001b[0m \u001b[31m3.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: aiohttp in /usr/local/lib/python3.12/dist-packages (from torch_geometric) (3.13.2)\n",
            "Requirement already satisfied: fsspec in /usr/local/lib/python3.12/dist-packages (from torch_geometric) (2025.3.0)\n",
            "Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch_geometric) (3.1.6)\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.12/dist-packages (from torch_geometric) (2.0.2)\n",
            "Requirement already satisfied: psutil>=5.8.0 in /usr/local/lib/python3.12/dist-packages (from torch_geometric) (5.9.5)\n",
            "Requirement already satisfied: pyparsing in /usr/local/lib/python3.12/dist-packages (from torch_geometric) (3.2.5)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from torch_geometric) (2.32.4)\n",
            "Requirement already satisfied: tqdm in /usr/local/lib/python3.12/dist-packages (from torch_geometric) (4.67.1)\n",
            "Requirement already satisfied: xxhash in /usr/local/lib/python3.12/dist-packages (from torch_geometric) (3.6.0)\n",
            "Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp->torch_geometric) (2.6.1)\n",
            "Requirement already satisfied: aiosignal>=1.4.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp->torch_geometric) (1.4.0)\n",
            "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp->torch_geometric) (25.4.0)\n",
            "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.12/dist-packages (from aiohttp->torch_geometric) (1.8.0)\n",
            "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.12/dist-packages (from aiohttp->torch_geometric) (6.7.0)\n",
            "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp->torch_geometric) (0.4.1)\n",
            "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp->torch_geometric) (1.22.0)\n",
            "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch_geometric) (3.0.3)\n",
            "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->torch_geometric) (3.4.4)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests->torch_geometric) (3.11)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests->torch_geometric) (2.5.0)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests->torch_geometric) (2025.10.5)\n",
            "Requirement already satisfied: typing-extensions>=4.2 in /usr/local/lib/python3.12/dist-packages (from aiosignal>=1.4.0->aiohttp->torch_geometric) (4.15.0)\n",
            "Downloading torch_geometric-2.7.0-py3-none-any.whl (1.3 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m45.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hInstalling collected packages: torch_geometric\n",
            "Successfully installed torch_geometric-2.7.0\n",
            "Collecting captum\n",
            "  Downloading captum-0.8.0-py3-none-any.whl.metadata (26 kB)\n",
            "Requirement already satisfied: matplotlib in /usr/local/lib/python3.12/dist-packages (from captum) (3.10.0)\n",
            "Collecting numpy<2.0 (from captum)\n",
            "  Downloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m61.0/61.0 kB\u001b[0m \u001b[31m2.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from captum) (25.0)\n",
            "Requirement already satisfied: torch>=1.10 in /usr/local/lib/python3.12/dist-packages (from captum) (2.8.0+cu126)\n",
            "Requirement already satisfied: tqdm in /usr/local/lib/python3.12/dist-packages (from captum) (4.67.1)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (3.20.0)\n",
            "Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (4.15.0)\n",
            "Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (75.2.0)\n",
            "Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (1.13.3)\n",
            "Requirement already satisfied: networkx in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (3.5)\n",
            "Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (3.1.6)\n",
            "Requirement already satisfied: fsspec in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (2025.3.0)\n",
            "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (12.6.77)\n",
            "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (12.6.77)\n",
            "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (12.6.80)\n",
            "Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (9.10.2.21)\n",
            "Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (12.6.4.1)\n",
            "Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (11.3.0.4)\n",
            "Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (10.3.7.77)\n",
            "Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (11.7.1.2)\n",
            "Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (12.5.4.2)\n",
            "Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (0.7.1)\n",
            "Requirement already satisfied: nvidia-nccl-cu12==2.27.3 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (2.27.3)\n",
            "Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (12.6.77)\n",
            "Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (12.6.85)\n",
            "Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (1.11.1.6)\n",
            "Requirement already satisfied: triton==3.4.0 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (3.4.0)\n",
            "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib->captum) (1.3.3)\n",
            "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.12/dist-packages (from matplotlib->captum) (0.12.1)\n",
            "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib->captum) (4.60.1)\n",
            "Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib->captum) (1.4.9)\n",
            "Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.12/dist-packages (from matplotlib->captum) (11.3.0)\n",
            "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib->captum) (3.2.5)\n",
            "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.12/dist-packages (from matplotlib->captum) (2.9.0.post0)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.12/dist-packages (from python-dateutil>=2.7->matplotlib->captum) (1.17.0)\n",
            "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch>=1.10->captum) (1.3.0)\n",
            "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch>=1.10->captum) (3.0.3)\n",
            "Downloading captum-0.8.0-py3-none-any.whl (1.4 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.4/1.4 MB\u001b[0m \u001b[31m38.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.0 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m18.0/18.0 MB\u001b[0m \u001b[31m67.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hInstalling collected packages: numpy, captum\n",
            "  Attempting uninstall: numpy\n",
            "    Found existing installation: numpy 2.0.2\n",
            "    Uninstalling numpy-2.0.2:\n",
            "      Successfully uninstalled numpy-2.0.2\n",
            "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
            "opencv-python-headless 4.12.0.88 requires numpy<2.3.0,>=2; python_version >= \"3.9\", but you have numpy 1.26.4 which is incompatible.\n",
            "shap 0.50.0 requires numpy>=2, but you have numpy 1.26.4 which is incompatible.\n",
            "opencv-python 4.12.0.88 requires numpy<2.3.0,>=2; python_version >= \"3.9\", but you have numpy 1.26.4 which is incompatible.\n",
            "jax 0.7.2 requires numpy>=2.0, but you have numpy 1.26.4 which is incompatible.\n",
            "jaxlib 0.7.2 requires numpy>=2.0, but you have numpy 1.26.4 which is incompatible.\n",
            "pytensor 2.35.1 requires numpy>=2.0, but you have numpy 1.26.4 which is incompatible.\n",
            "opencv-contrib-python 4.12.0.88 requires numpy<2.3.0,>=2; python_version >= \"3.9\", but you have numpy 1.26.4 which is incompatible.\u001b[0m\u001b[31m\n",
            "\u001b[0mSuccessfully installed captum-0.8.0 numpy-1.26.4\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.colab-display-data+json": {
              "pip_warning": {
                "packages": [
                  "numpy"
                ]
              },
              "id": "ffb461e747af40a8a9d8ac7ffbb3c1fa"
            }
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install lime\n",
        "#!git clone https://github.com/WilliamCCHuang/GraphLIME.git"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "collapsed": true,
        "id": "nxYB7hM3ajIl",
        "outputId": "038bec64-e29c-439c-e932-c294b3cd8207"
      },
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting lime\n",
            "  Downloading lime-0.2.0.1.tar.gz (275 kB)\n",
            "\u001b[?25l     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/275.7 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m275.7/275.7 kB\u001b[0m \u001b[31m9.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "Requirement already satisfied: matplotlib in /usr/local/lib/python3.12/dist-packages (from lime) (3.10.0)\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.12/dist-packages (from lime) (1.26.4)\n",
            "Requirement already satisfied: scipy in /usr/local/lib/python3.12/dist-packages (from lime) (1.16.3)\n",
            "Requirement already satisfied: tqdm in /usr/local/lib/python3.12/dist-packages (from lime) (4.67.1)\n",
            "Requirement already satisfied: scikit-learn>=0.18 in /usr/local/lib/python3.12/dist-packages (from lime) (1.6.1)\n",
            "Requirement already satisfied: scikit-image>=0.12 in /usr/local/lib/python3.12/dist-packages (from lime) (0.25.2)\n",
            "Requirement already satisfied: networkx>=3.0 in /usr/local/lib/python3.12/dist-packages (from scikit-image>=0.12->lime) (3.5)\n",
            "Requirement already satisfied: pillow>=10.1 in /usr/local/lib/python3.12/dist-packages (from scikit-image>=0.12->lime) (11.3.0)\n",
            "Requirement already satisfied: imageio!=2.35.0,>=2.33 in /usr/local/lib/python3.12/dist-packages (from scikit-image>=0.12->lime) (2.37.2)\n",
            "Requirement already satisfied: tifffile>=2022.8.12 in /usr/local/lib/python3.12/dist-packages (from scikit-image>=0.12->lime) (2025.10.16)\n",
            "Requirement already satisfied: packaging>=21 in /usr/local/lib/python3.12/dist-packages (from scikit-image>=0.12->lime) (25.0)\n",
            "Requirement already satisfied: lazy-loader>=0.4 in /usr/local/lib/python3.12/dist-packages (from scikit-image>=0.12->lime) (0.4)\n",
            "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn>=0.18->lime) (1.5.2)\n",
            "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn>=0.18->lime) (3.6.0)\n",
            "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib->lime) (1.3.3)\n",
            "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.12/dist-packages (from matplotlib->lime) (0.12.1)\n",
            "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib->lime) (4.60.1)\n",
            "Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib->lime) (1.4.9)\n",
            "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib->lime) (3.2.5)\n",
            "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.12/dist-packages (from matplotlib->lime) (2.9.0.post0)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.12/dist-packages (from python-dateutil>=2.7->matplotlib->lime) (1.17.0)\n",
            "Building wheels for collected packages: lime\n",
            "  Building wheel for lime (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for lime: filename=lime-0.2.0.1-py3-none-any.whl size=283834 sha256=1955c12105e6d76341ca11b098fa6d5e72e2225272d6c7333e2856ca181a2102\n",
            "  Stored in directory: /root/.cache/pip/wheels/e7/5d/0e/4b4fff9a47468fed5633211fb3b76d1db43fe806a17fb7486a\n",
            "Successfully built lime\n",
            "Installing collected packages: lime\n",
            "Successfully installed lime-0.2.0.1\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!python -c \"import torch_geometric; print(torch_geometric.__version__)\""
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "jouMiG7p7FfS",
        "outputId": "61c61ca2-20fd-4543-bd80-48cbbd814054"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "2.7.0\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## AMZ-COMPUTER"
      ],
      "metadata": {
        "id": "CT0NAmSY80h-"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 53,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "68a316yJ2lfT",
        "outputId": "3e5ad664-7462-44c2-ea45-1f077d738030"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "=== Amazon-Computers ===\n",
            "#Nodes     = 13752\n",
            "#Edges     = 491722\n",
            "#Features  = 767\n",
            "#Classes   = 10\n",
            "#Train nodes = 9621\n",
            "#Val nodes   = 1370\n",
            "#Test nodes  = 2761\n",
            "Final feature dim after augmentation = 767\n",
            "=== Amazon-Computers ===\n",
            "#Nodes     = 13752\n",
            "#Edges     = 491722\n",
            "#Features  = 767\n",
            "#Classes   = 10\n",
            "Added noise to features.\n",
            "#Train nodes = 9621\n",
            "#Val nodes   = 1370\n",
            "#Test nodes  = 2761\n",
            "Final feature dim after augmentation = 767\n",
            "\n",
            "=== Loading existing model from amazon_computer_gcn.pt ===\n",
            "Loaded model | Train=0.6363 Val=0.6314 Test=0.6338\n",
            "\n",
            "=== Building normalized adjacency list ===\n",
            "\n",
            "=== Explaining 100 test nodes (pred class as target) ===\n",
            "Building feature importance via decomposition for 100 nodes...\n",
            "  processed 1/100 nodes\n",
            "  processed 20/100 nodes\n",
            "  processed 40/100 nodes\n",
            "  processed 60/100 nodes\n",
            "  processed 80/100 nodes\n",
            "  processed 100/100 nodes\n",
            "Building feature importance via decomposition for 100 nodes...\n",
            "  processed 1/100 nodes\n",
            "  processed 20/100 nodes\n",
            "  processed 40/100 nodes\n",
            "  processed 60/100 nodes\n",
            "  processed 80/100 nodes\n",
            "  processed 100/100 nodes\n",
            "Running GNN-LRP for 100 nodes...\n",
            "  GNN-LRP processed 1/100 nodes\n",
            "  GNN-LRP processed 10/100 nodes\n",
            "  GNN-LRP processed 20/100 nodes\n",
            "  GNN-LRP processed 30/100 nodes\n",
            "  GNN-LRP processed 40/100 nodes\n",
            "  GNN-LRP processed 50/100 nodes\n",
            "  GNN-LRP processed 60/100 nodes\n",
            "  GNN-LRP processed 70/100 nodes\n",
            "  GNN-LRP processed 80/100 nodes\n",
            "  GNN-LRP processed 90/100 nodes\n",
            "  GNN-LRP processed 100/100 nodes\n",
            "Running GNN-LRP for 100 nodes...\n",
            "  GNN-LRP processed 1/100 nodes\n",
            "  GNN-LRP processed 10/100 nodes\n",
            "  GNN-LRP processed 20/100 nodes\n",
            "  GNN-LRP processed 30/100 nodes\n",
            "  GNN-LRP processed 40/100 nodes\n",
            "  GNN-LRP processed 50/100 nodes\n",
            "  GNN-LRP processed 60/100 nodes\n",
            "  GNN-LRP processed 70/100 nodes\n",
            "  GNN-LRP processed 80/100 nodes\n",
            "  GNN-LRP processed 90/100 nodes\n",
            "  GNN-LRP processed 100/100 nodes\n",
            "Running GOAT for 100 nodes...\n",
            "  GOAT processed 1/100 nodes\n",
            "  GOAT processed 10/100 nodes\n",
            "  GOAT processed 20/100 nodes\n",
            "  GOAT processed 30/100 nodes\n",
            "  GOAT processed 40/100 nodes\n",
            "  GOAT processed 50/100 nodes\n",
            "  GOAT processed 60/100 nodes\n",
            "  GOAT processed 70/100 nodes\n",
            "  GOAT processed 80/100 nodes\n",
            "  GOAT processed 90/100 nodes\n",
            "  GOAT processed 100/100 nodes\n",
            "Running GOAT for 100 nodes...\n",
            "  GOAT processed 1/100 nodes\n",
            "  GOAT processed 10/100 nodes\n",
            "  GOAT processed 20/100 nodes\n",
            "  GOAT processed 30/100 nodes\n",
            "  GOAT processed 40/100 nodes\n",
            "  GOAT processed 50/100 nodes\n",
            "  GOAT processed 60/100 nodes\n",
            "  GOAT processed 70/100 nodes\n",
            "  GOAT processed 80/100 nodes\n",
            "  GOAT processed 90/100 nodes\n",
            "  GOAT processed 100/100 nodes\n",
            "Running LIME for 100 nodes...\n",
            "  LIME processed 1/100 nodes\n",
            "  LIME processed 10/100 nodes\n",
            "  LIME processed 20/100 nodes\n",
            "  LIME processed 30/100 nodes\n",
            "  LIME processed 40/100 nodes\n",
            "  LIME processed 50/100 nodes\n",
            "  LIME processed 60/100 nodes\n",
            "  LIME processed 70/100 nodes\n",
            "  LIME processed 80/100 nodes\n",
            "  LIME processed 90/100 nodes\n",
            "  LIME processed 100/100 nodes\n",
            "Running LIME for 100 nodes...\n",
            "  LIME processed 1/100 nodes\n",
            "  LIME processed 10/100 nodes\n",
            "  LIME processed 20/100 nodes\n",
            "  LIME processed 30/100 nodes\n",
            "  LIME processed 40/100 nodes\n",
            "  LIME processed 50/100 nodes\n",
            "  LIME processed 60/100 nodes\n",
            "  LIME processed 70/100 nodes\n",
            "  LIME processed 80/100 nodes\n",
            "  LIME processed 90/100 nodes\n",
            "  LIME processed 100/100 nodes\n",
            "Running Clean IG for 100 nodes...\n",
            "  IG processed 1/100 nodes\n",
            "  IG processed 10/100 nodes\n",
            "  IG processed 20/100 nodes\n",
            "  IG processed 30/100 nodes\n",
            "  IG processed 40/100 nodes\n",
            "  IG processed 50/100 nodes\n",
            "  IG processed 60/100 nodes\n",
            "  IG processed 70/100 nodes\n",
            "  IG processed 80/100 nodes\n",
            "  IG processed 90/100 nodes\n",
            "  IG processed 100/100 nodes\n",
            "Running Clean IG for 100 nodes...\n",
            "  IG processed 1/100 nodes\n",
            "  IG processed 10/100 nodes\n",
            "  IG processed 20/100 nodes\n",
            "  IG processed 30/100 nodes\n",
            "  IG processed 40/100 nodes\n",
            "  IG processed 50/100 nodes\n",
            "  IG processed 60/100 nodes\n",
            "  IG processed 70/100 nodes\n",
            "  IG processed 80/100 nodes\n",
            "  IG processed 90/100 nodes\n",
            "  IG processed 100/100 nodes\n",
            "\n",
            "=== Computing Recovery scores ===\n",
            "\n",
            "=== Summary: mean Δ = p_orig - p_mask ===\n",
            "\n",
            "Robustness (pick top-k% features):\n",
            "k=0.02 | self=0.9642 | self+2hop=0.9842 | rand=0.5060 | gradcam=0.9987 | lrp=0.6607 | goat=0.9949 | lime=0.3779 | ig=0.9642 | \n",
            "k=0.05 | self=0.9627 | self+2hop=0.9774 | rand=0.5089 | gradcam=0.9987 | lrp=0.6553 | goat=0.9934 | lime=0.3879 | ig=0.9627 | \n",
            "k=0.10 | self=0.8976 | self+2hop=0.9607 | rand=0.5220 | gradcam=0.9988 | lrp=0.5935 | goat=0.9828 | lime=0.4084 | ig=0.8978 | \n",
            "k=0.20 | self=0.7785 | self+2hop=0.9441 | rand=0.5404 | gradcam=0.9987 | lrp=0.5000 | goat=0.9753 | lime=0.4318 | ig=0.7787 | \n",
            "k=0.95 | self=0.3321 | self+2hop=0.8840 | rand=0.4218 | gradcam=0.9967 | lrp=0.2043 | goat=0.9522 | lime=0.3590 | ig=0.3322 | \n"
          ]
        }
      ],
      "source": [
        "import sys\n",
        "sys.path.append(\"GraphLIME\")\n",
        "#from graphlime import GraphLIME\n",
        "from lime.lime_tabular import LimeTabularExplainer\n",
        "import time\n",
        "import argparse\n",
        "import random\n",
        "import numpy as np\n",
        "import torch\n",
        "import torch.nn.functional as F\n",
        "import os\n",
        "from torch_geometric.datasets import Amazon\n",
        "import torch_geometric.transforms as T\n",
        "from torch_geometric.nn import GCNConv\n",
        "from torch_geometric.utils import k_hop_subgraph\n",
        "from torch_geometric.explain import Explainer, GNNExplainer\n",
        "from torch_geometric.explain import Explainer, GNNExplainer, CaptumExplainer\n",
        "from captum.attr import IntegratedGradients\n",
        "\n",
        "# -------------------------------------------------------------\n",
        "# 0. Utils\n",
        "# -------------------------------------------------------------\n",
        "def set_seed(seed: int = 42):\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    torch.cuda.manual_seed_all(seed)\n",
        "    torch.backends.cudnn.deterministic = True\n",
        "    torch.backends.cudnn.benchmark = False\n",
        "\n",
        "def _now():\n",
        "    # sync CUDA to avoid async timing issues\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.synchronize()\n",
        "    return time.time()\n",
        "\n",
        "@torch.no_grad()\n",
        "def is_binary_features(x, tol=1e-6):\n",
        "    \"\"\"\n",
        "    Heuristic: return True if all entries are (approximately) 0 or 1.\n",
        "    \"\"\"\n",
        "    x_flat = x.view(-1)\n",
        "    # allow tiny numerical noise\n",
        "    return torch.all(\n",
        "        ((x_flat - 0.0).abs() < tol) | ((x_flat - 1.0).abs() < tol)\n",
        "    ).item()\n",
        "\n",
        "@torch.no_grad()\n",
        "def augment_with_noisy_features(data, n_noisy: int, seed: int = 42):\n",
        "    \"\"\"\n",
        "    Augment data.x with n_noisy extra 'pure noise' features.\n",
        "\n",
        "    Returns:\n",
        "        data_aug   : a cloned Data object with x.shape = [N, F + n_noisy]\n",
        "        noisy_mask : [F + n_noisy] bool tensor, True for noisy feature indices.\n",
        "    \"\"\"\n",
        "    torch.manual_seed(seed)\n",
        "    device = data.x.device\n",
        "\n",
        "    X = data.x\n",
        "    N, F = X.size()\n",
        "\n",
        "    binary = is_binary_features(X)\n",
        "    print(f\"augment_with_noisy_features: detected binary={binary}, F={F}, N={N}\")\n",
        "\n",
        "    if binary:\n",
        "        # global Bernoulli probability\n",
        "        p = X.mean().item()\n",
        "        print(f\"  Using Bernoulli noise with p={p:.4f}\")\n",
        "        noisy = torch.bernoulli(p * torch.ones(N, n_noisy, device=device))\n",
        "    else:\n",
        "        # global normal noise\n",
        "        mu = X.mean().item()\n",
        "        sigma = X.std().item()\n",
        "        if sigma == 0:\n",
        "            sigma = 1.0\n",
        "        print(f\"  Using Gaussian noise with mean={mu:.4f}, std={sigma:.4f}\")\n",
        "        noisy = torch.randn(N, n_noisy, device=device) * sigma + mu\n",
        "\n",
        "    X_aug = torch.cat([X, noisy], dim=1)  # [N, F + n_noisy]\n",
        "\n",
        "    data_aug = data.clone()\n",
        "    data_aug.x = X_aug\n",
        "\n",
        "    noisy_mask = torch.zeros(F + n_noisy, dtype=torch.bool, device=device)\n",
        "    noisy_mask[F:] = True\n",
        "\n",
        "    print(f\"  Augmented features: original F={F}, noisy={n_noisy}, total={F + n_noisy}\")\n",
        "    return data_aug, noisy_mask\n",
        "\n",
        "@torch.no_grad()\n",
        "def make_noisy_data(data, noise_scale=0.05):\n",
        "    \"\"\"\n",
        "    Return a shallow copy of `data` with noisy x:\n",
        "        x_noisy = x + noise_scale * std(x) * N(0,1)\n",
        "    \"\"\"\n",
        "    device = data.x.device\n",
        "    x = data.x\n",
        "\n",
        "    # Feature-wise std (over nodes)\n",
        "    x_mean = x.mean(dim=0, keepdim=True)\n",
        "    x_std = x.std(dim=0, keepdim=True)\n",
        "    x_std[x_std == 0] = 1.0\n",
        "\n",
        "    noise = torch.randn_like(x) * (noise_scale * x_std.to(device))\n",
        "    x_noisy = x + noise\n",
        "\n",
        "    # Make a new Data object with same everything except x\n",
        "    data_noisy = data.clone()\n",
        "    data_noisy.x = x_noisy\n",
        "\n",
        "    return data_noisy\n",
        "\n",
        "\n",
        "# -------------------------------------------------------------\n",
        "# 1. Dataset + splits\n",
        "# -------------------------------------------------------------\n",
        "# def load_amazon_computers(root: str = \"data/Amazon\"):\n",
        "#     dataset = Amazon(root=root, name=\"Computers\", transform=T.NormalizeFeatures())\n",
        "#     data = dataset[0]\n",
        "\n",
        "#     print(\"=== Amazon-Computers ===\")\n",
        "#     print(f\"#Nodes     = {data.num_nodes}\")\n",
        "#     print(f\"#Edges     = {data.num_edges}\")\n",
        "#     print(f\"#Features  = {dataset.num_features}\")\n",
        "#     print(f\"#Classes   = {dataset.num_classes}\")\n",
        "\n",
        "#     if not hasattr(data, \"train_mask\"):\n",
        "#         # data = create_splits(data, num_train_per_class=20, num_val_per_class=30)\n",
        "#         data = create_splits_stratified(data, train_ratio=0.7, val_ratio=0.1, seed=1234)\n",
        "\n",
        "#     return data, dataset.num_features, dataset.num_classes\n",
        "\n",
        "# ------------------------------------------------------------\n",
        "# Main loader function (with noisy-feature augmentation)\n",
        "# ------------------------------------------------------------\n",
        "def load_amazon_computers(\n",
        "    root: str = \"data/Amazon\",\n",
        "    percent_noisy: float = 0,              # <<< % of noisy features to add\n",
        "    seed: int = 1234,\n",
        "    add_noise = True\n",
        "):\n",
        "    dataset = Amazon(root=root, name=\"Computers\", transform=T.NormalizeFeatures())\n",
        "    data = dataset[0]\n",
        "\n",
        "    print(\"=== Amazon-Computers ===\")\n",
        "    print(f\"#Nodes     = {data.num_nodes}\")\n",
        "    print(f\"#Edges     = {data.num_edges}\")\n",
        "    print(f\"#Features  = {dataset.num_features}\")\n",
        "    print(f\"#Classes   = {dataset.num_classes}\")\n",
        "\n",
        "    # -----------------------------\n",
        "    # 1) Add Noisy Features (NEW)\n",
        "    # -----------------------------\n",
        "    if percent_noisy > 0:\n",
        "        print(f\"Adding {percent_noisy*100}% noisy features...\")\n",
        "        data, noisy_mask = augment_with_noisy_features(\n",
        "            data,\n",
        "            n_noisy=int(percent_noisy*dataset.num_features),\n",
        "            seed=seed\n",
        "        )\n",
        "    else:\n",
        "        noisy_mask = torch.zeros(dataset.num_features, dtype=torch.bool)\n",
        "    if add_noise:\n",
        "        data = make_noisy_data(data)\n",
        "        print(\"Added noise to features.\")\n",
        "    # -----------------------------\n",
        "    # 2) Split (same as before)\n",
        "    # -----------------------------\n",
        "    if not hasattr(data, \"train_mask\") or data.train_mask is None:\n",
        "        data = create_splits_stratified(\n",
        "            data, train_ratio=0.7, val_ratio=0.1, seed=seed\n",
        "        )\n",
        "\n",
        "    # -----------------------------\n",
        "    # Return updated feature dim\n",
        "    # -----------------------------\n",
        "    in_dim = data.x.size(1)\n",
        "    num_classes = dataset.num_classes\n",
        "\n",
        "    print(f\"Final feature dim after augmentation = {in_dim}\")\n",
        "\n",
        "    return data, in_dim, num_classes, noisy_mask\n",
        "\n",
        "def create_splits_stratified(data, train_ratio=0.7, val_ratio=0.1, seed=42):\n",
        "    torch.manual_seed(seed)\n",
        "\n",
        "    y = data.y\n",
        "    num_nodes = data.num_nodes\n",
        "    num_classes = int(y.max().item() + 1)\n",
        "\n",
        "    train_mask = torch.zeros(num_nodes, dtype=torch.bool)\n",
        "    val_mask   = torch.zeros(num_nodes, dtype=torch.bool)\n",
        "    test_mask  = torch.zeros(num_nodes, dtype=torch.bool)\n",
        "\n",
        "    for c in range(num_classes):\n",
        "        idx = (y == c).nonzero(as_tuple=False).view(-1)\n",
        "\n",
        "        if idx.numel() == 0:\n",
        "            continue\n",
        "\n",
        "        idx = idx[torch.randperm(idx.size(0))]\n",
        "\n",
        "        n_c = idx.size(0)\n",
        "        n_train = int(train_ratio * n_c)\n",
        "        n_val   = int(val_ratio * n_c)\n",
        "        n_test  = n_c - n_train - n_val\n",
        "\n",
        "        train_idx = idx[:n_train]\n",
        "        val_idx   = idx[n_train:n_train + n_val]\n",
        "        test_idx  = idx[n_train + n_val:]\n",
        "\n",
        "        train_mask[train_idx] = True\n",
        "        val_mask[val_idx]     = True\n",
        "        test_mask[test_idx]   = True\n",
        "\n",
        "    data.train_mask = train_mask\n",
        "    data.val_mask   = val_mask\n",
        "    data.test_mask  = test_mask\n",
        "\n",
        "    print(f\"#Train nodes = {int(train_mask.sum())}\")\n",
        "    print(f\"#Val nodes   = {int(val_mask.sum())}\")\n",
        "    print(f\"#Test nodes  = {int(test_mask.sum())}\")\n",
        "\n",
        "    return data\n",
        "\n",
        "\n",
        "\n",
        "# -------------------------------------------------------------\n",
        "# 2. 2-layer GCN model\n",
        "# -------------------------------------------------------------\n",
        "class GCN(torch.nn.Module):\n",
        "    def __init__(self, in_channels, hidden_channels, out_channels, dropout=0.5):\n",
        "        super().__init__()\n",
        "        self.conv1 = GCNConv(in_channels, hidden_channels, cached=True)\n",
        "        self.conv2 = GCNConv(hidden_channels, out_channels, cached=True)\n",
        "        self.dropout = dropout\n",
        "\n",
        "    def forward(self, x, edge_index):\n",
        "        x = self.conv1(x, edge_index)\n",
        "        x = F.relu(x)\n",
        "        x = F.dropout(x, p=self.dropout, training=self.training)\n",
        "        x = self.conv2(x, edge_index)\n",
        "        return x\n",
        "\n",
        "\n",
        "def train(model, data, optimizer):\n",
        "    model.train()\n",
        "    optimizer.zero_grad()\n",
        "    out = model(data.x, data.edge_index)\n",
        "    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])\n",
        "    loss.backward()\n",
        "    optimizer.step()\n",
        "    return loss.item()\n",
        "\n",
        "\n",
        "@torch.no_grad()\n",
        "def evaluate(model, data):\n",
        "    model.eval()\n",
        "    logits = model(data.x, data.edge_index)\n",
        "    preds = logits.argmax(dim=-1)\n",
        "\n",
        "    def acc(mask):\n",
        "        total = int(mask.sum())\n",
        "        if total == 0:\n",
        "            return 0.0\n",
        "        return (preds[mask] == data.y[mask]).sum().item() / total\n",
        "\n",
        "    return acc(data.train_mask), acc(data.val_mask), acc(data.test_mask), logits\n",
        "\n",
        "\n",
        "# -------------------------------------------------------------\n",
        "# 3. Build normalized adjacency list from cached GCNConv\n",
        "#    (row = target node, col = source node)\n",
        "# -------------------------------------------------------------\n",
        "@torch.no_grad()\n",
        "def build_normalized_adjacency_list(model, data):\n",
        "    \"\"\"\n",
        "    Use conv1's cached normalized edge weights.\n",
        "    For each node i, store a list of (neighbor_j, weight_ij) where\n",
        "    weight_ij corresponds to A[i, j] in the matrix A @ (XW).\n",
        "    \"\"\"\n",
        "    model.eval()\n",
        "\n",
        "    # Ensure cache is filled\n",
        "    _ = model(data.x, data.edge_index)\n",
        "\n",
        "    edge_index, edge_weight = model.conv1._cached_edge_index\n",
        "    edge_index = edge_index.cpu()\n",
        "    edge_weight = edge_weight.cpu()\n",
        "\n",
        "    num_nodes = data.num_nodes\n",
        "    adj = [[] for _ in range(num_nodes)]  # adj[i] = list of (j, weight) s.t. A[i,j] = weight\n",
        "\n",
        "    # In PyG, edge_index[0] = source (j), edge_index[1] = target (i)\n",
        "    src = edge_index[0].tolist()\n",
        "    dst = edge_index[1].tolist()\n",
        "    w   = edge_weight.tolist()\n",
        "\n",
        "    for s, d, weight in zip(src, dst, w):\n",
        "        adj[d].append((s, weight))\n",
        "\n",
        "    return adj\n",
        "\n",
        "\n",
        "# -------------------------------------------------------------\n",
        "# 4. 2-layer GCN decomposition on an ego-subgraph\n",
        "# -------------------------------------------------------------\n",
        "@torch.no_grad()\n",
        "def decompose_two_layer_gcn_ego(\n",
        "    A_sub,          # [n_sub, n_sub] dense normalized adjacency\n",
        "    X_sub,          # [n_sub, F]\n",
        "    W1, b1,         # W1: [F, H], b1: [H]\n",
        "    W2, b2,         # W2: [H, C], b2: [C]\n",
        "    v_local,        # index of central node in subgraph [0..n_sub-1]\n",
        "    c_idx           # target class index\n",
        "):\n",
        "    \"\"\"\n",
        "    Exact decomposition of 2-layer GCN logit for node v and class c_idx.\n",
        "\n",
        "    Returns:\n",
        "        C: [n_sub, F] contributions from node-features -> logit of v,c_idx\n",
        "        logit_vc: scalar, logit of v,c_idx\n",
        "    \"\"\"\n",
        "    device = X_sub.device\n",
        "    n_sub, F = X_sub.shape\n",
        "    H = W1.shape[1]\n",
        "\n",
        "    # 1) Forward on subgraph\n",
        "    Z1 = A_sub @ (X_sub @ W1) + b1        # [n_sub, H]\n",
        "    M = (Z1 > 0).float()                  # ReLU mask\n",
        "    H1 = M * Z1                           # [n_sub, H]\n",
        "\n",
        "    Z2 = A_sub @ (H1 @ W2) + b2           # [n_sub, C]\n",
        "\n",
        "    # 2) Contributions C[u,f] to logit of v for class c_idx\n",
        "    # φ_{v,u,f} = sum_{j,k} A[v,j] * M[j,k] * A[j,u] * W1[f,k] * W2[k,c_idx]\n",
        "    # We compute this vectorized.\n",
        "\n",
        "    Av = A_sub[v_local]                   # [n_sub] == A[v, :]\n",
        "    # alpha[k,j] = A[v,j] * M[j,k]\n",
        "    alpha = (M.T * Av.unsqueeze(0))       # [H, n_sub]\n",
        "    # S[k, u] = sum_j alpha[k,j] * A[j,u]\n",
        "    S = alpha @ A_sub                     # [H, n_sub]\n",
        "    U = S.T                               # [n_sub, H]\n",
        "\n",
        "    w_kc = W2[:, c_idx]                   # [H]\n",
        "    T = W1 * w_kc.unsqueeze(0)            # [F, H]\n",
        "\n",
        "    # Φ[u,f] = sum_k U[u,k] * T[f,k]\n",
        "    Phi = U @ T.T                         # [n_sub, F]\n",
        "\n",
        "    # Contribution from each node-feature = Φ[u,f] * X_sub[u,f]\n",
        "    C = Phi * X_sub                       # [n_sub, F]\n",
        "\n",
        "    return C, Z2[v_local, c_idx]\n",
        "\n",
        "\n",
        "\n",
        "# # -------------------------------------------------------------\n",
        "# # 5. Build feature importance (self vs self+1hop vs self+2hop)\n",
        "# # -------------------------------------------------------------\n",
        "\n",
        "\n",
        "@torch.no_grad()\n",
        "def build_feature_importance_decomposition(\n",
        "    data,\n",
        "    model,\n",
        "    adj_list,\n",
        "    nodes_to_explain,\n",
        "    hop=2,\n",
        "    use_true_class=True,\n",
        "):\n",
        "    \"\"\"\n",
        "    For each node v in nodes_to_explain, compute three feature importance variants:\n",
        "\n",
        "        self_only[v,f]        = |C[v,f]|        (u = v only)\n",
        "        self_plus_1hop[v,f]   = |C[v,f]| + sum_{u:dist=1} |C[u,f]|\n",
        "        self_plus_2hop[v,f]   = |C[v,f]| + sum_{u:dist<=2} |C[u,f]|\n",
        "\n",
        "    where C[u,f] is contribution from node u's feature f to v's logit.\n",
        "\n",
        "    Returns:\n",
        "        feat_self:       [num_nodes, num_features]\n",
        "        feat_self_1hop:  [num_nodes, num_features]\n",
        "        feat_self_2hop:  [num_nodes, num_features]\n",
        "    \"\"\"\n",
        "    device = next(model.parameters()).device\n",
        "    num_nodes, num_features = data.x.size()\n",
        "\n",
        "    feat_self = torch.zeros(num_nodes, num_features, device=device)\n",
        "    feat_self_1hop = torch.zeros(num_nodes, num_features, device=device)\n",
        "    feat_self_2hop = torch.zeros(num_nodes, num_features, device=device)\n",
        "\n",
        "    # One forward pass to get target class\n",
        "    logits = model(data.x, data.edge_index)\n",
        "    preds = logits.argmax(dim=-1)\n",
        "    if use_true_class:\n",
        "        target_class = data.y\n",
        "    else:\n",
        "        target_class = preds\n",
        "\n",
        "    # Extract weights once\n",
        "    W1 = model.conv1.lin.weight.T.to(device)  # [F, H]\n",
        "    W2 = model.conv2.lin.weight.T.to(device)  # [H, C]\n",
        "\n",
        "    if model.conv1.bias is not None:\n",
        "        b1 = model.conv1.bias.to(device)\n",
        "    else:\n",
        "        b1 = torch.zeros(W1.shape[1], device=device)\n",
        "\n",
        "    edge_index = data.edge_index\n",
        "\n",
        "    print(f\"Building feature importance via decomposition for {len(nodes_to_explain)} nodes...\")\n",
        "    for i, v in enumerate(nodes_to_explain.tolist(), start=1):\n",
        "        v = int(v)\n",
        "\n",
        "        # 2-hop ego subgraph\n",
        "        nodes_sub, _, _, _ = k_hop_subgraph(\n",
        "            v, hop, edge_index, relabel_nodes=False\n",
        "        )\n",
        "        nodes_sub = nodes_sub.to(device)\n",
        "        n_sub = nodes_sub.numel()\n",
        "\n",
        "        # mapping: global node id -> local index in [0..n_sub-1]\n",
        "        # NOTE: nodes_sub is usually small (ego), so Python dict is fine\n",
        "        mapping = {int(nid): idx for idx, nid in enumerate(nodes_sub.tolist())}\n",
        "        v_local = mapping[v]\n",
        "\n",
        "        # Dense normalized adjacency A_sub\n",
        "        A_sub = torch.zeros(n_sub, n_sub, device=device)\n",
        "        # Use adj_list (already normalized) to fill rows\n",
        "        for local_row, global_row in enumerate(nodes_sub.tolist()):\n",
        "            for (nbr, w) in adj_list[global_row]:\n",
        "                if nbr in mapping:\n",
        "                    local_col = mapping[nbr]\n",
        "                    A_sub[local_row, local_col] = w\n",
        "\n",
        "        X_sub = data.x[nodes_sub]  # [n_sub, F]\n",
        "        c_idx = int(target_class[v].item())\n",
        "\n",
        "        # Decomposition on ego subgraph\n",
        "        C_sub, _ = decompose_two_layer_gcn_ego(\n",
        "            A_sub, X_sub, W1, b1, W2, model.conv2.bias, v_local, c_idx\n",
        "        )\n",
        "        C_abs = C_sub.abs()  # [n_sub, F]\n",
        "\n",
        "        # ---- Hop masks without BFS ----\n",
        "        # self node mask\n",
        "        mask_self = (nodes_sub == v)        # [n_sub]\n",
        "\n",
        "        # 1-hop neighbors of v (global indices), then intersect with nodes_sub\n",
        "        neighbors_global = [nbr for (nbr, _) in adj_list[v]]\n",
        "        neighbors_set = set(neighbors_global)\n",
        "        # Build 1-hop mask over nodes_sub\n",
        "        mask_1hop_list = [ (int(nid) in neighbors_set) for nid in nodes_sub.tolist() ]\n",
        "        mask_1hop = torch.tensor(mask_1hop_list, dtype=torch.bool, device=device)\n",
        "\n",
        "        # 2-hop = everyone else in the 2-hop ego (excluding self and 1-hop)\n",
        "        mask_2hop = ~(mask_self | mask_1hop)\n",
        "\n",
        "        # Aggregate contributions by hop\n",
        "        if mask_self.any():\n",
        "            self_contrib = C_abs[mask_self].sum(dim=0)\n",
        "        else:\n",
        "            self_contrib = torch.zeros(num_features, device=device)\n",
        "\n",
        "        if mask_1hop.any():\n",
        "            hop1_contrib = C_abs[mask_1hop].sum(dim=0)\n",
        "        else:\n",
        "            hop1_contrib = torch.zeros(num_features, device=device)\n",
        "\n",
        "        if mask_2hop.any():\n",
        "            hop2_contrib = C_abs[mask_2hop].sum(dim=0)\n",
        "        else:\n",
        "            hop2_contrib = torch.zeros(num_features, device=device)\n",
        "\n",
        "        feat_self[v] = self_contrib\n",
        "        feat_self_1hop[v] = self_contrib + hop1_contrib\n",
        "        feat_self_2hop[v] = self_contrib + hop1_contrib + hop2_contrib\n",
        "\n",
        "        if i % 20 == 0 or i == 1 or i == len(nodes_to_explain):\n",
        "            print(f\"  processed {i}/{len(nodes_to_explain)} nodes\", flush=True)\n",
        "\n",
        "    return feat_self, feat_self_1hop, feat_self_2hop\n",
        "\n",
        "\n",
        "@torch.no_grad()\n",
        "def build_random_importance(num_nodes, num_features, device):\n",
        "    \"\"\"\n",
        "    Generate random feature-importance scores, normalized per-node\n",
        "    so that distributions resemble real importance magnitudes.\n",
        "    \"\"\"\n",
        "    rand_imp = torch.rand(num_nodes, num_features, device=device)\n",
        "    return rand_imp\n",
        "\n",
        "@torch.no_grad()\n",
        "def build_lime_importance(\n",
        "    data,\n",
        "    model,\n",
        "    nodes_to_explain,\n",
        "    num_classes,\n",
        "    num_samples: int = 200,\n",
        "):\n",
        "    \"\"\"\n",
        "    Build LIME-based feature importance for a subset of nodes.\n",
        "\n",
        "    Returns:\n",
        "        feat_lime: [num_nodes, num_features], zeros for nodes not explained.\n",
        "    \"\"\"\n",
        "    device = next(model.parameters()).device\n",
        "    X = data.x\n",
        "    num_nodes, num_features = X.size()\n",
        "\n",
        "    # Fit LIME on training features only (numpy)\n",
        "    X_train_np = X[data.train_mask].detach().cpu().numpy()\n",
        "    class_names = [f\"class_{i}\" for i in range(num_classes)]\n",
        "    feature_names = [f\"f{j}\" for j in range(num_features)]\n",
        "\n",
        "    lime_expl = LimeTabularExplainer(\n",
        "        training_data=X_train_np,\n",
        "        feature_names=feature_names,\n",
        "        class_names=class_names,\n",
        "        discretize_continuous=False,\n",
        "        mode=\"classification\",\n",
        "    )\n",
        "\n",
        "    feat_lime = torch.zeros(num_nodes, num_features, device=device)\n",
        "\n",
        "    nodes_to_explain = nodes_to_explain.to(device)\n",
        "\n",
        "    print(f\"Running LIME for {len(nodes_to_explain)} nodes...\")\n",
        "    for idx, nid in enumerate(nodes_to_explain.tolist(), start=1):\n",
        "        nid = int(nid)\n",
        "        x0 = X[nid].detach().cpu().numpy()\n",
        "\n",
        "        def predict_fn(X_batch_np):\n",
        "            # X_batch_np: [m, D_total] numpy; return probs [m, num_classes]\n",
        "            X_batch = torch.tensor(X_batch_np, dtype=torch.float32, device=device)\n",
        "            probs = []\n",
        "            # Save and restore node nid's features\n",
        "            saved = X[nid].clone()\n",
        "            try:\n",
        "                for i_row in range(X_batch.shape[0]):\n",
        "                    data.x[nid] = X_batch[i_row]\n",
        "                    logits_n = model(data.x, data.edge_index)[nid]  # [C] logits\n",
        "                    prob_n = F.softmax(logits_n, dim=-1)\n",
        "                    probs.append(prob_n.detach().cpu().numpy())\n",
        "            finally:\n",
        "                data.x[nid] = saved\n",
        "            return np.vstack(probs)\n",
        "\n",
        "        # LIME explanation around x0\n",
        "        exp = lime_expl.explain_instance(\n",
        "            data_row=x0,\n",
        "            predict_fn=predict_fn,\n",
        "            num_features=num_features,  # ask for full vector\n",
        "            num_samples=num_samples,\n",
        "        )\n",
        "\n",
        "        # Use the first available label (usually the predicted class)\n",
        "        label = exp.available_labels()[0]\n",
        "        weights = np.zeros(num_features, dtype=np.float32)\n",
        "\n",
        "        # exp.as_map()[label]: list of (feat_idx_or_name, weight)\n",
        "        for feat_id, w in exp.as_map()[label]:\n",
        "            if isinstance(feat_id, str):\n",
        "                if feat_id.startswith(\"f\"):\n",
        "                    feat_idx = int(feat_id[1:])\n",
        "                else:\n",
        "                    feat_idx = int(feat_id)\n",
        "            else:\n",
        "                feat_idx = int(feat_id)\n",
        "            if 0 <= feat_idx < num_features:\n",
        "                weights[feat_idx] = abs(w)\n",
        "\n",
        "        feat_lime[nid] = torch.from_numpy(weights).to(device)\n",
        "\n",
        "        if idx % 10 == 0 or idx == 1 or idx == len(nodes_to_explain):\n",
        "            print(f\"  LIME processed {idx}/{len(nodes_to_explain)} nodes\", flush=True)\n",
        "\n",
        "    return feat_lime\n",
        "\n",
        "@torch.no_grad()\n",
        "def build_graphlime_importance(\n",
        "    data,\n",
        "    model,\n",
        "    nodes_to_explain,\n",
        "    hop: int = 2,\n",
        "    rho: float = 0.1,\n",
        "):\n",
        "    \"\"\"\n",
        "    Build GraphLIME-based feature importance for a subset of nodes.\n",
        "\n",
        "    Assumes you have a GraphLIME API like:\n",
        "        glime = GraphLIME(model=model, hop=hop, rho=rho)\n",
        "        coefs = glime.explain_node(nid, x, edge_index)  # [D_total]\n",
        "\n",
        "    Returns:\n",
        "        feat_glime: [num_nodes, num_features], zeros for nodes not explained.\n",
        "    \"\"\"\n",
        "    device = next(model.parameters()).device\n",
        "    X = data.x.to(device)\n",
        "    num_nodes, num_features = X.size()\n",
        "\n",
        "    glime = GraphLIME(model=model, hop=hop, rho=rho)\n",
        "\n",
        "    feat_glime = torch.zeros(num_nodes, num_features, device=device)\n",
        "    nodes_to_explain = nodes_to_explain.to(device)\n",
        "\n",
        "    print(f\"Running GraphLIME for {len(nodes_to_explain)} nodes...\")\n",
        "    for idx, nid in enumerate(nodes_to_explain.tolist(), start=1):\n",
        "        nid = int(nid)\n",
        "        coefs = glime.explain_node(nid, X, data.edge_index.to(device))  # [D_total]\n",
        "\n",
        "        # coefs could already be a tensor; make sure we abs and move to device\n",
        "        if isinstance(coefs, np.ndarray):\n",
        "            coefs = torch.from_numpy(coefs)\n",
        "        coefs = coefs.to(device)\n",
        "        feat_glime[nid] = coefs.abs()\n",
        "\n",
        "        if idx % 10 == 0 or idx == 1 or idx == len(nodes_to_explain):\n",
        "            print(f\"  GraphLIME processed {idx}/{len(nodes_to_explain)} nodes\", flush=True)\n",
        "\n",
        "    return feat_glime\n",
        "\n",
        "\n",
        "def build_gnnexplainer_importance(\n",
        "    data,\n",
        "    model,\n",
        "    nodes_to_explain,\n",
        "    epochs=50,\n",
        "):\n",
        "    \"\"\"\n",
        "    Use PyG's GNNExplainer (via Explainer) to get per-node feature importance\n",
        "    for the nodes in `nodes_to_explain`.\n",
        "\n",
        "    Returns:\n",
        "        feat_gnnexp: [num_nodes, num_features] tensor\n",
        "                     (zeros for nodes not explained)\n",
        "    \"\"\"\n",
        "    device = next(model.parameters()).device\n",
        "    num_nodes, num_features = data.x.size()\n",
        "\n",
        "    feat_gnnexp = torch.zeros(num_nodes, num_features, device=device)\n",
        "\n",
        "    model.eval()\n",
        "\n",
        "    explainer = Explainer(\n",
        "        model=model,\n",
        "        algorithm=GNNExplainer(epochs=epochs),\n",
        "        explanation_type=\"model\",\n",
        "        node_mask_type=\"attributes\",\n",
        "        edge_mask_type=None,\n",
        "        model_config=dict(\n",
        "            mode=\"multiclass_classification\",\n",
        "            task_level=\"node\",\n",
        "            return_type=\"raw\",  # our model returns logits\n",
        "        ),\n",
        "    )\n",
        "\n",
        "    print(f\"Running GNNExplainer for {len(nodes_to_explain)} nodes...\")\n",
        "    for i, nid in enumerate(nodes_to_explain.tolist(), start=1):\n",
        "        # GNNExplainer call: node-level explanation for this index\n",
        "        exp = explainer(data.x, data.edge_index, index=int(nid))\n",
        "\n",
        "        node_mask = exp.node_mask  # shape [num_nodes_sub, F] or [F]\n",
        "\n",
        "        # Find the row corresponding to the explained node.\n",
        "        # exp.index is the node index used inside the explainer.\n",
        "        if hasattr(exp, \"index\"):\n",
        "            row = int(exp.index)\n",
        "        else:\n",
        "            row = int(nid)\n",
        "\n",
        "        if node_mask.dim() == 2:\n",
        "            feat_imp_row = node_mask[row]     # [F]\n",
        "        else:\n",
        "            feat_imp_row = node_mask          # [F]\n",
        "\n",
        "        # GNNExplainer node mask is non-negative importance; we can take it as-is\n",
        "        feat_gnnexp[nid] = feat_imp_row\n",
        "\n",
        "        if i % 10 == 0 or i == 1 or i == len(nodes_to_explain):\n",
        "            print(f\"  GNNExplainer processed {i}/{len(nodes_to_explain)} nodes\", flush=True)\n",
        "\n",
        "    return feat_gnnexp\n",
        "\n",
        "\n",
        "#IG - Manual way\n",
        "def build_ig_importance(\n",
        "    data,\n",
        "    model,\n",
        "    nodes_to_explain,\n",
        "    n_steps=30,\n",
        "    use_true_class=True,\n",
        "    baseline='zero'\n",
        "):\n",
        "    \"\"\"\n",
        "    Cleaner implementation using torch.autograd.grad directly.\n",
        "    \"\"\"\n",
        "    device = next(model.parameters()).device\n",
        "    num_nodes, num_features = data.x.size()\n",
        "\n",
        "    feat_ig = torch.zeros(num_nodes, num_features, device=device)\n",
        "\n",
        "    model.eval()\n",
        "\n",
        "    # Prepare baseline\n",
        "    if baseline == 'zero':\n",
        "        baseline_x = torch.zeros_like(data.x)\n",
        "    elif baseline == 'random':\n",
        "        baseline_x = torch.rand_like(data.x)\n",
        "    else:\n",
        "        baseline_x = torch.zeros_like(data.x)\n",
        "\n",
        "    x = data.x.to(device)\n",
        "    edge_index = data.edge_index.to(device)\n",
        "\n",
        "    # Get target classes\n",
        "    with torch.no_grad():\n",
        "        logits = model(x, edge_index)\n",
        "        if use_true_class:\n",
        "            target_class = data.y.to(device)\n",
        "        else:\n",
        "            target_class = logits.argmax(dim=-1)\n",
        "\n",
        "    print(f\"Running Clean IG for {len(nodes_to_explain)} nodes...\")\n",
        "\n",
        "    for i, nid in enumerate(nodes_to_explain.tolist(), start=1):\n",
        "        nid = int(nid)\n",
        "        node_target_class = target_class[nid]\n",
        "\n",
        "        # Compute attribution using clean method\n",
        "        attribution = _integrated_gradients_single(\n",
        "            model=model,\n",
        "            x=x,\n",
        "            baseline=baseline_x,\n",
        "            edge_index=edge_index,\n",
        "            target_node=nid,\n",
        "            target_class=node_target_class,\n",
        "            steps=n_steps\n",
        "        )\n",
        "\n",
        "        feat_ig[nid] = attribution.abs()\n",
        "\n",
        "        if i % 10 == 0 or i == 1 or i == len(nodes_to_explain):\n",
        "            print(f\"  IG processed {i}/{len(nodes_to_explain)} nodes\", flush=True)\n",
        "\n",
        "    return feat_ig\n",
        "\n",
        "\n",
        "def _integrated_gradients_single(model, x, baseline, edge_index, target_node, target_class, steps):\n",
        "    \"\"\"\n",
        "    Compute Integrated Gradients for a single node using autograd.grad.\n",
        "    \"\"\"\n",
        "    # Scale factor\n",
        "    scaled_inputs = [baseline + (float(i) / steps) * (x - baseline) for i in range(0, steps + 1)]\n",
        "\n",
        "    # Compute gradients for each scaled input\n",
        "    gradients = []\n",
        "\n",
        "    for scaled_input in scaled_inputs:\n",
        "        scaled_input = scaled_input.clone().requires_grad_(True)\n",
        "\n",
        "        # Forward pass\n",
        "        output = model(scaled_input, edge_index)\n",
        "\n",
        "        # Get the target node's score for target class\n",
        "        target_score = output[target_node, target_class]\n",
        "\n",
        "        # Compute gradient\n",
        "        gradient = torch.autograd.grad(outputs=target_score, inputs=scaled_input)[0]\n",
        "        gradients.append(gradient)\n",
        "\n",
        "    # Average the gradients\n",
        "    avg_gradients = torch.stack(gradients).mean(dim=0)\n",
        "\n",
        "    # Integrated Gradients formula\n",
        "    integrated_gradients = (x - baseline) * avg_gradients\n",
        "\n",
        "    return integrated_gradients[target_node]\n",
        "\n",
        "\n",
        "#GOAT\n",
        "from torch_geometric.utils import to_undirected, add_self_loops, degree\n",
        "\n",
        "def build_gcn_norm_dense(edge_index, num_nodes, device):\n",
        "    edge_index = to_undirected(edge_index, num_nodes=num_nodes)\n",
        "    edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)\n",
        "    row, col = edge_index\n",
        "    deg = degree(row, num_nodes=num_nodes)\n",
        "    deg_inv_sqrt = deg.pow(-0.5)\n",
        "    deg_inv_sqrt[torch.isinf(deg_inv_sqrt)] = 0.\n",
        "    w = deg_inv_sqrt[row] * deg_inv_sqrt[col]\n",
        "    A = torch.zeros((num_nodes, num_nodes), device=device)\n",
        "    A.index_put_((row, col), w, accumulate=True)\n",
        "    return A  # [N,N]\n",
        "\n",
        "def preactivations_and_masks(data, model, A_hat):\n",
        "    \"\"\"\n",
        "    Compute pre-activations z1 and ReLU masks M1 for GOAT,\n",
        "    using the current operating point (data.x).\n",
        "    For a 2-layer GCN:\n",
        "        z1 = A_hat @ (X W1)\n",
        "        h1 = ReLU(z1)\n",
        "        logits = A_hat @ (h1 W2)\n",
        "    \"\"\"\n",
        "    device = data.x.device\n",
        "\n",
        "    # Extract linear weights used inside GCNConv\n",
        "    # conv.lin.weight: [out_channels, in_channels]  so we transpose\n",
        "    W1 = model.conv1.lin.weight.T.to(device)  # [in_dim, H]\n",
        "    W2 = model.conv2.lin.weight.T.to(device)  # [H, C]\n",
        "\n",
        "    X = data.x.to(device)                     # [N, D_total]\n",
        "\n",
        "    # First layer pre-activation: z1 = Â X W1\n",
        "    z1 = A_hat @ (X @ W1)                     # [N, H]\n",
        "    M1 = (z1 > 0).float()                     # ReLU mask for h1\n",
        "\n",
        "    # Logits: z2 = Â (ReLU(z1) W2)\n",
        "    h1 = M1 * z1                              # [N, H] (output of first layer after ReLU)\n",
        "    z2 = A_hat @ (h1 @ W2)                    # [N, C] (logits, output of second layer)\n",
        "\n",
        "    return z1, z2, M1, W1, W2\n",
        "\n",
        "@torch.no_grad()\n",
        "def goat_2layer_feature_importance_for_node(\n",
        "    i: int,\n",
        "    data,\n",
        "    model,\n",
        "    A_hat,\n",
        "    z1, z2, M1,\n",
        "    W1, W2,\n",
        "):\n",
        "    \"\"\"\n",
        "    Returns phi: [D_total], feature contributions to the predicted class logit at node i.\n",
        "\n",
        "    Implements:\n",
        "        phi_f(i,c) ≈ sum_{h,u} x_{u,f} * [ (r_i^T D1_h Â)_u ] * W1[f,h] * W2[h,c]\n",
        "\n",
        "    where:\n",
        "        r_i^T is row i of Â,\n",
        "        D1_h = diag(M1[:,h]) is the ReLU mask for hidden unit h across nodes.\n",
        "    \"\"\"\n",
        "    device = data.x.device\n",
        "    X = data.x.to(device)                    # [N, D_total]\n",
        "    XT = X.t()                               # [D_total, N]\n",
        "    N, D_total = X.size()\n",
        "    H = W1.size(1)                           # hidden size\n",
        "\n",
        "    i = int(i)\n",
        "\n",
        "    # predicted class c at node i (using full model)\n",
        "    logits = model(X, data.edge_index.to(device))  # [N, C]\n",
        "    c = int(logits[i].argmax().item())\n",
        "\n",
        "    r = A_hat[i, :]                          # [N]  (row i of Â)\n",
        "    phi = torch.zeros(D_total, device=device)\n",
        "\n",
        "    # For each hidden channel (h) of the first layer\n",
        "    for h in range(H):\n",
        "        # t1 corresponds to (r_i^T D1_h Â)_u\n",
        "        # D1_h is diag(M1[:,h])\n",
        "        t1 = (r * M1[:, h]) @ A_hat          # [N]\n",
        "\n",
        "        # Aggregate over source nodes u: v_f = sum_u x_{u,f} * t1[u] = (X^T @ t1)_f\n",
        "        v = XT @ t1                      # [D_total]\n",
        "\n",
        "        # Weight chain for (f -> h -> c)\n",
        "        # W1[f, h] * W2[h, c]\n",
        "        chain = W1[:, h] * W2[h, c]      # [D_total]\n",
        "\n",
        "        phi += v * chain\n",
        "\n",
        "    return phi  # [D_total]\n",
        "\n",
        "def build_goat_feature_importance(\n",
        "    data,\n",
        "    model,\n",
        "    nodes_to_explain=None,\n",
        "):\n",
        "    \"\"\"\n",
        "    Compute GOAT-style feature importance for a set of nodes.\n",
        "\n",
        "    Returns:\n",
        "        feat_goat: [num_nodes, num_features]\n",
        "                   feat_goat[v, f] = |GOAT contribution of feature f to logit of v|\n",
        "                   (0 for nodes not in nodes_to_explain)\n",
        "    \"\"\"\n",
        "    device = data.x.device\n",
        "    N, D_total = data.x.size()\n",
        "\n",
        "    # Dense normalized adjacency\n",
        "    A_hat = build_gcn_norm_dense(data.edge_index.to(device), N, device=device)\n",
        "\n",
        "    # Pre-activations, masks, and weights\n",
        "    z1, z2, M1, W1, W2 = preactivations_and_masks(data, model, A_hat)\n",
        "\n",
        "    # If nodes_to_explain is None, explain all nodes\n",
        "    if nodes_to_explain is None:\n",
        "        nodes_to_explain = torch.arange(N, device=device)\n",
        "    else:\n",
        "        nodes_to_explain = nodes_to_explain.to(device)\n",
        "\n",
        "    feat_goat = torch.zeros(N, D_total, device=device)\n",
        "\n",
        "    print(f\"Running GOAT for {len(nodes_to_explain)} nodes...\")\n",
        "    for idx, nid in enumerate(nodes_to_explain.tolist(), start=1):\n",
        "        phi = goat_2layer_feature_importance_for_node(\n",
        "            nid,\n",
        "            data=data,\n",
        "            model=model,\n",
        "            A_hat=A_hat,\n",
        "            z1=z1, z2=z2,\n",
        "            M1=M1,\n",
        "            W1=W1, W2=W2,\n",
        "        )\n",
        "        feat_goat[nid] = phi.abs()\n",
        "\n",
        "        if idx % 10 == 0 or idx == 1 or idx == len(nodes_to_explain):\n",
        "            print(f\"  GOAT processed {idx}/{len(nodes_to_explain)} nodes\", flush=True)\n",
        "\n",
        "    return feat_goat\n",
        "\n",
        "\n",
        "def gradcam_feature_importance(model, data, use_true_class=True):\n",
        "    \"\"\"\n",
        "    GradCAM-style feature importance using ONLY positive gradients as importance scores.\n",
        "    \"\"\"\n",
        "    device = next(model.parameters()).device\n",
        "    x = data.x.to(device)\n",
        "    edge_index = data.edge_index.to(device)\n",
        "\n",
        "    model.eval()\n",
        "\n",
        "    # ---- Forward conv1 with grad tracking ----\n",
        "    x.requires_grad_(True)\n",
        "    z1 = model.conv1(x, edge_index)      # [N, H]\n",
        "    h1 = F.relu(z1)\n",
        "    h1.retain_grad()                     # store grad wrt h1\n",
        "\n",
        "    # Forward to logits\n",
        "    z2 = model.conv2(h1, edge_index)     # [N, C]\n",
        "    logits = z2\n",
        "\n",
        "    # Choose target class per node (true or predicted)\n",
        "    if use_true_class:\n",
        "        target_class = data.y.to(device)\n",
        "    else:\n",
        "        target_class = logits.argmax(dim=-1)\n",
        "\n",
        "    # Backprop from sum over target logits to get global grads\n",
        "    model.zero_grad(set_to_none=True)\n",
        "    if h1.grad is not None:\n",
        "        h1.grad.zero_()\n",
        "\n",
        "    idx = torch.arange(x.size(0), device=device)\n",
        "    target_logits = logits[idx, target_class[idx]]\n",
        "    target_sum = target_logits.sum()\n",
        "    target_sum.backward(retain_graph=False)\n",
        "\n",
        "    grads = h1.grad           # [N, H]\n",
        "\n",
        "    # MODIFICATION: Only keep positive gradients\n",
        "    positive_grads = F.relu(grads)  # [N, H]\n",
        "\n",
        "    # Channel weights α_k: global average over nodes (using only positive grads)\n",
        "    alpha = positive_grads.mean(dim=0)  # [H]\n",
        "\n",
        "    # ReLU mask on z1\n",
        "    M = (z1 > 0).float()      # [N, H]\n",
        "\n",
        "    W1 = model.conv1.lin.weight.T.to(device)  # [F, H]\n",
        "\n",
        "    # S[f, k] = W1[f, k] * alpha[k]\n",
        "    S = W1 * alpha.unsqueeze(0)              # [F, H]\n",
        "\n",
        "    # For each node v, gradient of cam_v wrt X[v, f] is:\n",
        "    #   dcam_dX[v, f] = sum_k M[v,k] * S[f,k]\n",
        "    dcam_dX_T = S @ M.T       # [F, N]\n",
        "    dcam_dX = dcam_dX_T.T     # [N, F]\n",
        "\n",
        "    # MODIFICATION: Use only positive gradients as importance (no multiplication by x)\n",
        "    # Keep only positive values\n",
        "    feat_imp = F.relu(dcam_dX)          # [N, F]\n",
        "\n",
        "    return feat_imp\n",
        "\n",
        "###GNN-LRP\n",
        "def gnn_lrp_feature_importance(model, data, eps: float = 1e-6, use_true_class=True, nodes_to_explain=None):\n",
        "    \"\"\"\n",
        "    Simple 2-layer GCN LRP producing feature-level relevance per node.\n",
        "\n",
        "    We apply epsilon-rule LRP:\n",
        "      1) logit -> hidden units at *same node*\n",
        "      2) hidden units -> input features at *same node*\n",
        "\n",
        "    This ignores neighbor mixing in conv1 but is a clean feature-level baseline.\n",
        "\n",
        "    Returns:\n",
        "        feat_imp: [num_nodes, num_features] tensor of |relevance|.\n",
        "    \"\"\"\n",
        "    device = next(model.parameters()).device\n",
        "    x = data.x.to(device)\n",
        "    edge_index = data.edge_index.to(device)\n",
        "    model.eval()\n",
        "\n",
        "    # Forward full model to get logits and hidden layer\n",
        "    z1 = model.conv1(x, edge_index)   # [N, H]\n",
        "    h1 = F.relu(z1)                   # [N, H]\n",
        "    z2 = model.conv2(h1, edge_index)  # [N, C]\n",
        "    logits = z2\n",
        "\n",
        "    num_nodes, num_features = x.size()\n",
        "    H = h1.size(1)\n",
        "    C = logits.size(1)\n",
        "\n",
        "    if use_true_class:\n",
        "        target_class = data.y.to(device)\n",
        "    else:\n",
        "        target_class = logits.argmax(dim=-1)\n",
        "\n",
        "    # Extract W1, W2\n",
        "    W1 = model.conv1.lin.weight.T.to(device)  # [F, H]\n",
        "    W2 = model.conv2.lin.weight.T.to(device)  # [H, C]\n",
        "\n",
        "    feat_relevance = torch.zeros(num_nodes, num_features, device=device)\n",
        "\n",
        "    if nodes_to_explain is None:\n",
        "        nodes_to_process = torch.arange(num_nodes, device=device)\n",
        "    else:\n",
        "        nodes_to_process = nodes_to_explain.to(device)\n",
        "\n",
        "    print(f\"Running GNN-LRP for {len(nodes_to_process)} nodes...\")\n",
        "    for idx, v_global in enumerate(nodes_to_process.tolist(), start=1):\n",
        "        v = int(v_global)\n",
        "        c = int(target_class[v].item())\n",
        "\n",
        "        # ---- Step 0: output relevance at (v,c)\n",
        "        R_out = logits[v, c]  # scalar\n",
        "\n",
        "        # ---- Step 1: distribute to hidden units at node v\n",
        "        h_v = h1[v]           # [H]\n",
        "        w2_c = W2[:, c]       # [H]\n",
        "\n",
        "        z_vk = h_v * w2_c     # [H]\n",
        "        Zk = z_vk.sum() + eps\n",
        "        if Zk.item() == 0.0:\n",
        "            continue\n",
        "\n",
        "        R_h = (z_vk / Zk) * R_out  # [H]\n",
        "\n",
        "        # ---- Step 2: distribute each hidden unit relevance to features of node v\n",
        "        x_v = x[v]                 # [F]\n",
        "        # z_fk for each k: [F] = x_v * W1[:, k]\n",
        "        R_x_v = torch.zeros(num_features, device=device)\n",
        "\n",
        "        for k in range(H):\n",
        "            w1_k = W1[:, k]        # [F]\n",
        "            z_fk = x_v * w1_k      # [F]\n",
        "            Zf = z_fk.sum() + eps\n",
        "            if Zf.item() != 0.0:\n",
        "                R_x_v += (z_fk / Zf) * R_h[k]  # [F]\n",
        "\n",
        "        feat_relevance[v] = R_x_v.abs()\n",
        "\n",
        "        if idx % 10 == 0 or idx == 1 or idx == len(nodes_to_process):\n",
        "            print(f\"  GNN-LRP processed {idx}/{len(nodes_to_process)} nodes\", flush=True)\n",
        "\n",
        "    return feat_relevance\n",
        "\n",
        "\n",
        "\n",
        "@torch.no_grad()\n",
        "def compute_noisy_fraction(\n",
        "    feat_imp,        # [N, F_total]\n",
        "    noisy_mask,      # [F_total] bool\n",
        "    nodes,           # [M]\n",
        "    k_frac,          # float or list/tuple of floats\n",
        "    label=''\n",
        "):\n",
        "    \"\"\"\n",
        "    Compute fraction of noisy features in top-k importance features.\n",
        "\n",
        "    k_frac:\n",
        "        - float: return scalar\n",
        "        - list/tuple of floats: return dict {k: fraction}\n",
        "\n",
        "    Returns scalar or dict.\n",
        "    \"\"\"\n",
        "    print(f\"\\n=== Recovery scores for importance: {label} ===\")\n",
        "    device = feat_imp.device\n",
        "    nodes = nodes.to(device)\n",
        "    noisy_mask = noisy_mask.to(device)\n",
        "\n",
        "    imp_nodes = feat_imp[nodes]  # [M, F_total]\n",
        "    M, F_total = imp_nodes.shape\n",
        "\n",
        "    # convert to iterable\n",
        "    if isinstance(k_frac, (float, int)):\n",
        "        k_frac_list = [float(k_frac)]\n",
        "        return_dict = False\n",
        "    else:\n",
        "        k_frac_list = [float(k) for k in k_frac]\n",
        "        return_dict = True\n",
        "\n",
        "    results = {}\n",
        "\n",
        "    for kf in k_frac_list:\n",
        "        k = max(1, int(round(kf * F_total)))\n",
        "\n",
        "        # top-k indices per node\n",
        "        topk_idx = torch.topk(imp_nodes, k=k, dim=1, largest=True).indices  # [M, k]\n",
        "\n",
        "        # fraction noisy\n",
        "        noisy_flags = noisy_mask[topk_idx]       # [M, k]\n",
        "        frac = noisy_flags.float().mean().item()\n",
        "\n",
        "        results[kf] = frac\n",
        "\n",
        "    # return scalar if single k_frac was given\n",
        "    if not return_dict:\n",
        "        return results[k_frac_list[0]]\n",
        "    return results\n",
        "\n",
        "\n",
        "# -------------------------------------------------------------\n",
        "# 6. Fidelity metrics (feature deletion / keeping)\n",
        "# -------------------------------------------------------------\n",
        "@torch.no_grad()\n",
        "def compute_fidelity_scores(\n",
        "    data,\n",
        "    model,\n",
        "    feat_imp,\n",
        "    nodes,\n",
        "    k_fracs=(0.1, 0.2, 0.5),\n",
        "    use_true_class=True,\n",
        "    label=\"\",\n",
        "):\n",
        "    \"\"\"\n",
        "    Compute fidelity+ / fidelity- / keep-top for given feat_imp.\n",
        "    \"\"\"\n",
        "    device = next(model.parameters()).device\n",
        "    model.eval()\n",
        "\n",
        "    x_orig = data.x.clone()\n",
        "    edge_index = data.edge_index\n",
        "\n",
        "    logits_orig = model(x_orig, edge_index)\n",
        "    probs_orig = F.softmax(logits_orig, dim=-1)\n",
        "\n",
        "    if use_true_class:\n",
        "        target_class = data.y\n",
        "    else:\n",
        "        target_class = logits_orig.argmax(dim=-1)\n",
        "\n",
        "    nodes = nodes.to(device)\n",
        "    num_features = x_orig.size(1)\n",
        "\n",
        "    p_orig_nodes = probs_orig[nodes, target_class[nodes]]\n",
        "\n",
        "    results_plus = {}\n",
        "    results_minus = {}\n",
        "    results_keep = {}\n",
        "\n",
        "    print(f\"\\n=== Fidelity scores for importance: {label} ===\")\n",
        "    for k_frac in k_fracs:\n",
        "        k = max(1, int(round(k_frac * num_features)))\n",
        "        print(f\"\\n--- k_frac = {k_frac:.2f} (top {int(k_frac*100)}% -> k={k}) ---\")\n",
        "\n",
        "        imp_nodes = feat_imp[nodes]  # [M, F]\n",
        "        topk_idx = torch.topk(imp_nodes, k=k, dim=1, largest=True).indices    # [M, k]\n",
        "        bottomk_idx = torch.topk(imp_nodes, k=k, dim=1, largest=False).indices  # [M, k]\n",
        "\n",
        "        rows = nodes.unsqueeze(1).expand(-1, k).reshape(-1)\n",
        "\n",
        "        # 1) Fidelity+ (remove TOP-k)\n",
        "        x_mask_top = x_orig.clone()\n",
        "        cols_top = topk_idx.reshape(-1)\n",
        "        x_mask_top[rows, cols_top] = 0.0\n",
        "\n",
        "        logits_mask_top = model(x_mask_top, edge_index)\n",
        "        probs_mask_top = F.softmax(logits_mask_top, dim=-1)\n",
        "        p_mask_top = probs_mask_top[nodes, target_class[nodes]]\n",
        "\n",
        "        delta_plus = p_orig_nodes - p_mask_top\n",
        "        results_plus[k_frac] = float(delta_plus.mean().item())\n",
        "        print(f\"Fidelity+ (remove TOP-k)   = {results_plus[k_frac]:.4f}\")\n",
        "\n",
        "        # 2) Fidelity- (remove BOTTOM-k)\n",
        "        x_mask_bottom = x_orig.clone()\n",
        "        cols_bottom = bottomk_idx.reshape(-1)\n",
        "        x_mask_bottom[rows, cols_bottom] = 0.0\n",
        "\n",
        "        logits_mask_bottom = model(x_mask_bottom, edge_index)\n",
        "        probs_mask_bottom = F.softmax(logits_mask_bottom, dim=-1)\n",
        "        p_mask_bottom = probs_mask_bottom[nodes, target_class[nodes]]\n",
        "\n",
        "        delta_minus = p_orig_nodes - p_mask_bottom\n",
        "        results_minus[k_frac] = float(delta_minus.mean().item())\n",
        "        print(f\"Fidelity- (remove BOTTOM-k)= {results_minus[k_frac]:.4f}\")\n",
        "\n",
        "        # 3) Keep-top (only TOP-k kept)\n",
        "        x_keep_top = x_orig.clone()\n",
        "        x_keep_top[nodes] = 0.0\n",
        "        cols_keep = topk_idx.reshape(-1)\n",
        "        x_keep_top[rows, cols_keep] = x_orig[rows, cols_keep]\n",
        "\n",
        "        logits_keep_top = model(x_keep_top, edge_index)\n",
        "        probs_keep_top = F.softmax(logits_keep_top, dim=-1)\n",
        "        p_keep_top = probs_keep_top[nodes, target_class[nodes]]\n",
        "\n",
        "        delta_keep = p_orig_nodes - p_keep_top\n",
        "        results_keep[k_frac] = float(delta_keep.mean().item())\n",
        "        print(f\"Keep-top (only TOP-k kept) = {results_keep[k_frac]:.4f}\")\n",
        "\n",
        "    return results_plus, results_minus, results_keep\n",
        "\n",
        "@torch.no_grad()\n",
        "def compute_robustness_topk(\n",
        "    feat_imp_clean,\n",
        "    feat_imp_noisy,\n",
        "    nodes,\n",
        "    k_frac,  # Can be float or list of floats\n",
        "    eps: float = 1e-8,\n",
        "):\n",
        "    device = feat_imp_clean.device\n",
        "    nodes = nodes.to(device)\n",
        "\n",
        "    I_c = feat_imp_clean[nodes]  # [M, F]\n",
        "    I_n = feat_imp_noisy[nodes]  # [M, F]\n",
        "\n",
        "    M, F = I_c.shape\n",
        "\n",
        "    # Handle both single k_frac and list of k_frac\n",
        "    if isinstance(k_frac, (int, float)):\n",
        "        k_frac = [k_frac]\n",
        "\n",
        "    results = {}\n",
        "\n",
        "    for kf in k_frac:\n",
        "        k = max(1, int(round(kf * F)))\n",
        "\n",
        "        # 1) top-k indices *from clean*\n",
        "        topk_idx = torch.topk(I_c.abs(), k=k, dim=1, largest=True).indices   # [M, k]\n",
        "\n",
        "        # 2) gather clean/noisy importances at those indices\n",
        "        rows = torch.arange(M, device=device).unsqueeze(1).expand(-1, k)     # [M, k]\n",
        "        I_c_top = I_c[rows, topk_idx]\n",
        "        I_n_top = I_n[rows, topk_idx]\n",
        "\n",
        "        # 3) robustness\n",
        "        diff = (I_c_top - I_n_top).abs()\n",
        "        denom = I_c_top.abs() + eps\n",
        "        R = 1.0 - diff / denom\n",
        "        R = torch.clamp(R, min=0.0, max=1.0)\n",
        "\n",
        "        # FIX: Use numeric key instead of string key\n",
        "        results[kf] = float(R.mean().item())\n",
        "\n",
        "    return results\n",
        "\n",
        "\n",
        "# -------------------------------------------------------------\n",
        "# 7. Main script\n",
        "# -------------------------------------------------------------\n",
        "def main():\n",
        "    method_times = {}\n",
        "    parser = argparse.ArgumentParser()\n",
        "    parser.add_argument(\"--seed\", type=int, default=42)\n",
        "    parser.add_argument(\"--hidden_dim\", type=int, default=64)\n",
        "    parser.add_argument(\"--dropout\", type=float, default=0.5)\n",
        "    parser.add_argument(\"--epochs\", type=int, default=500)\n",
        "    parser.add_argument(\"--lr\", type=float, default=0.05)\n",
        "    parser.add_argument(\"--weight_decay\", type=float, default=5e-4)\n",
        "    parser.add_argument(\"--root\", type=str, default=\"data/Amazon\")\n",
        "    parser.add_argument(\"--k_fracs\", type=float, nargs=\"+\",\n",
        "                        default=[0.02,0.05,0.1,0.2,0.95]) #0.95 sanity check\n",
        "    parser.add_argument(\"--num_explain\", type=int, default=100,\n",
        "                        help=\"how many test nodes to explain via decomposition\")\n",
        "    parser.add_argument(\"--use_true_class\", action=\"store_true\",\n",
        "                        help=\"use true label instead of predicted label as target class\")\n",
        "    parser.add_argument(\"--model_path\", type=str, default=\"amazon_computer_gcn.pt\",\n",
        "                        help=\"path to save/load trained GCN model\")\n",
        "    args = parser.parse_args(args=[])\n",
        "\n",
        "    set_seed(args.seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    # 1) Load data\n",
        "    data, in_dim, num_classes,_ = load_amazon_computers(args.root,add_noise=False)\n",
        "    data = data.to(device)\n",
        "\n",
        "    data_noise, _, _,_ = load_amazon_computers(args.root,add_noise=True)\n",
        "    data_noise = data_noise.to(device)\n",
        "\n",
        "    # 2) Build model\n",
        "    model = GCN(\n",
        "        in_channels=in_dim,\n",
        "        hidden_channels=args.hidden_dim,\n",
        "        out_channels=num_classes,\n",
        "        dropout=args.dropout,\n",
        "    ).to(device)\n",
        "\n",
        "    # 3) Train or load model\n",
        "    if os.path.exists(args.model_path):\n",
        "        print(f\"\\n=== Loading existing model from {args.model_path} ===\")\n",
        "        state = torch.load(args.model_path, map_location=device)\n",
        "        model.load_state_dict(state)\n",
        "        train_acc, val_acc, test_acc, _ = evaluate(model, data)\n",
        "        print(f\"Loaded model | Train={train_acc:.4f} Val={val_acc:.4f} Test={test_acc:.4f}\")\n",
        "    else:\n",
        "        print(\"\\n=== Training GCN (no saved model found) ===\")\n",
        "        optimizer = torch.optim.Adam(\n",
        "            model.parameters(),\n",
        "            lr=args.lr,\n",
        "            weight_decay=args.weight_decay,\n",
        "        )\n",
        "\n",
        "        best_val_acc = 0.0\n",
        "        best_state = None\n",
        "\n",
        "        for epoch in range(1, args.epochs + 1):\n",
        "            loss = train(model, data, optimizer)\n",
        "            train_acc, val_acc, test_acc, _ = evaluate(model, data)\n",
        "\n",
        "            if val_acc > best_val_acc:\n",
        "                best_val_acc = val_acc\n",
        "                best_state = {\n",
        "                    \"model\": model.state_dict(),\n",
        "                    \"epoch\": epoch,\n",
        "                    \"test_acc\": test_acc,\n",
        "                }\n",
        "\n",
        "            if epoch % 20 == 0 or epoch == 1:\n",
        "                print(\n",
        "                    f\"Epoch {epoch:03d} | \"\n",
        "                    f\"Loss {loss:.4f} | \"\n",
        "                    f\"Train {train_acc:.4f} | \"\n",
        "                    f\"Val {val_acc:.4f} | \"\n",
        "                    f\"Test {test_acc:.4f}\"\n",
        "                )\n",
        "\n",
        "        if best_state is not None:\n",
        "            model.load_state_dict(best_state[\"model\"])\n",
        "            print(\n",
        "                f\"\\nBest epoch = {best_state['epoch']} | \"\n",
        "                f\"Best val acc = {best_val_acc:.4f} | \"\n",
        "                f\"Test acc @best = {best_state['test_acc']:.4f}\"\n",
        "            )\n",
        "\n",
        "        # Save trained model\n",
        "        torch.save(model.state_dict(), args.model_path)\n",
        "        print(f\"Model saved to {args.model_path}\")\n",
        "\n",
        "    # 4) Build normalized adjacency list\n",
        "    print(\"\\n=== Building normalized adjacency list ===\")\n",
        "    adj_list = build_normalized_adjacency_list(model, data)\n",
        "\n",
        "    # 5) Choose test nodes to explain (correct predictions preferred)\n",
        "    model.eval()\n",
        "    logits = model(data.x, data.edge_index)\n",
        "    preds = logits.argmax(dim=-1)\n",
        "\n",
        "    test_nodes = torch.nonzero(data.test_mask, as_tuple=False).view(-1)\n",
        "    correct_mask = preds == data.y\n",
        "    test_correct_nodes = test_nodes[correct_mask[test_nodes]]\n",
        "\n",
        "    if test_correct_nodes.numel() == 0:\n",
        "        print(\"WARNING: no correctly classified test nodes, using all test nodes.\")\n",
        "        nodes_to_explain = test_nodes\n",
        "    else:\n",
        "        nodes_to_explain = test_correct_nodes\n",
        "\n",
        "    if nodes_to_explain.numel() > args.num_explain:\n",
        "        nodes_to_explain = nodes_to_explain[:args.num_explain]\n",
        "\n",
        "    print(\n",
        "        f\"\\n=== Explaining {nodes_to_explain.numel()} test nodes \"\n",
        "        f\"({'true' if args.use_true_class else 'pred'} class as target) ===\"\n",
        "    )\n",
        "\n",
        "    # 6) Build feature importance: self vs self+1hop vs self+2hop\n",
        "\n",
        "    feat_self, feat_self_1hop, feat_self_2hop = build_feature_importance_decomposition(\n",
        "        data=data,\n",
        "        model=model,\n",
        "        adj_list=adj_list,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        "        hop=2,\n",
        "        use_true_class=args.use_true_class,\n",
        "    )\n",
        "\n",
        "\n",
        "    feat_self_noise, feat_self_1hop_noise, feat_self_2hop_noise = build_feature_importance_decomposition(\n",
        "        data=data_noise,\n",
        "        model=model,\n",
        "        adj_list=adj_list,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        "        hop=2,\n",
        "        use_true_class=args.use_true_class,\n",
        "    )\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "#     # feat_self, feat_self_1hop, feat_self_2hop = build_feature_importance_decomposition(\n",
        "#     #     data=data,\n",
        "#     #     model=model,\n",
        "#     #     adj_list=adj_list,\n",
        "#     #     nodes_to_explain=nodes_to_explain,\n",
        "#     #     hop=2,\n",
        "#     #     use_true_class=args.use_true_class,\n",
        "#     # )\n",
        "#     # ---- Random baseline ----\n",
        "#     print(\"\\n=== Building RANDOM importance baseline ===\")\n",
        "\n",
        "#     t0 = _now()\n",
        "    feat_random = build_random_importance(\n",
        "        num_nodes=data.num_nodes,\n",
        "        num_features=data.x.size(1),\n",
        "        device=device,\n",
        "    )\n",
        "    feat_random_noise = build_random_importance(\n",
        "        num_nodes=data.num_nodes,\n",
        "        num_features=data.x.size(1),\n",
        "        device=device,\n",
        "    )\n",
        "#     t1 = _now()\n",
        "#     method_times[\"random\"] = t1 - t0\n",
        "\n",
        "# #     feat_random = build_random_importance(\n",
        "# #         num_nodes=data.num_nodes,\n",
        "# #         num_features=data.x.size(1),\n",
        "# #         device=device\n",
        "# # )\n",
        "#     print(\"\\n=== Building GRAD-Cam ===\")\n",
        "#     # feat_gradcam = gradcam_feature_importance(model, data, use_true_class=True)\n",
        "#     t0 = _now()\n",
        "    feat_gradcam = gradcam_feature_importance(\n",
        "        model=model,\n",
        "        data=data,\n",
        "        use_true_class=args.use_true_class,\n",
        "    )\n",
        "    feat_gradcam_noise = gradcam_feature_importance(\n",
        "        model=model,\n",
        "        data=data_noise,\n",
        "        use_true_class=args.use_true_class,\n",
        "    )\n",
        "#     t1 = _now()\n",
        "#     method_times[\"gradcam\"] = t1 - t0\n",
        "\n",
        "#     print(\"\\n=== Building GNN-LRP ===\")\n",
        "# #     feat_lrp = gnn_lrp_feature_importance(\n",
        "# #     model,\n",
        "# #     data,\n",
        "# #     eps=1e-6,\n",
        "# #     use_true_class=True,\n",
        "# #     nodes_to_explain=nodes_to_explain  # same subset you use for fidelity\n",
        "# # )\n",
        "#     t0 = _now()\n",
        "    feat_lrp = gnn_lrp_feature_importance(\n",
        "        model=model,\n",
        "        data=data,\n",
        "        eps=1e-6,\n",
        "        use_true_class=args.use_true_class,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        "    )\n",
        "    feat_lrp_noise = gnn_lrp_feature_importance(\n",
        "        model=model,\n",
        "        data=data_noise,\n",
        "        eps=1e-6,\n",
        "        use_true_class=args.use_true_class,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        "    )\n",
        "#     t1 = _now()\n",
        "#     method_times[\"lrp\"] = t1 - t0\n",
        "\n",
        "\n",
        "#     print(\"\\n=== Building GOAT ===\")\n",
        "#     # feat_goat = build_goat_feature_importance(\n",
        "#     #     data=data,\n",
        "#     #     model=model,\n",
        "#     #     nodes_to_explain=nodes_to_explain,\n",
        "#     # )\n",
        "#     t0 = _now()\n",
        "    feat_goat = build_goat_feature_importance(\n",
        "        data=data,\n",
        "        model=model,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        "    )\n",
        "    feat_goat_noise = build_goat_feature_importance(\n",
        "        data=data_noise,\n",
        "        model=model,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        "    )\n",
        "    # t1 = _now()\n",
        "#     method_times[\"goat\"] = t1 - t0\n",
        "\n",
        "#     print(\"\\n=== Building LIME ===\")\n",
        "#     # feat_lime = build_lime_importance(\n",
        "#     #     data=data,\n",
        "#     #     model=model,\n",
        "#     #     nodes_to_explain=nodes_to_explain,\n",
        "#     #     num_classes=num_classes,\n",
        "#     #     num_samples=50,\n",
        "#     # )\n",
        "#     # t0 = _now()\n",
        "    feat_lime = build_lime_importance(\n",
        "        data=data,\n",
        "        model=model,\n",
        "        nodes_to_explain=nodes_to_explain,  # or smaller subset\n",
        "        num_classes=num_classes,\n",
        "        num_samples=50,\n",
        "    )\n",
        "    feat_lime_noise = build_lime_importance(\n",
        "        data=data_noise,\n",
        "        model=model,\n",
        "        nodes_to_explain=nodes_to_explain,  # or smaller subset\n",
        "        num_classes=num_classes,\n",
        "        num_samples=50,\n",
        "    )\n",
        "#     t1 = _now()\n",
        "#     method_times[\"lime\"] = t1 - t0\n",
        "\n",
        "#     # GraphLime: if possible\n",
        "#     # print(\"\\n=== Building GraphLIME ===\")\n",
        "#     # t0 = _now()\n",
        "#     # feat_glime = build_graphlime_importance(\n",
        "#     # data=data,\n",
        "#     # model=model,\n",
        "#     # nodes_to_explain=nodes_to_explain,\n",
        "#     # hop=2,\n",
        "#     # rho=0.1,\n",
        "#     # )\n",
        "#     # t1 = _now()\n",
        "#     # method_times[\"graphlime\"] = t1 - t0\n",
        "\n",
        "\n",
        "#     #probably not use the following two ---\n",
        "\n",
        "#     #     # ---- GNNExplainer baseline importance ----\n",
        "#     # print(\"\\n=== Building GNNExplainer importance baseline ===\")\n",
        "#     # feat_gnnexp = build_gnnexplainer_importance(\n",
        "#     #     data=data,\n",
        "#     #     model=model,\n",
        "#     #     nodes_to_explain=nodes_to_explain,\n",
        "#     #     epochs=50,   # you can change this\n",
        "#     # )\n",
        "\n",
        "#     print(\"\\n=== Building Integrated Gradients (Captum) importance baseline ===\")\n",
        "#     t0 = _now()\n",
        "    feat_ig = build_ig_importance(\n",
        "        data=data,\n",
        "        model=model,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        ")\n",
        "    feat_ig_noise = build_ig_importance(\n",
        "        data=data_noise,\n",
        "        model=model,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        ")\n",
        "#     t1 = _now()\n",
        "#     method_times[\"ig\"] = t1 - t0\n",
        "\n",
        "    # 7) Recovery scores for each importance variant\n",
        "    print(\"\\n=== Computing Recovery scores ===\")\n",
        "\n",
        "    robust_self = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_self,\n",
        "    feat_imp_noisy= feat_self_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "    robust_s2 = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_self_2hop,\n",
        "    feat_imp_noisy= feat_self_2hop_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "    robust_gradcam = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_gradcam,\n",
        "    feat_imp_noisy= feat_gradcam_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "    robust_lrp = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_lrp,\n",
        "    feat_imp_noisy= feat_lrp_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "    robust_goat = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_goat,\n",
        "    feat_imp_noisy= feat_goat_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "    robust_random = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_random,\n",
        "    feat_imp_noisy= feat_random_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "    robust_lime = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_lime,\n",
        "    feat_imp_noisy= feat_lime_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "    robust_ig = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_ig,\n",
        "    feat_imp_noisy= feat_ig_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "\n",
        "    # frac_self = compute_noisy_fraction(\n",
        "    #   feat_imp=feat_self,       # [N, F_total]\n",
        "    #   noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #   nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #   k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #   label = 'SELF'\n",
        "    # )\n",
        "    # frac_s1 = compute_noisy_fraction(\n",
        "    #   feat_imp=feat_self_1hop,       # [N, F_total]\n",
        "    #   noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #   nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #   k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #   label = 'SELF + 1HOP'\n",
        "    # )\n",
        "    # frac_s2 = compute_noisy_fraction(\n",
        "    #   feat_imp=feat_self_2hop,       # [N, F_total]\n",
        "    #   noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #   nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #   k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #   label = 'SELF + 1HOP + 2HOP'\n",
        "    # )\n",
        "#     fp_s1, fm_s1, fk_s1 = compute_fidelity_scores(\n",
        "#         data=data,\n",
        "#         model=model,\n",
        "#         feat_imp=feat_self_1hop,\n",
        "#         nodes=nodes_to_explain,\n",
        "#         k_fracs=args.k_fracs,\n",
        "#         use_true_class=args.use_true_class,\n",
        "#         label=\"SELF + 1HOP\",\n",
        "#     )\n",
        "\n",
        "#     fp_s2, fm_s2, fk_s2 = compute_fidelity_scores(\n",
        "#         data=data,\n",
        "#         model=model,\n",
        "#         feat_imp=feat_self_2hop,\n",
        "#         nodes=nodes_to_explain,\n",
        "#         k_fracs=args.k_fracs,\n",
        "#         use_true_class=args.use_true_class,\n",
        "#         label=\"SELF + 1HOP + 2HOP\",\n",
        "#     )\n",
        "    # frac_rand = compute_noisy_fraction(\n",
        "    #   feat_imp=feat_random,       # [N, F_total]\n",
        "    #   noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #   nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #   k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #   label = 'RANDOM BASELINE'\n",
        "    # )\n",
        "    # frac_gc = compute_noisy_fraction(\n",
        "    #   feat_imp=feat_gradcam,       # [N, F_total]\n",
        "    #   noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #   nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #   k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #   label = 'GradCAM-feature'\n",
        "    # )\n",
        "#     fp_rand, fm_rand, fk_rand = compute_fidelity_scores(\n",
        "#     data=data,\n",
        "#     model=model,\n",
        "#     feat_imp=feat_random,\n",
        "#     nodes=nodes_to_explain,\n",
        "#     k_fracs=args.k_fracs,\n",
        "#     use_true_class=args.use_true_class,\n",
        "#     label=\"RANDOM BASELINE\",\n",
        "#     )\n",
        "\n",
        "#     fp_gc, fm_gc, fk_gc = compute_fidelity_scores(\n",
        "#     data=data,\n",
        "#     model=model,\n",
        "#     feat_imp=feat_gradcam,\n",
        "#     nodes=nodes_to_explain,\n",
        "#     k_fracs=args.k_fracs,\n",
        "#     use_true_class=args.use_true_class,\n",
        "#     label=\"GradCAM-feature\",\n",
        "# )\n",
        "\n",
        "#     fp_lrp, fm_lrp, fk_lrp = compute_fidelity_scores(\n",
        "#         data=data,\n",
        "#         model=model,\n",
        "#         feat_imp=feat_lrp,\n",
        "#         nodes=nodes_to_explain,\n",
        "#         k_fracs=args.k_fracs,\n",
        "#         use_true_class=args.use_true_class,\n",
        "#         label=\"GNN-LRP-feature\",\n",
        "#     )\n",
        "    # frac_lrp = compute_noisy_fraction(\n",
        "    #     feat_imp=feat_lrp,       # [N, F_total]\n",
        "    #     noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #     nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #     k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #     label = 'GradCAM-feature'\n",
        "    #   )\n",
        "\n",
        "#     fp_goat, fm_goat, fk_goat = compute_fidelity_scores(\n",
        "#         data=data,\n",
        "#         model=model,\n",
        "#         feat_imp=feat_goat,\n",
        "#         nodes=nodes_to_explain,\n",
        "#         k_fracs=args.k_fracs,\n",
        "#         use_true_class=args.use_true_class,\n",
        "#         label=\"GOAT-feature\",\n",
        "#     )\n",
        "    # frac_goat = compute_noisy_fraction(\n",
        "    #     feat_imp=feat_goat,       # [N, F_total]\n",
        "    #     noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #     nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #     k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #     label = 'GradCAM-feature'\n",
        "    #   )\n",
        "    # frac_lime = compute_noisy_fraction(\n",
        "    #     feat_imp=feat_lime,       # [N, F_total]\n",
        "    #     noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #     nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #     k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #     label = 'GradCAM-feature'\n",
        "      # )\n",
        "#     fp_lime, fm_lime, fk_lime = compute_fidelity_scores(\n",
        "#         data=data,\n",
        "#         model=model,\n",
        "#         feat_imp=feat_lime,\n",
        "#         nodes=nodes_to_explain,\n",
        "#         k_fracs=args.k_fracs,\n",
        "#         use_true_class=args.use_true_class,\n",
        "#         label=\"LIME\",\n",
        "#     )\n",
        "    # fp_glime, fm_glime, fk_glime = compute_fidelity_scores(\n",
        "    # data=data,\n",
        "    # model=model,\n",
        "    # feat_imp=feat_glime,\n",
        "    # nodes=nodes_to_explain,\n",
        "    # k_fracs=args.k_fracs,\n",
        "    # use_true_class=args.use_true_class,\n",
        "    # label=\"GraphLIME\",\n",
        "    # )\n",
        "    # fp_gnn, fm_gnn, fk_gnn = compute_fidelity_scores(\n",
        "    # data=data,\n",
        "    # model=model,\n",
        "    # feat_imp=feat_gnnexp,\n",
        "    # nodes=nodes_to_explain,\n",
        "    # k_fracs=args.k_fracs,\n",
        "    # use_true_class=args.use_true_class,\n",
        "    # label=\"GNNEXPLAINER BASELINE\",\n",
        "    # )\n",
        "  #   fp_ig, fm_ig, fk_ig = compute_fidelity_scores(\n",
        "  #   data=data,\n",
        "  #   model=model,\n",
        "  #   feat_imp=feat_ig,\n",
        "  #   nodes=nodes_to_explain,\n",
        "  #   k_fracs=args.k_fracs,\n",
        "  #   use_true_class=args.use_true_class,\n",
        "  #   label=\"IG (Captum) BASELINE\",\n",
        "  # )\n",
        "    # frac_ig = compute_noisy_fraction(\n",
        "    #     feat_imp=feat_ig,       # [N, F_total]\n",
        "    #     noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #     nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #     k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #     label = 'IG BASELINE'\n",
        "    #   )\n",
        "\n",
        "\n",
        "    # 8) Summary\n",
        "    print(\"\\n=== Summary: mean Δ = p_orig - p_mask ===\")\n",
        "    print(\"\\nRobustness (pick top-k% features):\")\n",
        "    for k in args.k_fracs:\n",
        "        print(\n",
        "            f\"k={k:.2f} | \"\n",
        "            f\"self={robust_self[k]:.4f} | \"\n",
        "            f\"self+2hop={robust_s2[k]:.4f} | \"\n",
        "            f\"rand={robust_random[k]:.4f} | \"\n",
        "            f\"gradcam={robust_gradcam[k]:.4f} | \"\n",
        "            f\"lrp={robust_lrp[k]:.4f} | \"\n",
        "            f\"goat={robust_goat[k]:.4f} | \"\n",
        "            f\"lime={robust_lime[k]:.4f} | \"\n",
        "            f\"ig={robust_ig[k]:.4f} | \"\n",
        "        )\n",
        "\n",
        "    # print(\"\\nFidelity- (remove bottom-k% features):\")\n",
        "    # for k in args.k_fracs:\n",
        "    #     print(\n",
        "    #         f\"k={k:.2f} | \"\n",
        "    #         f\"self={fm_self[k]:.4f} | \"\n",
        "    #         f\"self+1hop={fm_s1[k]:.4f} | \"\n",
        "    #         f\"self+2hop={fm_s2[k]:.4f} | \"\n",
        "    #         f\"rand={fm_rand[k]:.4f} | \"\n",
        "    #         f\"gradcam={fm_gc[k]:.4f} | \"\n",
        "    #         f\"lrp={fm_lrp[k]:.4f} | \"\n",
        "    #         f\"goat={fm_goat[k]:.4f} | \"\n",
        "    #         f\"lime={fm_lime[k]:.4f} | \"\n",
        "    #         f\"ig={fm_ig[k]:.4f} | \"\n",
        "    #     )\n",
        "\n",
        "    # print(\"\\n=== Build-Time Summary (seconds) ===\")\n",
        "    # print(f\"{'method':20s} {'build_time':>12s}\")\n",
        "    # for name, t in method_times.items():\n",
        "    #     print(f\"{name:20s} {t:12.3f}\")\n",
        "    # print(frac_self)\n",
        "    # print(frac_s1)\n",
        "    # print(frac_s2)\n",
        "    # print(frac_rand)\n",
        "    # print(frac_gc)\n",
        "    # print(frac_lime)\n",
        "    #return frac_self\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "  main()"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## AMZ-PHOTO"
      ],
      "metadata": {
        "id": "ngHRDErv8TQd"
      }
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "IaTCL-1SC1Lu"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# -------------------------------------------------------------\n",
        "# 1. Dataset + splits\n",
        "# -------------------------------------------------------------\n",
        "def load_amazon_photo(\n",
        "    root: str = \"data/Amazon\",\n",
        "    percent_noisy: float = 0,              # <<< % of noisy features to add\n",
        "    seed: int = 1234,\n",
        "    add_noise = True\n",
        "):\n",
        "    dataset = Amazon(root=root, name=\"Photo\", transform=T.NormalizeFeatures())\n",
        "    data = dataset[0]\n",
        "\n",
        "    print(\"=== Amazon-Photo ===\")\n",
        "    print(f\"#Nodes     = {data.num_nodes}\")\n",
        "    print(f\"#Edges     = {data.num_edges}\")\n",
        "    print(f\"#Features  = {dataset.num_features}\")\n",
        "    print(f\"#Classes   = {dataset.num_classes}\")\n",
        "\n",
        "    # -----------------------------\n",
        "    # 1) Add Noisy Features (NEW)\n",
        "    # -----------------------------\n",
        "    if percent_noisy > 0:\n",
        "        print(f\"Adding {percent_noisy*100}% noisy features...\")\n",
        "        data, noisy_mask = augment_with_noisy_features(\n",
        "            data,\n",
        "            n_noisy=int(percent_noisy*dataset.num_features),\n",
        "            seed=seed\n",
        "        )\n",
        "    else:\n",
        "        noisy_mask = torch.zeros(dataset.num_features, dtype=torch.bool)\n",
        "    if add_noise:\n",
        "        data = make_noisy_data(data)\n",
        "        print(\"Added noise to features.\")\n",
        "    # -----------------------------\n",
        "    # 2) Split (same as before)\n",
        "    # -----------------------------\n",
        "    if not hasattr(data, \"train_mask\") or data.train_mask is None:\n",
        "        data = create_splits_stratified(\n",
        "            data, train_ratio=0.7, val_ratio=0.1, seed=seed\n",
        "        )\n",
        "\n",
        "    # -----------------------------\n",
        "    # Return updated feature dim\n",
        "    # -----------------------------\n",
        "    in_dim = data.x.size(1)\n",
        "    num_classes = dataset.num_classes\n",
        "\n",
        "    print(f\"Final feature dim after augmentation = {in_dim}\")\n",
        "\n",
        "    return data, in_dim, num_classes, noisy_mask\n",
        "\n",
        "def main():\n",
        "    method_times = {}\n",
        "    parser = argparse.ArgumentParser()\n",
        "    parser.add_argument(\"--seed\", type=int, default=42)\n",
        "    parser.add_argument(\"--hidden_dim\", type=int, default=64)\n",
        "    parser.add_argument(\"--dropout\", type=float, default=0.5)\n",
        "    parser.add_argument(\"--epochs\", type=int, default=500)\n",
        "    parser.add_argument(\"--lr\", type=float, default=0.05)\n",
        "    parser.add_argument(\"--weight_decay\", type=float, default=5e-4)\n",
        "    parser.add_argument(\"--root\", type=str, default=\"data/Amazon\")\n",
        "    parser.add_argument(\"--k_fracs\", type=float, nargs=\"+\",\n",
        "                        default=[0.02,0.05,0.1,0.2,0.95]) #0.95 sanity check\n",
        "    parser.add_argument(\"--num_explain\", type=int, default=100,\n",
        "                        help=\"how many test nodes to explain via decomposition\")\n",
        "    parser.add_argument(\"--use_true_class\", action=\"store_true\",\n",
        "                        help=\"use true label instead of predicted label as target class\")\n",
        "    parser.add_argument(\"--model_path\", type=str, default=\"amazon_photo_gcn.pt\",\n",
        "                        help=\"path to save/load trained GCN model\")\n",
        "    args = parser.parse_args(args=[])\n",
        "\n",
        "    set_seed(args.seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    # 1) Load data\n",
        "    data, in_dim, num_classes,_ = load_amazon_photo(args.root,add_noise=False)\n",
        "    data = data.to(device)\n",
        "\n",
        "    data_noise, _, _,_ = load_amazon_photo(args.root,add_noise=True)\n",
        "    data_noise = data_noise.to(device)\n",
        "\n",
        "    # 2) Build model\n",
        "    model = GCN(\n",
        "        in_channels=in_dim,\n",
        "        hidden_channels=args.hidden_dim,\n",
        "        out_channels=num_classes,\n",
        "        dropout=args.dropout,\n",
        "    ).to(device)\n",
        "\n",
        "    # 3) Train or load model\n",
        "    if os.path.exists(args.model_path):\n",
        "        print(f\"\\n=== Loading existing model from {args.model_path} ===\")\n",
        "        state = torch.load(args.model_path, map_location=device)\n",
        "        model.load_state_dict(state)\n",
        "        train_acc, val_acc, test_acc, _ = evaluate(model, data)\n",
        "        print(f\"Loaded model | Train={train_acc:.4f} Val={val_acc:.4f} Test={test_acc:.4f}\")\n",
        "    else:\n",
        "        print(\"\\n=== Training GCN (no saved model found) ===\")\n",
        "        optimizer = torch.optim.Adam(\n",
        "            model.parameters(),\n",
        "            lr=args.lr,\n",
        "            weight_decay=args.weight_decay,\n",
        "        )\n",
        "\n",
        "        best_val_acc = 0.0\n",
        "        best_state = None\n",
        "\n",
        "        for epoch in range(1, args.epochs + 1):\n",
        "            loss = train(model, data, optimizer)\n",
        "            train_acc, val_acc, test_acc, _ = evaluate(model, data)\n",
        "\n",
        "            if val_acc > best_val_acc:\n",
        "                best_val_acc = val_acc\n",
        "                best_state = {\n",
        "                    \"model\": model.state_dict(),\n",
        "                    \"epoch\": epoch,\n",
        "                    \"test_acc\": test_acc,\n",
        "                }\n",
        "\n",
        "            if epoch % 20 == 0 or epoch == 1:\n",
        "                print(\n",
        "                    f\"Epoch {epoch:03d} | \"\n",
        "                    f\"Loss {loss:.4f} | \"\n",
        "                    f\"Train {train_acc:.4f} | \"\n",
        "                    f\"Val {val_acc:.4f} | \"\n",
        "                    f\"Test {test_acc:.4f}\"\n",
        "                )\n",
        "\n",
        "        if best_state is not None:\n",
        "            model.load_state_dict(best_state[\"model\"])\n",
        "            print(\n",
        "                f\"\\nBest epoch = {best_state['epoch']} | \"\n",
        "                f\"Best val acc = {best_val_acc:.4f} | \"\n",
        "                f\"Test acc @best = {best_state['test_acc']:.4f}\"\n",
        "            )\n",
        "\n",
        "        # Save trained model\n",
        "        torch.save(model.state_dict(), args.model_path)\n",
        "        print(f\"Model saved to {args.model_path}\")\n",
        "\n",
        "    # 4) Build normalized adjacency list\n",
        "    print(\"\\n=== Building normalized adjacency list ===\")\n",
        "    adj_list = build_normalized_adjacency_list(model, data)\n",
        "\n",
        "    # 5) Choose test nodes to explain (correct predictions preferred)\n",
        "    model.eval()\n",
        "    logits = model(data.x, data.edge_index)\n",
        "    preds = logits.argmax(dim=-1)\n",
        "\n",
        "    test_nodes = torch.nonzero(data.test_mask, as_tuple=False).view(-1)\n",
        "    correct_mask = preds == data.y\n",
        "    test_correct_nodes = test_nodes[correct_mask[test_nodes]]\n",
        "\n",
        "    if test_correct_nodes.numel() == 0:\n",
        "        print(\"WARNING: no correctly classified test nodes, using all test nodes.\")\n",
        "        nodes_to_explain = test_nodes\n",
        "    else:\n",
        "        nodes_to_explain = test_correct_nodes\n",
        "\n",
        "    if nodes_to_explain.numel() > args.num_explain:\n",
        "        nodes_to_explain = nodes_to_explain[:args.num_explain]\n",
        "\n",
        "    print(\n",
        "        f\"\\n=== Explaining {nodes_to_explain.numel()} test nodes \"\n",
        "        f\"({'true' if args.use_true_class else 'pred'} class as target) ===\"\n",
        "    )\n",
        "\n",
        "    # 6) Build feature importance: self vs self+1hop vs self+2hop\n",
        "\n",
        "    feat_self, feat_self_1hop, feat_self_2hop = build_feature_importance_decomposition(\n",
        "        data=data,\n",
        "        model=model,\n",
        "        adj_list=adj_list,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        "        hop=2,\n",
        "        use_true_class=args.use_true_class,\n",
        "    )\n",
        "\n",
        "\n",
        "    feat_self_noise, feat_self_1hop_noise, feat_self_2hop_noise = build_feature_importance_decomposition(\n",
        "        data=data_noise,\n",
        "        model=model,\n",
        "        adj_list=adj_list,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        "        hop=2,\n",
        "        use_true_class=args.use_true_class,\n",
        "    )\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "#     # feat_self, feat_self_1hop, feat_self_2hop = build_feature_importance_decomposition(\n",
        "#     #     data=data,\n",
        "#     #     model=model,\n",
        "#     #     adj_list=adj_list,\n",
        "#     #     nodes_to_explain=nodes_to_explain,\n",
        "#     #     hop=2,\n",
        "#     #     use_true_class=args.use_true_class,\n",
        "#     # )\n",
        "#     # ---- Random baseline ----\n",
        "#     print(\"\\n=== Building RANDOM importance baseline ===\")\n",
        "\n",
        "#     t0 = _now()\n",
        "    feat_random = build_random_importance(\n",
        "        num_nodes=data.num_nodes,\n",
        "        num_features=data.x.size(1),\n",
        "        device=device,\n",
        "    )\n",
        "    feat_random_noise = build_random_importance(\n",
        "        num_nodes=data.num_nodes,\n",
        "        num_features=data.x.size(1),\n",
        "        device=device,\n",
        "    )\n",
        "#     t1 = _now()\n",
        "#     method_times[\"random\"] = t1 - t0\n",
        "\n",
        "# #     feat_random = build_random_importance(\n",
        "# #         num_nodes=data.num_nodes,\n",
        "# #         num_features=data.x.size(1),\n",
        "# #         device=device\n",
        "# # )\n",
        "#     print(\"\\n=== Building GRAD-Cam ===\")\n",
        "#     # feat_gradcam = gradcam_feature_importance(model, data, use_true_class=True)\n",
        "#     t0 = _now()\n",
        "    feat_gradcam = gradcam_feature_importance(\n",
        "        model=model,\n",
        "        data=data,\n",
        "        use_true_class=args.use_true_class,\n",
        "    )\n",
        "    feat_gradcam_noise = gradcam_feature_importance(\n",
        "        model=model,\n",
        "        data=data_noise,\n",
        "        use_true_class=args.use_true_class,\n",
        "    )\n",
        "#     t1 = _now()\n",
        "#     method_times[\"gradcam\"] = t1 - t0\n",
        "\n",
        "#     print(\"\\n=== Building GNN-LRP ===\")\n",
        "# #     feat_lrp = gnn_lrp_feature_importance(\n",
        "# #     model,\n",
        "# #     data,\n",
        "# #     eps=1e-6,\n",
        "# #     use_true_class=True,\n",
        "# #     nodes_to_explain=nodes_to_explain  # same subset you use for fidelity\n",
        "# # )\n",
        "#     t0 = _now()\n",
        "    feat_lrp = gnn_lrp_feature_importance(\n",
        "        model=model,\n",
        "        data=data,\n",
        "        eps=1e-6,\n",
        "        use_true_class=args.use_true_class,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        "    )\n",
        "    feat_lrp_noise = gnn_lrp_feature_importance(\n",
        "        model=model,\n",
        "        data=data_noise,\n",
        "        eps=1e-6,\n",
        "        use_true_class=args.use_true_class,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        "    )\n",
        "#     t1 = _now()\n",
        "#     method_times[\"lrp\"] = t1 - t0\n",
        "\n",
        "\n",
        "#     print(\"\\n=== Building GOAT ===\")\n",
        "#     # feat_goat = build_goat_feature_importance(\n",
        "#     #     data=data,\n",
        "#     #     model=model,\n",
        "#     #     nodes_to_explain=nodes_to_explain,\n",
        "#     # )\n",
        "#     t0 = _now()\n",
        "    feat_goat = build_goat_feature_importance(\n",
        "        data=data,\n",
        "        model=model,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        "    )\n",
        "    feat_goat_noise = build_goat_feature_importance(\n",
        "        data=data_noise,\n",
        "        model=model,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        "    )\n",
        "    # t1 = _now()\n",
        "#     method_times[\"goat\"] = t1 - t0\n",
        "\n",
        "#     print(\"\\n=== Building LIME ===\")\n",
        "#     # feat_lime = build_lime_importance(\n",
        "#     #     data=data,\n",
        "#     #     model=model,\n",
        "#     #     nodes_to_explain=nodes_to_explain,\n",
        "#     #     num_classes=num_classes,\n",
        "#     #     num_samples=50,\n",
        "#     # )\n",
        "#     # t0 = _now()\n",
        "    feat_lime = build_lime_importance(\n",
        "        data=data,\n",
        "        model=model,\n",
        "        nodes_to_explain=nodes_to_explain,  # or smaller subset\n",
        "        num_classes=num_classes,\n",
        "        num_samples=50,\n",
        "    )\n",
        "    feat_lime_noise = build_lime_importance(\n",
        "        data=data_noise,\n",
        "        model=model,\n",
        "        nodes_to_explain=nodes_to_explain,  # or smaller subset\n",
        "        num_classes=num_classes,\n",
        "        num_samples=50,\n",
        "    )\n",
        "#     t1 = _now()\n",
        "#     method_times[\"lime\"] = t1 - t0\n",
        "\n",
        "#     # GraphLime: if possible\n",
        "#     # print(\"\\n=== Building GraphLIME ===\")\n",
        "#     # t0 = _now()\n",
        "#     # feat_glime = build_graphlime_importance(\n",
        "#     # data=data,\n",
        "#     # model=model,\n",
        "#     # nodes_to_explain=nodes_to_explain,\n",
        "#     # hop=2,\n",
        "#     # rho=0.1,\n",
        "#     # )\n",
        "#     # t1 = _now()\n",
        "#     # method_times[\"graphlime\"] = t1 - t0\n",
        "\n",
        "\n",
        "#     #probably not use the following two ---\n",
        "\n",
        "#     #     # ---- GNNExplainer baseline importance ----\n",
        "#     # print(\"\\n=== Building GNNExplainer importance baseline ===\")\n",
        "#     # feat_gnnexp = build_gnnexplainer_importance(\n",
        "#     #     data=data,\n",
        "#     #     model=model,\n",
        "#     #     nodes_to_explain=nodes_to_explain,\n",
        "#     #     epochs=50,   # you can change this\n",
        "#     # )\n",
        "\n",
        "#     print(\"\\n=== Building Integrated Gradients (Captum) importance baseline ===\")\n",
        "#     t0 = _now()\n",
        "    feat_ig = build_ig_importance(\n",
        "        data=data,\n",
        "        model=model,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        ")\n",
        "    feat_ig_noise = build_ig_importance(\n",
        "        data=data_noise,\n",
        "        model=model,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        ")\n",
        "#     t1 = _now()\n",
        "#     method_times[\"ig\"] = t1 - t0\n",
        "\n",
        "    # 7) Recovery scores for each importance variant\n",
        "    print(\"\\n=== Computing Recovery scores ===\")\n",
        "\n",
        "    robust_self = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_self,\n",
        "    feat_imp_noisy= feat_self_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "    robust_s2 = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_self_2hop,\n",
        "    feat_imp_noisy= feat_self_2hop_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "    robust_gradcam = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_gradcam,\n",
        "    feat_imp_noisy= feat_gradcam_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "    robust_lrp = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_lrp,\n",
        "    feat_imp_noisy= feat_lrp_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "    robust_goat = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_goat,\n",
        "    feat_imp_noisy= feat_goat_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "    robust_random = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_random,\n",
        "    feat_imp_noisy= feat_random_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "    robust_lime = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_lime,\n",
        "    feat_imp_noisy= feat_lime_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "    robust_ig = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_ig,\n",
        "    feat_imp_noisy= feat_ig_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "\n",
        "    # frac_self = compute_noisy_fraction(\n",
        "    #   feat_imp=feat_self,       # [N, F_total]\n",
        "    #   noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #   nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #   k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #   label = 'SELF'\n",
        "    # )\n",
        "    # frac_s1 = compute_noisy_fraction(\n",
        "    #   feat_imp=feat_self_1hop,       # [N, F_total]\n",
        "    #   noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #   nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #   k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #   label = 'SELF + 1HOP'\n",
        "    # )\n",
        "    # frac_s2 = compute_noisy_fraction(\n",
        "    #   feat_imp=feat_self_2hop,       # [N, F_total]\n",
        "    #   noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #   nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #   k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #   label = 'SELF + 1HOP + 2HOP'\n",
        "    # )\n",
        "#     fp_s1, fm_s1, fk_s1 = compute_fidelity_scores(\n",
        "#         data=data,\n",
        "#         model=model,\n",
        "#         feat_imp=feat_self_1hop,\n",
        "#         nodes=nodes_to_explain,\n",
        "#         k_fracs=args.k_fracs,\n",
        "#         use_true_class=args.use_true_class,\n",
        "#         label=\"SELF + 1HOP\",\n",
        "#     )\n",
        "\n",
        "#     fp_s2, fm_s2, fk_s2 = compute_fidelity_scores(\n",
        "#         data=data,\n",
        "#         model=model,\n",
        "#         feat_imp=feat_self_2hop,\n",
        "#         nodes=nodes_to_explain,\n",
        "#         k_fracs=args.k_fracs,\n",
        "#         use_true_class=args.use_true_class,\n",
        "#         label=\"SELF + 1HOP + 2HOP\",\n",
        "#     )\n",
        "    # frac_rand = compute_noisy_fraction(\n",
        "    #   feat_imp=feat_random,       # [N, F_total]\n",
        "    #   noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #   nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #   k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #   label = 'RANDOM BASELINE'\n",
        "    # )\n",
        "    # frac_gc = compute_noisy_fraction(\n",
        "    #   feat_imp=feat_gradcam,       # [N, F_total]\n",
        "    #   noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #   nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #   k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #   label = 'GradCAM-feature'\n",
        "    # )\n",
        "#     fp_rand, fm_rand, fk_rand = compute_fidelity_scores(\n",
        "#     data=data,\n",
        "#     model=model,\n",
        "#     feat_imp=feat_random,\n",
        "#     nodes=nodes_to_explain,\n",
        "#     k_fracs=args.k_fracs,\n",
        "#     use_true_class=args.use_true_class,\n",
        "#     label=\"RANDOM BASELINE\",\n",
        "#     )\n",
        "\n",
        "#     fp_gc, fm_gc, fk_gc = compute_fidelity_scores(\n",
        "#     data=data,\n",
        "#     model=model,\n",
        "#     feat_imp=feat_gradcam,\n",
        "#     nodes=nodes_to_explain,\n",
        "#     k_fracs=args.k_fracs,\n",
        "#     use_true_class=args.use_true_class,\n",
        "#     label=\"GradCAM-feature\",\n",
        "# )\n",
        "\n",
        "#     fp_lrp, fm_lrp, fk_lrp = compute_fidelity_scores(\n",
        "#         data=data,\n",
        "#         model=model,\n",
        "#         feat_imp=feat_lrp,\n",
        "#         nodes=nodes_to_explain,\n",
        "#         k_fracs=args.k_fracs,\n",
        "#         use_true_class=args.use_true_class,\n",
        "#         label=\"GNN-LRP-feature\",\n",
        "#     )\n",
        "    # frac_lrp = compute_noisy_fraction(\n",
        "    #     feat_imp=feat_lrp,       # [N, F_total]\n",
        "    #     noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #     nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #     k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #     label = 'GradCAM-feature'\n",
        "    #   )\n",
        "\n",
        "#     fp_goat, fm_goat, fk_goat = compute_fidelity_scores(\n",
        "#         data=data,\n",
        "#         model=model,\n",
        "#         feat_imp=feat_goat,\n",
        "#         nodes=nodes_to_explain,\n",
        "#         k_fracs=args.k_fracs,\n",
        "#         use_true_class=args.use_true_class,\n",
        "#         label=\"GOAT-feature\",\n",
        "#     )\n",
        "    # frac_goat = compute_noisy_fraction(\n",
        "    #     feat_imp=feat_goat,       # [N, F_total]\n",
        "    #     noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #     nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #     k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #     label = 'GradCAM-feature'\n",
        "    #   )\n",
        "    # frac_lime = compute_noisy_fraction(\n",
        "    #     feat_imp=feat_lime,       # [N, F_total]\n",
        "    #     noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #     nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #     k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #     label = 'GradCAM-feature'\n",
        "      # )\n",
        "#     fp_lime, fm_lime, fk_lime = compute_fidelity_scores(\n",
        "#         data=data,\n",
        "#         model=model,\n",
        "#         feat_imp=feat_lime,\n",
        "#         nodes=nodes_to_explain,\n",
        "#         k_fracs=args.k_fracs,\n",
        "#         use_true_class=args.use_true_class,\n",
        "#         label=\"LIME\",\n",
        "#     )\n",
        "    # fp_glime, fm_glime, fk_glime = compute_fidelity_scores(\n",
        "    # data=data,\n",
        "    # model=model,\n",
        "    # feat_imp=feat_glime,\n",
        "    # nodes=nodes_to_explain,\n",
        "    # k_fracs=args.k_fracs,\n",
        "    # use_true_class=args.use_true_class,\n",
        "    # label=\"GraphLIME\",\n",
        "    # )\n",
        "    # fp_gnn, fm_gnn, fk_gnn = compute_fidelity_scores(\n",
        "    # data=data,\n",
        "    # model=model,\n",
        "    # feat_imp=feat_gnnexp,\n",
        "    # nodes=nodes_to_explain,\n",
        "    # k_fracs=args.k_fracs,\n",
        "    # use_true_class=args.use_true_class,\n",
        "    # label=\"GNNEXPLAINER BASELINE\",\n",
        "    # )\n",
        "  #   fp_ig, fm_ig, fk_ig = compute_fidelity_scores(\n",
        "  #   data=data,\n",
        "  #   model=model,\n",
        "  #   feat_imp=feat_ig,\n",
        "  #   nodes=nodes_to_explain,\n",
        "  #   k_fracs=args.k_fracs,\n",
        "  #   use_true_class=args.use_true_class,\n",
        "  #   label=\"IG (Captum) BASELINE\",\n",
        "  # )\n",
        "    # frac_ig = compute_noisy_fraction(\n",
        "    #     feat_imp=feat_ig,       # [N, F_total]\n",
        "    #     noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #     nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #     k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #     label = 'IG BASELINE'\n",
        "    #   )\n",
        "\n",
        "\n",
        "    # 8) Summary\n",
        "    print(\"\\n=== Summary: mean Δ = p_orig - p_mask ===\")\n",
        "    print(\"\\nRobustness (pick top-k% features):\")\n",
        "    for k in args.k_fracs:\n",
        "        print(\n",
        "            f\"k={k:.2f} | \"\n",
        "            f\"self={robust_self[k]:.4f} | \"\n",
        "            f\"self+2hop={robust_s2[k]:.4f} | \"\n",
        "            f\"rand={robust_random[k]:.4f} | \"\n",
        "            f\"gradcam={robust_gradcam[k]:.4f} | \"\n",
        "            f\"lrp={robust_lrp[k]:.4f} | \"\n",
        "            f\"goat={robust_goat[k]:.4f} | \"\n",
        "            f\"lime={robust_lime[k]:.4f} | \"\n",
        "            f\"ig={robust_ig[k]:.4f} | \"\n",
        "        )\n",
        "\n",
        "    # print(\"\\nFidelity- (remove bottom-k% features):\")\n",
        "    # for k in args.k_fracs:\n",
        "    #     print(\n",
        "    #         f\"k={k:.2f} | \"\n",
        "    #         f\"self={fm_self[k]:.4f} | \"\n",
        "    #         f\"self+1hop={fm_s1[k]:.4f} | \"\n",
        "    #         f\"self+2hop={fm_s2[k]:.4f} | \"\n",
        "    #         f\"rand={fm_rand[k]:.4f} | \"\n",
        "    #         f\"gradcam={fm_gc[k]:.4f} | \"\n",
        "    #         f\"lrp={fm_lrp[k]:.4f} | \"\n",
        "    #         f\"goat={fm_goat[k]:.4f} | \"\n",
        "    #         f\"lime={fm_lime[k]:.4f} | \"\n",
        "    #         f\"ig={fm_ig[k]:.4f} | \"\n",
        "    #     )\n",
        "\n",
        "    # print(\"\\n=== Build-Time Summary (seconds) ===\")\n",
        "    # print(f\"{'method':20s} {'build_time':>12s}\")\n",
        "    # for name, t in method_times.items():\n",
        "    #     print(f\"{name:20s} {t:12.3f}\")\n",
        "    # print(frac_self)\n",
        "    # print(frac_s1)\n",
        "    # print(frac_s2)\n",
        "    # print(frac_rand)\n",
        "    # print(frac_gc)\n",
        "    # print(frac_lime)\n",
        "    #return frac_self\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "  main()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "fwpjF5gw8XzP",
        "outputId": "48f1a2fb-99aa-4361-a63c-a0f14883fc25"
      },
      "execution_count": 54,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "=== Amazon-Photo ===\n",
            "#Nodes     = 7650\n",
            "#Edges     = 238162\n",
            "#Features  = 745\n",
            "#Classes   = 8\n",
            "#Train nodes = 5352\n",
            "#Val nodes   = 762\n",
            "#Test nodes  = 1536\n",
            "Final feature dim after augmentation = 745\n",
            "=== Amazon-Photo ===\n",
            "#Nodes     = 7650\n",
            "#Edges     = 238162\n",
            "#Features  = 745\n",
            "#Classes   = 8\n",
            "Added noise to features.\n",
            "#Train nodes = 5352\n",
            "#Val nodes   = 762\n",
            "#Test nodes  = 1536\n",
            "Final feature dim after augmentation = 745\n",
            "\n",
            "=== Loading existing model from amazon_photo_gcn.pt ===\n",
            "Loaded model | Train=0.8696 Val=0.8622 Test=0.8639\n",
            "\n",
            "=== Building normalized adjacency list ===\n",
            "\n",
            "=== Explaining 100 test nodes (pred class as target) ===\n",
            "Building feature importance via decomposition for 100 nodes...\n",
            "  processed 1/100 nodes\n",
            "  processed 20/100 nodes\n",
            "  processed 40/100 nodes\n",
            "  processed 60/100 nodes\n",
            "  processed 80/100 nodes\n",
            "  processed 100/100 nodes\n",
            "Building feature importance via decomposition for 100 nodes...\n",
            "  processed 1/100 nodes\n",
            "  processed 20/100 nodes\n",
            "  processed 40/100 nodes\n",
            "  processed 60/100 nodes\n",
            "  processed 80/100 nodes\n",
            "  processed 100/100 nodes\n",
            "Running GNN-LRP for 100 nodes...\n",
            "  GNN-LRP processed 1/100 nodes\n",
            "  GNN-LRP processed 10/100 nodes\n",
            "  GNN-LRP processed 20/100 nodes\n",
            "  GNN-LRP processed 30/100 nodes\n",
            "  GNN-LRP processed 40/100 nodes\n",
            "  GNN-LRP processed 50/100 nodes\n",
            "  GNN-LRP processed 60/100 nodes\n",
            "  GNN-LRP processed 70/100 nodes\n",
            "  GNN-LRP processed 80/100 nodes\n",
            "  GNN-LRP processed 90/100 nodes\n",
            "  GNN-LRP processed 100/100 nodes\n",
            "Running GNN-LRP for 100 nodes...\n",
            "  GNN-LRP processed 1/100 nodes\n",
            "  GNN-LRP processed 10/100 nodes\n",
            "  GNN-LRP processed 20/100 nodes\n",
            "  GNN-LRP processed 30/100 nodes\n",
            "  GNN-LRP processed 40/100 nodes\n",
            "  GNN-LRP processed 50/100 nodes\n",
            "  GNN-LRP processed 60/100 nodes\n",
            "  GNN-LRP processed 70/100 nodes\n",
            "  GNN-LRP processed 80/100 nodes\n",
            "  GNN-LRP processed 90/100 nodes\n",
            "  GNN-LRP processed 100/100 nodes\n",
            "Running GOAT for 100 nodes...\n",
            "  GOAT processed 1/100 nodes\n",
            "  GOAT processed 10/100 nodes\n",
            "  GOAT processed 20/100 nodes\n",
            "  GOAT processed 30/100 nodes\n",
            "  GOAT processed 40/100 nodes\n",
            "  GOAT processed 50/100 nodes\n",
            "  GOAT processed 60/100 nodes\n",
            "  GOAT processed 70/100 nodes\n",
            "  GOAT processed 80/100 nodes\n",
            "  GOAT processed 90/100 nodes\n",
            "  GOAT processed 100/100 nodes\n",
            "Running GOAT for 100 nodes...\n",
            "  GOAT processed 1/100 nodes\n",
            "  GOAT processed 10/100 nodes\n",
            "  GOAT processed 20/100 nodes\n",
            "  GOAT processed 30/100 nodes\n",
            "  GOAT processed 40/100 nodes\n",
            "  GOAT processed 50/100 nodes\n",
            "  GOAT processed 60/100 nodes\n",
            "  GOAT processed 70/100 nodes\n",
            "  GOAT processed 80/100 nodes\n",
            "  GOAT processed 90/100 nodes\n",
            "  GOAT processed 100/100 nodes\n",
            "Running LIME for 100 nodes...\n",
            "  LIME processed 1/100 nodes\n",
            "  LIME processed 10/100 nodes\n",
            "  LIME processed 20/100 nodes\n",
            "  LIME processed 30/100 nodes\n",
            "  LIME processed 40/100 nodes\n",
            "  LIME processed 50/100 nodes\n",
            "  LIME processed 60/100 nodes\n",
            "  LIME processed 70/100 nodes\n",
            "  LIME processed 80/100 nodes\n",
            "  LIME processed 90/100 nodes\n",
            "  LIME processed 100/100 nodes\n",
            "Running LIME for 100 nodes...\n",
            "  LIME processed 1/100 nodes\n",
            "  LIME processed 10/100 nodes\n",
            "  LIME processed 20/100 nodes\n",
            "  LIME processed 30/100 nodes\n",
            "  LIME processed 40/100 nodes\n",
            "  LIME processed 50/100 nodes\n",
            "  LIME processed 60/100 nodes\n",
            "  LIME processed 70/100 nodes\n",
            "  LIME processed 80/100 nodes\n",
            "  LIME processed 90/100 nodes\n",
            "  LIME processed 100/100 nodes\n",
            "Running Clean IG for 100 nodes...\n",
            "  IG processed 1/100 nodes\n",
            "  IG processed 10/100 nodes\n",
            "  IG processed 20/100 nodes\n",
            "  IG processed 30/100 nodes\n",
            "  IG processed 40/100 nodes\n",
            "  IG processed 50/100 nodes\n",
            "  IG processed 60/100 nodes\n",
            "  IG processed 70/100 nodes\n",
            "  IG processed 80/100 nodes\n",
            "  IG processed 90/100 nodes\n",
            "  IG processed 100/100 nodes\n",
            "Running Clean IG for 100 nodes...\n",
            "  IG processed 1/100 nodes\n",
            "  IG processed 10/100 nodes\n",
            "  IG processed 20/100 nodes\n",
            "  IG processed 30/100 nodes\n",
            "  IG processed 40/100 nodes\n",
            "  IG processed 50/100 nodes\n",
            "  IG processed 60/100 nodes\n",
            "  IG processed 70/100 nodes\n",
            "  IG processed 80/100 nodes\n",
            "  IG processed 90/100 nodes\n",
            "  IG processed 100/100 nodes\n",
            "\n",
            "=== Computing Recovery scores ===\n",
            "\n",
            "=== Summary: mean Δ = p_orig - p_mask ===\n",
            "\n",
            "Robustness (pick top-k% features):\n",
            "k=0.02 | self=0.9629 | self+2hop=0.9863 | rand=0.4941 | gradcam=0.9947 | lrp=0.6710 | goat=0.9939 | lime=0.3745 | ig=0.9631 | \n",
            "k=0.05 | self=0.9544 | self+2hop=0.9800 | rand=0.5087 | gradcam=0.9949 | lrp=0.6588 | goat=0.9927 | lime=0.3962 | ig=0.9545 | \n",
            "k=0.10 | self=0.9141 | self+2hop=0.9730 | rand=0.5215 | gradcam=0.9948 | lrp=0.6147 | goat=0.9915 | lime=0.4115 | ig=0.9144 | \n",
            "k=0.20 | self=0.8162 | self+2hop=0.9629 | rand=0.5414 | gradcam=0.9949 | lrp=0.5312 | goat=0.9899 | lime=0.4352 | ig=0.8167 | \n",
            "k=0.95 | self=0.3784 | self+2hop=0.8893 | rand=0.4231 | gradcam=0.9913 | lrp=0.2385 | goat=0.9614 | lime=0.3624 | ig=0.3791 | \n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## CiteSeer"
      ],
      "metadata": {
        "id": "F23Zjt9wr3eU"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# -------------------------------------------------------------\n",
        "# 1. Dataset + splits\n",
        "# -------------------------------------------------------------\n",
        "\n",
        "def create_splits(data, num_train_per_class=20, num_val_per_class=30):\n",
        "    y = data.y\n",
        "    num_nodes = data.num_nodes\n",
        "    num_classes = int(y.max().item() + 1)\n",
        "\n",
        "    train_mask = torch.zeros(num_nodes, dtype=torch.bool)\n",
        "    val_mask = torch.zeros(num_nodes, dtype=torch.bool)\n",
        "    test_mask = torch.zeros(num_nodes, dtype=torch.bool)\n",
        "\n",
        "    for c in range(num_classes):\n",
        "        idx = (y == c).nonzero(as_tuple=False).view(-1)\n",
        "        idx = idx[torch.randperm(idx.size(0))]\n",
        "\n",
        "        n_train = min(num_train_per_class, idx.size(0))\n",
        "        n_val = min(num_val_per_class, max(0, idx.size(0) - n_train))\n",
        "\n",
        "        train_idx = idx[:n_train]\n",
        "        val_idx = idx[n_train:n_train + n_val]\n",
        "        test_idx = idx[n_train + n_val:]\n",
        "\n",
        "        train_mask[train_idx] = True\n",
        "        val_mask[val_idx] = True\n",
        "        test_mask[test_idx] = True\n",
        "\n",
        "    data.train_mask = train_mask\n",
        "    data.val_mask = val_mask\n",
        "    data.test_mask = test_mask\n",
        "\n",
        "    print(f\"#Train nodes = {int(train_mask.sum())}\")\n",
        "    print(f\"#Val nodes   = {int(val_mask.sum())}\")\n",
        "    print(f\"#Test nodes  = {int(test_mask.sum())}\")\n",
        "\n",
        "    return data\n",
        "\n",
        "def load_planetoid(\n",
        "    name: str = 'Cora',\n",
        "    root: str = \"data/Amazon\",\n",
        "    percent_noisy: float = 0,              # <<< % of noisy features to add\n",
        "    seed: int = 1234,\n",
        "    add_noise = True\n",
        "):\n",
        "    dataset = Planetoid(\n",
        "        root=root,\n",
        "        name=name,\n",
        "        transform=T.NormalizeFeatures(),   # common preprocessing\n",
        "    )\n",
        "    data = dataset[0]\n",
        "\n",
        "    print(\"=== Planetoid-{name} ===\")\n",
        "    print(f\"#Nodes     = {data.num_nodes}\")\n",
        "    print(f\"#Edges     = {data.num_edges}\")\n",
        "    print(f\"#Features  = {dataset.num_features}\")\n",
        "    print(f\"#Classes   = {dataset.num_classes}\")\n",
        "\n",
        "    # -----------------------------\n",
        "    # 1) Add Noisy Features (NEW)\n",
        "    # -----------------------------\n",
        "    if percent_noisy > 0:\n",
        "        print(f\"Adding {percent_noisy*100}% noisy features...\")\n",
        "        data, noisy_mask = augment_with_noisy_features(\n",
        "            data,\n",
        "            n_noisy=int(percent_noisy*dataset.num_features),\n",
        "            seed=seed\n",
        "        )\n",
        "    else:\n",
        "        noisy_mask = torch.zeros(dataset.num_features, dtype=torch.bool)\n",
        "    if add_noise:\n",
        "        data = make_noisy_data(data)\n",
        "        print(\"Added noise to features.\")\n",
        "    # -----------------------------\n",
        "    # 2) Split (same as before)\n",
        "    # -----------------------------\n",
        "    if not hasattr(data, \"train_mask\") or data.train_mask is None:\n",
        "        data = create_splits(\n",
        "            data, num_train_per_class=20, num_val_per_class=30\n",
        "        )\n",
        "\n",
        "    # -----------------------------\n",
        "    # Return updated feature dim\n",
        "    # -----------------------------\n",
        "    in_dim = data.x.size(1)\n",
        "    num_classes = dataset.num_classes\n",
        "\n",
        "    print(f\"Final feature dim after augmentation = {in_dim}\")\n",
        "\n",
        "    return data, in_dim, num_classes, noisy_mask\n",
        "\n",
        "\n",
        "def main():\n",
        "    method_times = {}\n",
        "    parser = argparse.ArgumentParser()\n",
        "    parser.add_argument(\"--seed\", type=int, default=42)\n",
        "    parser.add_argument(\"--hidden_dim\", type=int, default=64)\n",
        "    parser.add_argument(\"--dropout\", type=float, default=0.5)\n",
        "    parser.add_argument(\"--epochs\", type=int, default=500)\n",
        "    parser.add_argument(\"--lr\", type=float, default=0.05)\n",
        "    parser.add_argument(\"--weight_decay\", type=float, default=5e-4)\n",
        "    parser.add_argument(\"--root\", type=str, default=\"data/Planetoid\")\n",
        "    parser.add_argument(\"--k_fracs\", type=float, nargs=\"+\",\n",
        "                        default=[0.02,0.05,0.1,0.2,0.95]) #0.95 sanity check\n",
        "    parser.add_argument(\"--num_explain\", type=int, default=100,\n",
        "                        help=\"how many test nodes to explain via decomposition\")\n",
        "    parser.add_argument(\"--use_true_class\", action=\"store_true\",\n",
        "                        help=\"use true label instead of predicted label as target class\")\n",
        "    parser.add_argument(\"--model_path\", type=str, default=\"citeseer_gcn.pt\",\n",
        "                        help=\"path to save/load trained GCN model\")\n",
        "    args = parser.parse_args(args=[])\n",
        "\n",
        "    set_seed(args.seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    # 1) Load data\n",
        "    data, in_dim, num_classes,_ = load_planetoid(name='CiteSeer',root=args.root,add_noise=False)\n",
        "    data = data.to(device)\n",
        "\n",
        "    data_noise, _, _,_ = load_planetoid(name='CiteSeer',root=args.root,add_noise=True)\n",
        "    data_noise = data_noise.to(device)\n",
        "\n",
        "    # 2) Build model\n",
        "    model = GCN(\n",
        "        in_channels=in_dim,\n",
        "        hidden_channels=args.hidden_dim,\n",
        "        out_channels=num_classes,\n",
        "        dropout=args.dropout,\n",
        "    ).to(device)\n",
        "\n",
        "    # 3) Train or load model\n",
        "    if os.path.exists(args.model_path):\n",
        "        print(f\"\\n=== Loading existing model from {args.model_path} ===\")\n",
        "        state = torch.load(args.model_path, map_location=device)\n",
        "        model.load_state_dict(state)\n",
        "        train_acc, val_acc, test_acc, _ = evaluate(model, data)\n",
        "        print(f\"Loaded model | Train={train_acc:.4f} Val={val_acc:.4f} Test={test_acc:.4f}\")\n",
        "    else:\n",
        "        print(\"\\n=== Training GCN (no saved model found) ===\")\n",
        "        optimizer = torch.optim.Adam(\n",
        "            model.parameters(),\n",
        "            lr=args.lr,\n",
        "            weight_decay=args.weight_decay,\n",
        "        )\n",
        "\n",
        "        best_val_acc = 0.0\n",
        "        best_state = None\n",
        "\n",
        "        for epoch in range(1, args.epochs + 1):\n",
        "            loss = train(model, data, optimizer)\n",
        "            train_acc, val_acc, test_acc, _ = evaluate(model, data)\n",
        "\n",
        "            if val_acc > best_val_acc:\n",
        "                best_val_acc = val_acc\n",
        "                best_state = {\n",
        "                    \"model\": model.state_dict(),\n",
        "                    \"epoch\": epoch,\n",
        "                    \"test_acc\": test_acc,\n",
        "                }\n",
        "\n",
        "            if epoch % 20 == 0 or epoch == 1:\n",
        "                print(\n",
        "                    f\"Epoch {epoch:03d} | \"\n",
        "                    f\"Loss {loss:.4f} | \"\n",
        "                    f\"Train {train_acc:.4f} | \"\n",
        "                    f\"Val {val_acc:.4f} | \"\n",
        "                    f\"Test {test_acc:.4f}\"\n",
        "                )\n",
        "\n",
        "        if best_state is not None:\n",
        "            model.load_state_dict(best_state[\"model\"])\n",
        "            print(\n",
        "                f\"\\nBest epoch = {best_state['epoch']} | \"\n",
        "                f\"Best val acc = {best_val_acc:.4f} | \"\n",
        "                f\"Test acc @best = {best_state['test_acc']:.4f}\"\n",
        "            )\n",
        "\n",
        "        # Save trained model\n",
        "        torch.save(model.state_dict(), args.model_path)\n",
        "        print(f\"Model saved to {args.model_path}\")\n",
        "\n",
        "    # 4) Build normalized adjacency list\n",
        "    print(\"\\n=== Building normalized adjacency list ===\")\n",
        "    adj_list = build_normalized_adjacency_list(model, data)\n",
        "\n",
        "    # 5) Choose test nodes to explain (correct predictions preferred)\n",
        "    model.eval()\n",
        "    logits = model(data.x, data.edge_index)\n",
        "    preds = logits.argmax(dim=-1)\n",
        "\n",
        "    test_nodes = torch.nonzero(data.test_mask, as_tuple=False).view(-1)\n",
        "    correct_mask = preds == data.y\n",
        "    test_correct_nodes = test_nodes[correct_mask[test_nodes]]\n",
        "\n",
        "    if test_correct_nodes.numel() == 0:\n",
        "        print(\"WARNING: no correctly classified test nodes, using all test nodes.\")\n",
        "        nodes_to_explain = test_nodes\n",
        "    else:\n",
        "        nodes_to_explain = test_correct_nodes\n",
        "\n",
        "    if nodes_to_explain.numel() > args.num_explain:\n",
        "        nodes_to_explain = nodes_to_explain[:args.num_explain]\n",
        "\n",
        "    print(\n",
        "        f\"\\n=== Explaining {nodes_to_explain.numel()} test nodes \"\n",
        "        f\"({'true' if args.use_true_class else 'pred'} class as target) ===\"\n",
        "    )\n",
        "\n",
        "    # 6) Build feature importance: self vs self+1hop vs self+2hop\n",
        "\n",
        "    feat_self, feat_self_1hop, feat_self_2hop = build_feature_importance_decomposition(\n",
        "        data=data,\n",
        "        model=model,\n",
        "        adj_list=adj_list,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        "        hop=2,\n",
        "        use_true_class=args.use_true_class,\n",
        "    )\n",
        "\n",
        "\n",
        "    feat_self_noise, feat_self_1hop_noise, feat_self_2hop_noise = build_feature_importance_decomposition(\n",
        "        data=data_noise,\n",
        "        model=model,\n",
        "        adj_list=adj_list,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        "        hop=2,\n",
        "        use_true_class=args.use_true_class,\n",
        "    )\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "#     # feat_self, feat_self_1hop, feat_self_2hop = build_feature_importance_decomposition(\n",
        "#     #     data=data,\n",
        "#     #     model=model,\n",
        "#     #     adj_list=adj_list,\n",
        "#     #     nodes_to_explain=nodes_to_explain,\n",
        "#     #     hop=2,\n",
        "#     #     use_true_class=args.use_true_class,\n",
        "#     # )\n",
        "#     # ---- Random baseline ----\n",
        "#     print(\"\\n=== Building RANDOM importance baseline ===\")\n",
        "\n",
        "#     t0 = _now()\n",
        "    feat_random = build_random_importance(\n",
        "        num_nodes=data.num_nodes,\n",
        "        num_features=data.x.size(1),\n",
        "        device=device,\n",
        "    )\n",
        "    feat_random_noise = build_random_importance(\n",
        "        num_nodes=data.num_nodes,\n",
        "        num_features=data.x.size(1),\n",
        "        device=device,\n",
        "    )\n",
        "#     t1 = _now()\n",
        "#     method_times[\"random\"] = t1 - t0\n",
        "\n",
        "# #     feat_random = build_random_importance(\n",
        "# #         num_nodes=data.num_nodes,\n",
        "# #         num_features=data.x.size(1),\n",
        "# #         device=device\n",
        "# # )\n",
        "#     print(\"\\n=== Building GRAD-Cam ===\")\n",
        "#     # feat_gradcam = gradcam_feature_importance(model, data, use_true_class=True)\n",
        "#     t0 = _now()\n",
        "    feat_gradcam = gradcam_feature_importance(\n",
        "        model=model,\n",
        "        data=data,\n",
        "        use_true_class=args.use_true_class,\n",
        "    )\n",
        "    feat_gradcam_noise = gradcam_feature_importance(\n",
        "        model=model,\n",
        "        data=data_noise,\n",
        "        use_true_class=args.use_true_class,\n",
        "    )\n",
        "#     t1 = _now()\n",
        "#     method_times[\"gradcam\"] = t1 - t0\n",
        "\n",
        "#     print(\"\\n=== Building GNN-LRP ===\")\n",
        "# #     feat_lrp = gnn_lrp_feature_importance(\n",
        "# #     model,\n",
        "# #     data,\n",
        "# #     eps=1e-6,\n",
        "# #     use_true_class=True,\n",
        "# #     nodes_to_explain=nodes_to_explain  # same subset you use for fidelity\n",
        "# # )\n",
        "#     t0 = _now()\n",
        "    feat_lrp = gnn_lrp_feature_importance(\n",
        "        model=model,\n",
        "        data=data,\n",
        "        eps=1e-6,\n",
        "        use_true_class=args.use_true_class,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        "    )\n",
        "    feat_lrp_noise = gnn_lrp_feature_importance(\n",
        "        model=model,\n",
        "        data=data_noise,\n",
        "        eps=1e-6,\n",
        "        use_true_class=args.use_true_class,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        "    )\n",
        "#     t1 = _now()\n",
        "#     method_times[\"lrp\"] = t1 - t0\n",
        "\n",
        "\n",
        "#     print(\"\\n=== Building GOAT ===\")\n",
        "#     # feat_goat = build_goat_feature_importance(\n",
        "#     #     data=data,\n",
        "#     #     model=model,\n",
        "#     #     nodes_to_explain=nodes_to_explain,\n",
        "#     # )\n",
        "#     t0 = _now()\n",
        "    feat_goat = build_goat_feature_importance(\n",
        "        data=data,\n",
        "        model=model,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        "    )\n",
        "    feat_goat_noise = build_goat_feature_importance(\n",
        "        data=data_noise,\n",
        "        model=model,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        "    )\n",
        "    # t1 = _now()\n",
        "#     method_times[\"goat\"] = t1 - t0\n",
        "\n",
        "#     print(\"\\n=== Building LIME ===\")\n",
        "#     # feat_lime = build_lime_importance(\n",
        "#     #     data=data,\n",
        "#     #     model=model,\n",
        "#     #     nodes_to_explain=nodes_to_explain,\n",
        "#     #     num_classes=num_classes,\n",
        "#     #     num_samples=50,\n",
        "#     # )\n",
        "#     # t0 = _now()\n",
        "    feat_lime = build_lime_importance(\n",
        "        data=data,\n",
        "        model=model,\n",
        "        nodes_to_explain=nodes_to_explain,  # or smaller subset\n",
        "        num_classes=num_classes,\n",
        "        num_samples=50,\n",
        "    )\n",
        "    feat_lime_noise = build_lime_importance(\n",
        "        data=data_noise,\n",
        "        model=model,\n",
        "        nodes_to_explain=nodes_to_explain,  # or smaller subset\n",
        "        num_classes=num_classes,\n",
        "        num_samples=50,\n",
        "    )\n",
        "#     t1 = _now()\n",
        "#     method_times[\"lime\"] = t1 - t0\n",
        "\n",
        "#     # GraphLime: if possible\n",
        "#     # print(\"\\n=== Building GraphLIME ===\")\n",
        "#     # t0 = _now()\n",
        "#     # feat_glime = build_graphlime_importance(\n",
        "#     # data=data,\n",
        "#     # model=model,\n",
        "#     # nodes_to_explain=nodes_to_explain,\n",
        "#     # hop=2,\n",
        "#     # rho=0.1,\n",
        "#     # )\n",
        "#     # t1 = _now()\n",
        "#     # method_times[\"graphlime\"] = t1 - t0\n",
        "\n",
        "\n",
        "#     #probably not use the following two ---\n",
        "\n",
        "#     #     # ---- GNNExplainer baseline importance ----\n",
        "#     # print(\"\\n=== Building GNNExplainer importance baseline ===\")\n",
        "#     # feat_gnnexp = build_gnnexplainer_importance(\n",
        "#     #     data=data,\n",
        "#     #     model=model,\n",
        "#     #     nodes_to_explain=nodes_to_explain,\n",
        "#     #     epochs=50,   # you can change this\n",
        "#     # )\n",
        "\n",
        "#     print(\"\\n=== Building Integrated Gradients (Captum) importance baseline ===\")\n",
        "#     t0 = _now()\n",
        "    feat_ig = build_ig_importance(\n",
        "        data=data,\n",
        "        model=model,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        ")\n",
        "    feat_ig_noise = build_ig_importance(\n",
        "        data=data_noise,\n",
        "        model=model,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        ")\n",
        "#     t1 = _now()\n",
        "#     method_times[\"ig\"] = t1 - t0\n",
        "\n",
        "    # 7) Recovery scores for each importance variant\n",
        "    print(\"\\n=== Computing Recovery scores ===\")\n",
        "\n",
        "    robust_self = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_self,\n",
        "    feat_imp_noisy= feat_self_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "    robust_s2 = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_self_2hop,\n",
        "    feat_imp_noisy= feat_self_2hop_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "    robust_gradcam = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_gradcam,\n",
        "    feat_imp_noisy= feat_gradcam_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "    robust_lrp = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_lrp,\n",
        "    feat_imp_noisy= feat_lrp_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "    robust_goat = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_goat,\n",
        "    feat_imp_noisy= feat_goat_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "    robust_random = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_random,\n",
        "    feat_imp_noisy= feat_random_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "    robust_lime = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_lime,\n",
        "    feat_imp_noisy= feat_lime_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "    robust_ig = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_ig,\n",
        "    feat_imp_noisy= feat_ig_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "\n",
        "    # frac_self = compute_noisy_fraction(\n",
        "    #   feat_imp=feat_self,       # [N, F_total]\n",
        "    #   noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #   nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #   k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #   label = 'SELF'\n",
        "    # )\n",
        "    # frac_s1 = compute_noisy_fraction(\n",
        "    #   feat_imp=feat_self_1hop,       # [N, F_total]\n",
        "    #   noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #   nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #   k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #   label = 'SELF + 1HOP'\n",
        "    # )\n",
        "    # frac_s2 = compute_noisy_fraction(\n",
        "    #   feat_imp=feat_self_2hop,       # [N, F_total]\n",
        "    #   noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #   nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #   k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #   label = 'SELF + 1HOP + 2HOP'\n",
        "    # )\n",
        "#     fp_s1, fm_s1, fk_s1 = compute_fidelity_scores(\n",
        "#         data=data,\n",
        "#         model=model,\n",
        "#         feat_imp=feat_self_1hop,\n",
        "#         nodes=nodes_to_explain,\n",
        "#         k_fracs=args.k_fracs,\n",
        "#         use_true_class=args.use_true_class,\n",
        "#         label=\"SELF + 1HOP\",\n",
        "#     )\n",
        "\n",
        "#     fp_s2, fm_s2, fk_s2 = compute_fidelity_scores(\n",
        "#         data=data,\n",
        "#         model=model,\n",
        "#         feat_imp=feat_self_2hop,\n",
        "#         nodes=nodes_to_explain,\n",
        "#         k_fracs=args.k_fracs,\n",
        "#         use_true_class=args.use_true_class,\n",
        "#         label=\"SELF + 1HOP + 2HOP\",\n",
        "#     )\n",
        "    # frac_rand = compute_noisy_fraction(\n",
        "    #   feat_imp=feat_random,       # [N, F_total]\n",
        "    #   noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #   nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #   k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #   label = 'RANDOM BASELINE'\n",
        "    # )\n",
        "    # frac_gc = compute_noisy_fraction(\n",
        "    #   feat_imp=feat_gradcam,       # [N, F_total]\n",
        "    #   noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #   nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #   k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #   label = 'GradCAM-feature'\n",
        "    # )\n",
        "#     fp_rand, fm_rand, fk_rand = compute_fidelity_scores(\n",
        "#     data=data,\n",
        "#     model=model,\n",
        "#     feat_imp=feat_random,\n",
        "#     nodes=nodes_to_explain,\n",
        "#     k_fracs=args.k_fracs,\n",
        "#     use_true_class=args.use_true_class,\n",
        "#     label=\"RANDOM BASELINE\",\n",
        "#     )\n",
        "\n",
        "#     fp_gc, fm_gc, fk_gc = compute_fidelity_scores(\n",
        "#     data=data,\n",
        "#     model=model,\n",
        "#     feat_imp=feat_gradcam,\n",
        "#     nodes=nodes_to_explain,\n",
        "#     k_fracs=args.k_fracs,\n",
        "#     use_true_class=args.use_true_class,\n",
        "#     label=\"GradCAM-feature\",\n",
        "# )\n",
        "\n",
        "#     fp_lrp, fm_lrp, fk_lrp = compute_fidelity_scores(\n",
        "#         data=data,\n",
        "#         model=model,\n",
        "#         feat_imp=feat_lrp,\n",
        "#         nodes=nodes_to_explain,\n",
        "#         k_fracs=args.k_fracs,\n",
        "#         use_true_class=args.use_true_class,\n",
        "#         label=\"GNN-LRP-feature\",\n",
        "#     )\n",
        "    # frac_lrp = compute_noisy_fraction(\n",
        "    #     feat_imp=feat_lrp,       # [N, F_total]\n",
        "    #     noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #     nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #     k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #     label = 'GradCAM-feature'\n",
        "    #   )\n",
        "\n",
        "#     fp_goat, fm_goat, fk_goat = compute_fidelity_scores(\n",
        "#         data=data,\n",
        "#         model=model,\n",
        "#         feat_imp=feat_goat,\n",
        "#         nodes=nodes_to_explain,\n",
        "#         k_fracs=args.k_fracs,\n",
        "#         use_true_class=args.use_true_class,\n",
        "#         label=\"GOAT-feature\",\n",
        "#     )\n",
        "    # frac_goat = compute_noisy_fraction(\n",
        "    #     feat_imp=feat_goat,       # [N, F_total]\n",
        "    #     noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #     nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #     k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #     label = 'GradCAM-feature'\n",
        "    #   )\n",
        "    # frac_lime = compute_noisy_fraction(\n",
        "    #     feat_imp=feat_lime,       # [N, F_total]\n",
        "    #     noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #     nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #     k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #     label = 'GradCAM-feature'\n",
        "      # )\n",
        "#     fp_lime, fm_lime, fk_lime = compute_fidelity_scores(\n",
        "#         data=data,\n",
        "#         model=model,\n",
        "#         feat_imp=feat_lime,\n",
        "#         nodes=nodes_to_explain,\n",
        "#         k_fracs=args.k_fracs,\n",
        "#         use_true_class=args.use_true_class,\n",
        "#         label=\"LIME\",\n",
        "#     )\n",
        "    # fp_glime, fm_glime, fk_glime = compute_fidelity_scores(\n",
        "    # data=data,\n",
        "    # model=model,\n",
        "    # feat_imp=feat_glime,\n",
        "    # nodes=nodes_to_explain,\n",
        "    # k_fracs=args.k_fracs,\n",
        "    # use_true_class=args.use_true_class,\n",
        "    # label=\"GraphLIME\",\n",
        "    # )\n",
        "    # fp_gnn, fm_gnn, fk_gnn = compute_fidelity_scores(\n",
        "    # data=data,\n",
        "    # model=model,\n",
        "    # feat_imp=feat_gnnexp,\n",
        "    # nodes=nodes_to_explain,\n",
        "    # k_fracs=args.k_fracs,\n",
        "    # use_true_class=args.use_true_class,\n",
        "    # label=\"GNNEXPLAINER BASELINE\",\n",
        "    # )\n",
        "  #   fp_ig, fm_ig, fk_ig = compute_fidelity_scores(\n",
        "  #   data=data,\n",
        "  #   model=model,\n",
        "  #   feat_imp=feat_ig,\n",
        "  #   nodes=nodes_to_explain,\n",
        "  #   k_fracs=args.k_fracs,\n",
        "  #   use_true_class=args.use_true_class,\n",
        "  #   label=\"IG (Captum) BASELINE\",\n",
        "  # )\n",
        "    # frac_ig = compute_noisy_fraction(\n",
        "    #     feat_imp=feat_ig,       # [N, F_total]\n",
        "    #     noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #     nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #     k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #     label = 'IG BASELINE'\n",
        "    #   )\n",
        "\n",
        "\n",
        "    # 8) Summary\n",
        "    print(\"\\n=== Summary: mean Δ = p_orig - p_mask ===\")\n",
        "    print(\"\\nRobustness (pick top-k% features):\")\n",
        "    for k in args.k_fracs:\n",
        "        print(\n",
        "            f\"k={k:.2f} | \"\n",
        "            f\"self={robust_self[k]:.4f} | \"\n",
        "            f\"self+2hop={robust_s2[k]:.4f} | \"\n",
        "            f\"rand={robust_random[k]:.4f} | \"\n",
        "            f\"gradcam={robust_gradcam[k]:.4f} | \"\n",
        "            f\"lrp={robust_lrp[k]:.4f} | \"\n",
        "            f\"goat={robust_goat[k]:.4f} | \"\n",
        "            f\"lime={robust_lime[k]:.4f} | \"\n",
        "            f\"ig={robust_ig[k]:.4f} | \"\n",
        "        )\n",
        "\n",
        "    # print(\"\\nFidelity- (remove bottom-k% features):\")\n",
        "    # for k in args.k_fracs:\n",
        "    #     print(\n",
        "    #         f\"k={k:.2f} | \"\n",
        "    #         f\"self={fm_self[k]:.4f} | \"\n",
        "    #         f\"self+1hop={fm_s1[k]:.4f} | \"\n",
        "    #         f\"self+2hop={fm_s2[k]:.4f} | \"\n",
        "    #         f\"rand={fm_rand[k]:.4f} | \"\n",
        "    #         f\"gradcam={fm_gc[k]:.4f} | \"\n",
        "    #         f\"lrp={fm_lrp[k]:.4f} | \"\n",
        "    #         f\"goat={fm_goat[k]:.4f} | \"\n",
        "    #         f\"lime={fm_lime[k]:.4f} | \"\n",
        "    #         f\"ig={fm_ig[k]:.4f} | \"\n",
        "    #     )\n",
        "\n",
        "    # print(\"\\n=== Build-Time Summary (seconds) ===\")\n",
        "    # print(f\"{'method':20s} {'build_time':>12s}\")\n",
        "    # for name, t in method_times.items():\n",
        "    #     print(f\"{name:20s} {t:12.3f}\")\n",
        "    # print(frac_self)\n",
        "    # print(frac_s1)\n",
        "    # print(frac_s2)\n",
        "    # print(frac_rand)\n",
        "    # print(frac_gc)\n",
        "    # print(frac_lime)\n",
        "    #return frac_self\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "  main()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "collapsed": true,
        "id": "qoarKn-Dr_ZN",
        "outputId": "a87dbcbd-bfb1-45ca-a7eb-00be1933b46b"
      },
      "execution_count": 55,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "=== Planetoid-{name} ===\n",
            "#Nodes     = 3327\n",
            "#Edges     = 9104\n",
            "#Features  = 3703\n",
            "#Classes   = 6\n",
            "Final feature dim after augmentation = 3703\n",
            "=== Planetoid-{name} ===\n",
            "#Nodes     = 3327\n",
            "#Edges     = 9104\n",
            "#Features  = 3703\n",
            "#Classes   = 6\n",
            "Added noise to features.\n",
            "Final feature dim after augmentation = 3703\n",
            "\n",
            "=== Loading existing model from citeseer_gcn.pt ===\n",
            "Loaded model | Train=1.0000 Val=0.7040 Test=0.6870\n",
            "\n",
            "=== Building normalized adjacency list ===\n",
            "\n",
            "=== Explaining 100 test nodes (pred class as target) ===\n",
            "Building feature importance via decomposition for 100 nodes...\n",
            "  processed 1/100 nodes\n",
            "  processed 20/100 nodes\n",
            "  processed 40/100 nodes\n",
            "  processed 60/100 nodes\n",
            "  processed 80/100 nodes\n",
            "  processed 100/100 nodes\n",
            "Building feature importance via decomposition for 100 nodes...\n",
            "  processed 1/100 nodes\n",
            "  processed 20/100 nodes\n",
            "  processed 40/100 nodes\n",
            "  processed 60/100 nodes\n",
            "  processed 80/100 nodes\n",
            "  processed 100/100 nodes\n",
            "Running GNN-LRP for 100 nodes...\n",
            "  GNN-LRP processed 1/100 nodes\n",
            "  GNN-LRP processed 10/100 nodes\n",
            "  GNN-LRP processed 20/100 nodes\n",
            "  GNN-LRP processed 30/100 nodes\n",
            "  GNN-LRP processed 40/100 nodes\n",
            "  GNN-LRP processed 50/100 nodes\n",
            "  GNN-LRP processed 60/100 nodes\n",
            "  GNN-LRP processed 70/100 nodes\n",
            "  GNN-LRP processed 80/100 nodes\n",
            "  GNN-LRP processed 90/100 nodes\n",
            "  GNN-LRP processed 100/100 nodes\n",
            "Running GNN-LRP for 100 nodes...\n",
            "  GNN-LRP processed 1/100 nodes\n",
            "  GNN-LRP processed 10/100 nodes\n",
            "  GNN-LRP processed 20/100 nodes\n",
            "  GNN-LRP processed 30/100 nodes\n",
            "  GNN-LRP processed 40/100 nodes\n",
            "  GNN-LRP processed 50/100 nodes\n",
            "  GNN-LRP processed 60/100 nodes\n",
            "  GNN-LRP processed 70/100 nodes\n",
            "  GNN-LRP processed 80/100 nodes\n",
            "  GNN-LRP processed 90/100 nodes\n",
            "  GNN-LRP processed 100/100 nodes\n",
            "Running GOAT for 100 nodes...\n",
            "  GOAT processed 1/100 nodes\n",
            "  GOAT processed 10/100 nodes\n",
            "  GOAT processed 20/100 nodes\n",
            "  GOAT processed 30/100 nodes\n",
            "  GOAT processed 40/100 nodes\n",
            "  GOAT processed 50/100 nodes\n",
            "  GOAT processed 60/100 nodes\n",
            "  GOAT processed 70/100 nodes\n",
            "  GOAT processed 80/100 nodes\n",
            "  GOAT processed 90/100 nodes\n",
            "  GOAT processed 100/100 nodes\n",
            "Running GOAT for 100 nodes...\n",
            "  GOAT processed 1/100 nodes\n",
            "  GOAT processed 10/100 nodes\n",
            "  GOAT processed 20/100 nodes\n",
            "  GOAT processed 30/100 nodes\n",
            "  GOAT processed 40/100 nodes\n",
            "  GOAT processed 50/100 nodes\n",
            "  GOAT processed 60/100 nodes\n",
            "  GOAT processed 70/100 nodes\n",
            "  GOAT processed 80/100 nodes\n",
            "  GOAT processed 90/100 nodes\n",
            "  GOAT processed 100/100 nodes\n",
            "Running LIME for 100 nodes...\n",
            "  LIME processed 1/100 nodes\n",
            "  LIME processed 10/100 nodes\n",
            "  LIME processed 20/100 nodes\n",
            "  LIME processed 30/100 nodes\n",
            "  LIME processed 40/100 nodes\n",
            "  LIME processed 50/100 nodes\n",
            "  LIME processed 60/100 nodes\n",
            "  LIME processed 70/100 nodes\n",
            "  LIME processed 80/100 nodes\n",
            "  LIME processed 90/100 nodes\n",
            "  LIME processed 100/100 nodes\n",
            "Running LIME for 100 nodes...\n",
            "  LIME processed 1/100 nodes\n",
            "  LIME processed 10/100 nodes\n",
            "  LIME processed 20/100 nodes\n",
            "  LIME processed 30/100 nodes\n",
            "  LIME processed 40/100 nodes\n",
            "  LIME processed 50/100 nodes\n",
            "  LIME processed 60/100 nodes\n",
            "  LIME processed 70/100 nodes\n",
            "  LIME processed 80/100 nodes\n",
            "  LIME processed 90/100 nodes\n",
            "  LIME processed 100/100 nodes\n",
            "Running Clean IG for 100 nodes...\n",
            "  IG processed 1/100 nodes\n",
            "  IG processed 10/100 nodes\n",
            "  IG processed 20/100 nodes\n",
            "  IG processed 30/100 nodes\n",
            "  IG processed 40/100 nodes\n",
            "  IG processed 50/100 nodes\n",
            "  IG processed 60/100 nodes\n",
            "  IG processed 70/100 nodes\n",
            "  IG processed 80/100 nodes\n",
            "  IG processed 90/100 nodes\n",
            "  IG processed 100/100 nodes\n",
            "Running Clean IG for 100 nodes...\n",
            "  IG processed 1/100 nodes\n",
            "  IG processed 10/100 nodes\n",
            "  IG processed 20/100 nodes\n",
            "  IG processed 30/100 nodes\n",
            "  IG processed 40/100 nodes\n",
            "  IG processed 50/100 nodes\n",
            "  IG processed 60/100 nodes\n",
            "  IG processed 70/100 nodes\n",
            "  IG processed 80/100 nodes\n",
            "  IG processed 90/100 nodes\n",
            "  IG processed 100/100 nodes\n",
            "\n",
            "=== Computing Recovery scores ===\n",
            "\n",
            "=== Summary: mean Δ = p_orig - p_mask ===\n",
            "\n",
            "Robustness (pick top-k% features):\n",
            "k=0.02 | self=0.5204 | self+2hop=0.9215 | rand=0.5043 | gradcam=0.9905 | lrp=0.3292 | goat=0.9456 | lime=0.0000 | ig=0.5235 | \n",
            "k=0.05 | self=0.2716 | self+2hop=0.7820 | rand=0.5119 | gradcam=0.9900 | lrp=0.1929 | goat=0.8350 | lime=0.0000 | ig=0.2728 | \n",
            "k=0.10 | self=0.1776 | self+2hop=0.5861 | rand=0.5229 | gradcam=0.9887 | lrp=0.1371 | goat=0.6493 | lime=0.0000 | ig=0.1783 | \n",
            "k=0.20 | self=0.1439 | self+2hop=0.3750 | rand=0.5416 | gradcam=0.9866 | lrp=0.1226 | goat=0.4317 | lime=0.0000 | ig=0.1442 | \n",
            "k=0.95 | self=0.1208 | self+2hop=0.1735 | rand=0.4230 | gradcam=0.9779 | lrp=0.1144 | goat=0.1898 | lime=0.0000 | ig=0.1209 | \n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## PubMed"
      ],
      "metadata": {
        "id": "-EJD8zd7C1yW"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# -------------------------------------------------------------\n",
        "# 1. Dataset + splits\n",
        "# -------------------------------------------------------------\n",
        "from torch_geometric.datasets import Planetoid\n",
        "def create_splits(data, num_train_per_class=20, num_val_per_class=30):\n",
        "    y = data.y\n",
        "    num_nodes = data.num_nodes\n",
        "    num_classes = int(y.max().item() + 1)\n",
        "\n",
        "    train_mask = torch.zeros(num_nodes, dtype=torch.bool)\n",
        "    val_mask = torch.zeros(num_nodes, dtype=torch.bool)\n",
        "    test_mask = torch.zeros(num_nodes, dtype=torch.bool)\n",
        "\n",
        "    for c in range(num_classes):\n",
        "        idx = (y == c).nonzero(as_tuple=False).view(-1)\n",
        "        idx = idx[torch.randperm(idx.size(0))]\n",
        "\n",
        "        n_train = min(num_train_per_class, idx.size(0))\n",
        "        n_val = min(num_val_per_class, max(0, idx.size(0) - n_train))\n",
        "\n",
        "        train_idx = idx[:n_train]\n",
        "        val_idx = idx[n_train:n_train + n_val]\n",
        "        test_idx = idx[n_train + n_val:]\n",
        "\n",
        "        train_mask[train_idx] = True\n",
        "        val_mask[val_idx] = True\n",
        "        test_mask[test_idx] = True\n",
        "\n",
        "    data.train_mask = train_mask\n",
        "    data.val_mask = val_mask\n",
        "    data.test_mask = test_mask\n",
        "\n",
        "    print(f\"#Train nodes = {int(train_mask.sum())}\")\n",
        "    print(f\"#Val nodes   = {int(val_mask.sum())}\")\n",
        "    print(f\"#Test nodes  = {int(test_mask.sum())}\")\n",
        "\n",
        "    return data\n",
        "\n",
        "def load_planetoid(\n",
        "    name: str = 'Cora',\n",
        "    root: str = \"data/Amazon\",\n",
        "    percent_noisy: float = 0,              # <<< % of noisy features to add\n",
        "    seed: int = 1234,\n",
        "    add_noise = True\n",
        "):\n",
        "    dataset = Planetoid(\n",
        "        root=root,\n",
        "        name=name,\n",
        "        transform=T.NormalizeFeatures(),   # common preprocessing\n",
        "    )\n",
        "    data = dataset[0]\n",
        "\n",
        "    print(f\"=== Planetoid-{name} ===\")\n",
        "    print(f\"#Nodes     = {data.num_nodes}\")\n",
        "    print(f\"#Edges     = {data.num_edges}\")\n",
        "    print(f\"#Features  = {dataset.num_features}\")\n",
        "    print(f\"#Classes   = {dataset.num_classes}\")\n",
        "\n",
        "    # -----------------------------\n",
        "    # 1) Add Noisy Features (NEW)\n",
        "    # -----------------------------\n",
        "    if percent_noisy > 0:\n",
        "        print(f\"Adding {percent_noisy*100}% noisy features...\")\n",
        "        data, noisy_mask = augment_with_noisy_features(\n",
        "            data,\n",
        "            n_noisy=int(percent_noisy*dataset.num_features),\n",
        "            seed=seed\n",
        "        )\n",
        "    else:\n",
        "        noisy_mask = torch.zeros(dataset.num_features, dtype=torch.bool)\n",
        "    if add_noise:\n",
        "        data = make_noisy_data(data)\n",
        "        print(\"Added noise to features.\")\n",
        "    # -----------------------------\n",
        "    # 2) Split (same as before)\n",
        "    # -----------------------------\n",
        "    if not hasattr(data, \"train_mask\") or data.train_mask is None:\n",
        "        data = create_splits(\n",
        "            data, num_train_per_class=20, num_val_per_class=30\n",
        "        )\n",
        "\n",
        "    # -----------------------------\n",
        "    # Return updated feature dim\n",
        "    # -----------------------------\n",
        "    in_dim = data.x.size(1)\n",
        "    num_classes = dataset.num_classes\n",
        "\n",
        "    print(f\"Final feature dim after augmentation = {in_dim}\")\n",
        "\n",
        "    return data, in_dim, num_classes, noisy_mask\n",
        "\n",
        "\n",
        "def main():\n",
        "    method_times = {}\n",
        "    parser = argparse.ArgumentParser()\n",
        "    parser.add_argument(\"--seed\", type=int, default=42)\n",
        "    parser.add_argument(\"--hidden_dim\", type=int, default=64)\n",
        "    parser.add_argument(\"--dropout\", type=float, default=0.5)\n",
        "    parser.add_argument(\"--epochs\", type=int, default=500)\n",
        "    parser.add_argument(\"--lr\", type=float, default=0.05)\n",
        "    parser.add_argument(\"--weight_decay\", type=float, default=5e-4)\n",
        "    parser.add_argument(\"--root\", type=str, default=\"data/Planetoid\")\n",
        "    parser.add_argument(\"--k_fracs\", type=float, nargs=\"+\",\n",
        "                        default=[0.02,0.05,0.1,0.2,0.95]) #0.95 sanity check\n",
        "    parser.add_argument(\"--num_explain\", type=int, default=100,\n",
        "                        help=\"how many test nodes to explain via decomposition\")\n",
        "    parser.add_argument(\"--use_true_class\", action=\"store_true\",\n",
        "                        help=\"use true label instead of predicted label as target class\")\n",
        "    parser.add_argument(\"--model_path\", type=str, default=\"pubmed_gcn.pt\",\n",
        "                        help=\"path to save/load trained GCN model\")\n",
        "    args = parser.parse_args(args=[])\n",
        "\n",
        "    set_seed(args.seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    # 1) Load data\n",
        "    data, in_dim, num_classes,_ = load_planetoid(name='PubMed',root=args.root,add_noise=False)\n",
        "    data = data.to(device)\n",
        "\n",
        "    data_noise, _, _,_ = load_planetoid(name='PubMed',root=args.root,add_noise=True)\n",
        "    data_noise = data_noise.to(device)\n",
        "\n",
        "    # 2) Build model\n",
        "    model = GCN(\n",
        "        in_channels=in_dim,\n",
        "        hidden_channels=args.hidden_dim,\n",
        "        out_channels=num_classes,\n",
        "        dropout=args.dropout,\n",
        "    ).to(device)\n",
        "\n",
        "    # 3) Train or load model\n",
        "    if os.path.exists(args.model_path):\n",
        "        print(f\"\\n=== Loading existing model from {args.model_path} ===\")\n",
        "        state = torch.load(args.model_path, map_location=device)\n",
        "        model.load_state_dict(state)\n",
        "        train_acc, val_acc, test_acc, _ = evaluate(model, data)\n",
        "        print(f\"Loaded model | Train={train_acc:.4f} Val={val_acc:.4f} Test={test_acc:.4f}\")\n",
        "    else:\n",
        "        print(\"\\n=== Training GCN (no saved model found) ===\")\n",
        "        optimizer = torch.optim.Adam(\n",
        "            model.parameters(),\n",
        "            lr=args.lr,\n",
        "            weight_decay=args.weight_decay,\n",
        "        )\n",
        "\n",
        "        best_val_acc = 0.0\n",
        "        best_state = None\n",
        "\n",
        "        for epoch in range(1, args.epochs + 1):\n",
        "            loss = train(model, data, optimizer)\n",
        "            train_acc, val_acc, test_acc, _ = evaluate(model, data)\n",
        "\n",
        "            if val_acc > best_val_acc:\n",
        "                best_val_acc = val_acc\n",
        "                best_state = {\n",
        "                    \"model\": model.state_dict(),\n",
        "                    \"epoch\": epoch,\n",
        "                    \"test_acc\": test_acc,\n",
        "                }\n",
        "\n",
        "            if epoch % 20 == 0 or epoch == 1:\n",
        "                print(\n",
        "                    f\"Epoch {epoch:03d} | \"\n",
        "                    f\"Loss {loss:.4f} | \"\n",
        "                    f\"Train {train_acc:.4f} | \"\n",
        "                    f\"Val {val_acc:.4f} | \"\n",
        "                    f\"Test {test_acc:.4f}\"\n",
        "                )\n",
        "\n",
        "        if best_state is not None:\n",
        "            model.load_state_dict(best_state[\"model\"])\n",
        "            print(\n",
        "                f\"\\nBest epoch = {best_state['epoch']} | \"\n",
        "                f\"Best val acc = {best_val_acc:.4f} | \"\n",
        "                f\"Test acc @best = {best_state['test_acc']:.4f}\"\n",
        "            )\n",
        "\n",
        "        # Save trained model\n",
        "        torch.save(model.state_dict(), args.model_path)\n",
        "        print(f\"Model saved to {args.model_path}\")\n",
        "\n",
        "    # 4) Build normalized adjacency list\n",
        "    print(\"\\n=== Building normalized adjacency list ===\")\n",
        "    adj_list = build_normalized_adjacency_list(model, data)\n",
        "\n",
        "    # 5) Choose test nodes to explain (correct predictions preferred)\n",
        "    model.eval()\n",
        "    logits = model(data.x, data.edge_index)\n",
        "    preds = logits.argmax(dim=-1)\n",
        "\n",
        "    test_nodes = torch.nonzero(data.test_mask, as_tuple=False).view(-1)\n",
        "    correct_mask = preds == data.y\n",
        "    test_correct_nodes = test_nodes[correct_mask[test_nodes]]\n",
        "\n",
        "    if test_correct_nodes.numel() == 0:\n",
        "        print(\"WARNING: no correctly classified test nodes, using all test nodes.\")\n",
        "        nodes_to_explain = test_nodes\n",
        "    else:\n",
        "        nodes_to_explain = test_correct_nodes\n",
        "\n",
        "    if nodes_to_explain.numel() > args.num_explain:\n",
        "        nodes_to_explain = nodes_to_explain[:args.num_explain]\n",
        "\n",
        "    print(\n",
        "        f\"\\n=== Explaining {nodes_to_explain.numel()} test nodes \"\n",
        "        f\"({'true' if args.use_true_class else 'pred'} class as target) ===\"\n",
        "    )\n",
        "\n",
        "    # 6) Build feature importance: self vs self+1hop vs self+2hop\n",
        "\n",
        "    feat_self, feat_self_1hop, feat_self_2hop = build_feature_importance_decomposition(\n",
        "        data=data,\n",
        "        model=model,\n",
        "        adj_list=adj_list,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        "        hop=2,\n",
        "        use_true_class=args.use_true_class,\n",
        "    )\n",
        "\n",
        "\n",
        "    feat_self_noise, feat_self_1hop_noise, feat_self_2hop_noise = build_feature_importance_decomposition(\n",
        "        data=data_noise,\n",
        "        model=model,\n",
        "        adj_list=adj_list,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        "        hop=2,\n",
        "        use_true_class=args.use_true_class,\n",
        "    )\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "#     # feat_self, feat_self_1hop, feat_self_2hop = build_feature_importance_decomposition(\n",
        "#     #     data=data,\n",
        "#     #     model=model,\n",
        "#     #     adj_list=adj_list,\n",
        "#     #     nodes_to_explain=nodes_to_explain,\n",
        "#     #     hop=2,\n",
        "#     #     use_true_class=args.use_true_class,\n",
        "#     # )\n",
        "#     # ---- Random baseline ----\n",
        "#     print(\"\\n=== Building RANDOM importance baseline ===\")\n",
        "\n",
        "#     t0 = _now()\n",
        "    feat_random = build_random_importance(\n",
        "        num_nodes=data.num_nodes,\n",
        "        num_features=data.x.size(1),\n",
        "        device=device,\n",
        "    )\n",
        "    feat_random_noise = build_random_importance(\n",
        "        num_nodes=data.num_nodes,\n",
        "        num_features=data.x.size(1),\n",
        "        device=device,\n",
        "    )\n",
        "#     t1 = _now()\n",
        "#     method_times[\"random\"] = t1 - t0\n",
        "\n",
        "# #     feat_random = build_random_importance(\n",
        "# #         num_nodes=data.num_nodes,\n",
        "# #         num_features=data.x.size(1),\n",
        "# #         device=device\n",
        "# # )\n",
        "#     print(\"\\n=== Building GRAD-Cam ===\")\n",
        "#     # feat_gradcam = gradcam_feature_importance(model, data, use_true_class=True)\n",
        "#     t0 = _now()\n",
        "    feat_gradcam = gradcam_feature_importance(\n",
        "        model=model,\n",
        "        data=data,\n",
        "        use_true_class=args.use_true_class,\n",
        "    )\n",
        "    feat_gradcam_noise = gradcam_feature_importance(\n",
        "        model=model,\n",
        "        data=data_noise,\n",
        "        use_true_class=args.use_true_class,\n",
        "    )\n",
        "#     t1 = _now()\n",
        "#     method_times[\"gradcam\"] = t1 - t0\n",
        "\n",
        "#     print(\"\\n=== Building GNN-LRP ===\")\n",
        "# #     feat_lrp = gnn_lrp_feature_importance(\n",
        "# #     model,\n",
        "# #     data,\n",
        "# #     eps=1e-6,\n",
        "# #     use_true_class=True,\n",
        "# #     nodes_to_explain=nodes_to_explain  # same subset you use for fidelity\n",
        "# # )\n",
        "#     t0 = _now()\n",
        "    feat_lrp = gnn_lrp_feature_importance(\n",
        "        model=model,\n",
        "        data=data,\n",
        "        eps=1e-6,\n",
        "        use_true_class=args.use_true_class,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        "    )\n",
        "    feat_lrp_noise = gnn_lrp_feature_importance(\n",
        "        model=model,\n",
        "        data=data_noise,\n",
        "        eps=1e-6,\n",
        "        use_true_class=args.use_true_class,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        "    )\n",
        "#     t1 = _now()\n",
        "#     method_times[\"lrp\"] = t1 - t0\n",
        "\n",
        "\n",
        "#     print(\"\\n=== Building GOAT ===\")\n",
        "#     # feat_goat = build_goat_feature_importance(\n",
        "#     #     data=data,\n",
        "#     #     model=model,\n",
        "#     #     nodes_to_explain=nodes_to_explain,\n",
        "#     # )\n",
        "#     t0 = _now()\n",
        "    feat_goat = build_goat_feature_importance(\n",
        "        data=data,\n",
        "        model=model,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        "    )\n",
        "    feat_goat_noise = build_goat_feature_importance(\n",
        "        data=data_noise,\n",
        "        model=model,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        "    )\n",
        "    # t1 = _now()\n",
        "#     method_times[\"goat\"] = t1 - t0\n",
        "\n",
        "#     print(\"\\n=== Building LIME ===\")\n",
        "#     # feat_lime = build_lime_importance(\n",
        "#     #     data=data,\n",
        "#     #     model=model,\n",
        "#     #     nodes_to_explain=nodes_to_explain,\n",
        "#     #     num_classes=num_classes,\n",
        "#     #     num_samples=50,\n",
        "#     # )\n",
        "#     # t0 = _now()\n",
        "    feat_lime = build_lime_importance(\n",
        "        data=data,\n",
        "        model=model,\n",
        "        nodes_to_explain=nodes_to_explain,  # or smaller subset\n",
        "        num_classes=num_classes,\n",
        "        num_samples=50,\n",
        "    )\n",
        "    feat_lime_noise = build_lime_importance(\n",
        "        data=data_noise,\n",
        "        model=model,\n",
        "        nodes_to_explain=nodes_to_explain,  # or smaller subset\n",
        "        num_classes=num_classes,\n",
        "        num_samples=50,\n",
        "    )\n",
        "#     t1 = _now()\n",
        "#     method_times[\"lime\"] = t1 - t0\n",
        "\n",
        "#     # GraphLime: if possible\n",
        "#     # print(\"\\n=== Building GraphLIME ===\")\n",
        "#     # t0 = _now()\n",
        "#     # feat_glime = build_graphlime_importance(\n",
        "#     # data=data,\n",
        "#     # model=model,\n",
        "#     # nodes_to_explain=nodes_to_explain,\n",
        "#     # hop=2,\n",
        "#     # rho=0.1,\n",
        "#     # )\n",
        "#     # t1 = _now()\n",
        "#     # method_times[\"graphlime\"] = t1 - t0\n",
        "\n",
        "\n",
        "#     #probably not use the following two ---\n",
        "\n",
        "#     #     # ---- GNNExplainer baseline importance ----\n",
        "#     # print(\"\\n=== Building GNNExplainer importance baseline ===\")\n",
        "#     # feat_gnnexp = build_gnnexplainer_importance(\n",
        "#     #     data=data,\n",
        "#     #     model=model,\n",
        "#     #     nodes_to_explain=nodes_to_explain,\n",
        "#     #     epochs=50,   # you can change this\n",
        "#     # )\n",
        "\n",
        "#     print(\"\\n=== Building Integrated Gradients (Captum) importance baseline ===\")\n",
        "#     t0 = _now()\n",
        "    feat_ig = build_ig_importance(\n",
        "        data=data,\n",
        "        model=model,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        ")\n",
        "    feat_ig_noise = build_ig_importance(\n",
        "        data=data_noise,\n",
        "        model=model,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        ")\n",
        "#     t1 = _now()\n",
        "#     method_times[\"ig\"] = t1 - t0\n",
        "\n",
        "    # 7) Recovery scores for each importance variant\n",
        "    print(\"\\n=== Computing Recovery scores ===\")\n",
        "\n",
        "    robust_self = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_self,\n",
        "    feat_imp_noisy= feat_self_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "    robust_s2 = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_self_2hop,\n",
        "    feat_imp_noisy= feat_self_2hop_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "    robust_gradcam = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_gradcam,\n",
        "    feat_imp_noisy= feat_gradcam_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "    robust_lrp = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_lrp,\n",
        "    feat_imp_noisy= feat_lrp_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "    robust_goat = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_goat,\n",
        "    feat_imp_noisy= feat_goat_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "    robust_random = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_random,\n",
        "    feat_imp_noisy= feat_random_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "    robust_lime = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_lime,\n",
        "    feat_imp_noisy= feat_lime_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "    robust_ig = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_ig,\n",
        "    feat_imp_noisy= feat_ig_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "\n",
        "    # frac_self = compute_noisy_fraction(\n",
        "    #   feat_imp=feat_self,       # [N, F_total]\n",
        "    #   noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #   nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #   k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #   label = 'SELF'\n",
        "    # )\n",
        "    # frac_s1 = compute_noisy_fraction(\n",
        "    #   feat_imp=feat_self_1hop,       # [N, F_total]\n",
        "    #   noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #   nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #   k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #   label = 'SELF + 1HOP'\n",
        "    # )\n",
        "    # frac_s2 = compute_noisy_fraction(\n",
        "    #   feat_imp=feat_self_2hop,       # [N, F_total]\n",
        "    #   noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #   nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #   k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #   label = 'SELF + 1HOP + 2HOP'\n",
        "    # )\n",
        "#     fp_s1, fm_s1, fk_s1 = compute_fidelity_scores(\n",
        "#         data=data,\n",
        "#         model=model,\n",
        "#         feat_imp=feat_self_1hop,\n",
        "#         nodes=nodes_to_explain,\n",
        "#         k_fracs=args.k_fracs,\n",
        "#         use_true_class=args.use_true_class,\n",
        "#         label=\"SELF + 1HOP\",\n",
        "#     )\n",
        "\n",
        "#     fp_s2, fm_s2, fk_s2 = compute_fidelity_scores(\n",
        "#         data=data,\n",
        "#         model=model,\n",
        "#         feat_imp=feat_self_2hop,\n",
        "#         nodes=nodes_to_explain,\n",
        "#         k_fracs=args.k_fracs,\n",
        "#         use_true_class=args.use_true_class,\n",
        "#         label=\"SELF + 1HOP + 2HOP\",\n",
        "#     )\n",
        "    # frac_rand = compute_noisy_fraction(\n",
        "    #   feat_imp=feat_random,       # [N, F_total]\n",
        "    #   noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #   nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #   k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #   label = 'RANDOM BASELINE'\n",
        "    # )\n",
        "    # frac_gc = compute_noisy_fraction(\n",
        "    #   feat_imp=feat_gradcam,       # [N, F_total]\n",
        "    #   noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #   nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #   k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #   label = 'GradCAM-feature'\n",
        "    # )\n",
        "#     fp_rand, fm_rand, fk_rand = compute_fidelity_scores(\n",
        "#     data=data,\n",
        "#     model=model,\n",
        "#     feat_imp=feat_random,\n",
        "#     nodes=nodes_to_explain,\n",
        "#     k_fracs=args.k_fracs,\n",
        "#     use_true_class=args.use_true_class,\n",
        "#     label=\"RANDOM BASELINE\",\n",
        "#     )\n",
        "\n",
        "#     fp_gc, fm_gc, fk_gc = compute_fidelity_scores(\n",
        "#     data=data,\n",
        "#     model=model,\n",
        "#     feat_imp=feat_gradcam,\n",
        "#     nodes=nodes_to_explain,\n",
        "#     k_fracs=args.k_fracs,\n",
        "#     use_true_class=args.use_true_class,\n",
        "#     label=\"GradCAM-feature\",\n",
        "# )\n",
        "\n",
        "#     fp_lrp, fm_lrp, fk_lrp = compute_fidelity_scores(\n",
        "#         data=data,\n",
        "#         model=model,\n",
        "#         feat_imp=feat_lrp,\n",
        "#         nodes=nodes_to_explain,\n",
        "#         k_fracs=args.k_fracs,\n",
        "#         use_true_class=args.use_true_class,\n",
        "#         label=\"GNN-LRP-feature\",\n",
        "#     )\n",
        "    # frac_lrp = compute_noisy_fraction(\n",
        "    #     feat_imp=feat_lrp,       # [N, F_total]\n",
        "    #     noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #     nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #     k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #     label = 'GradCAM-feature'\n",
        "    #   )\n",
        "\n",
        "#     fp_goat, fm_goat, fk_goat = compute_fidelity_scores(\n",
        "#         data=data,\n",
        "#         model=model,\n",
        "#         feat_imp=feat_goat,\n",
        "#         nodes=nodes_to_explain,\n",
        "#         k_fracs=args.k_fracs,\n",
        "#         use_true_class=args.use_true_class,\n",
        "#         label=\"GOAT-feature\",\n",
        "#     )\n",
        "    # frac_goat = compute_noisy_fraction(\n",
        "    #     feat_imp=feat_goat,       # [N, F_total]\n",
        "    #     noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #     nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #     k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #     label = 'GradCAM-feature'\n",
        "    #   )\n",
        "    # frac_lime = compute_noisy_fraction(\n",
        "    #     feat_imp=feat_lime,       # [N, F_total]\n",
        "    #     noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #     nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #     k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #     label = 'GradCAM-feature'\n",
        "      # )\n",
        "#     fp_lime, fm_lime, fk_lime = compute_fidelity_scores(\n",
        "#         data=data,\n",
        "#         model=model,\n",
        "#         feat_imp=feat_lime,\n",
        "#         nodes=nodes_to_explain,\n",
        "#         k_fracs=args.k_fracs,\n",
        "#         use_true_class=args.use_true_class,\n",
        "#         label=\"LIME\",\n",
        "#     )\n",
        "    # fp_glime, fm_glime, fk_glime = compute_fidelity_scores(\n",
        "    # data=data,\n",
        "    # model=model,\n",
        "    # feat_imp=feat_glime,\n",
        "    # nodes=nodes_to_explain,\n",
        "    # k_fracs=args.k_fracs,\n",
        "    # use_true_class=args.use_true_class,\n",
        "    # label=\"GraphLIME\",\n",
        "    # )\n",
        "    # fp_gnn, fm_gnn, fk_gnn = compute_fidelity_scores(\n",
        "    # data=data,\n",
        "    # model=model,\n",
        "    # feat_imp=feat_gnnexp,\n",
        "    # nodes=nodes_to_explain,\n",
        "    # k_fracs=args.k_fracs,\n",
        "    # use_true_class=args.use_true_class,\n",
        "    # label=\"GNNEXPLAINER BASELINE\",\n",
        "    # )\n",
        "  #   fp_ig, fm_ig, fk_ig = compute_fidelity_scores(\n",
        "  #   data=data,\n",
        "  #   model=model,\n",
        "  #   feat_imp=feat_ig,\n",
        "  #   nodes=nodes_to_explain,\n",
        "  #   k_fracs=args.k_fracs,\n",
        "  #   use_true_class=args.use_true_class,\n",
        "  #   label=\"IG (Captum) BASELINE\",\n",
        "  # )\n",
        "    # frac_ig = compute_noisy_fraction(\n",
        "    #     feat_imp=feat_ig,       # [N, F_total]\n",
        "    #     noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #     nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #     k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #     label = 'IG BASELINE'\n",
        "    #   )\n",
        "\n",
        "\n",
        "    # 8) Summary\n",
        "    print(\"\\n=== Summary: mean Δ = p_orig - p_mask ===\")\n",
        "    print(\"\\nRobustness (pick top-k% features):\")\n",
        "    for k in args.k_fracs:\n",
        "        print(\n",
        "            f\"k={k:.2f} | \"\n",
        "            f\"self={robust_self[k]:.4f} | \"\n",
        "            f\"self+2hop={robust_s2[k]:.4f} | \"\n",
        "            f\"rand={robust_random[k]:.4f} | \"\n",
        "            f\"gradcam={robust_gradcam[k]:.4f} | \"\n",
        "            f\"lrp={robust_lrp[k]:.4f} | \"\n",
        "            f\"goat={robust_goat[k]:.4f} | \"\n",
        "            f\"lime={robust_lime[k]:.4f} | \"\n",
        "            f\"ig={robust_ig[k]:.4f} | \"\n",
        "        )\n",
        "\n",
        "    # print(\"\\nFidelity- (remove bottom-k% features):\")\n",
        "    # for k in args.k_fracs:\n",
        "    #     print(\n",
        "    #         f\"k={k:.2f} | \"\n",
        "    #         f\"self={fm_self[k]:.4f} | \"\n",
        "    #         f\"self+1hop={fm_s1[k]:.4f} | \"\n",
        "    #         f\"self+2hop={fm_s2[k]:.4f} | \"\n",
        "    #         f\"rand={fm_rand[k]:.4f} | \"\n",
        "    #         f\"gradcam={fm_gc[k]:.4f} | \"\n",
        "    #         f\"lrp={fm_lrp[k]:.4f} | \"\n",
        "    #         f\"goat={fm_goat[k]:.4f} | \"\n",
        "    #         f\"lime={fm_lime[k]:.4f} | \"\n",
        "    #         f\"ig={fm_ig[k]:.4f} | \"\n",
        "    #     )\n",
        "\n",
        "    # print(\"\\n=== Build-Time Summary (seconds) ===\")\n",
        "    # print(f\"{'method':20s} {'build_time':>12s}\")\n",
        "    # for name, t in method_times.items():\n",
        "    #     print(f\"{name:20s} {t:12.3f}\")\n",
        "    # print(frac_self)\n",
        "    # print(frac_s1)\n",
        "    # print(frac_s2)\n",
        "    # print(frac_rand)\n",
        "    # print(frac_gc)\n",
        "    # print(frac_lime)\n",
        "    #return frac_self\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "  main()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "aWv2ytkOC7yq",
        "outputId": "66e839fe-7024-44ed-981a-3c7a1f07188a"
      },
      "execution_count": 56,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "=== Planetoid-PubMed ===\n",
            "#Nodes     = 19717\n",
            "#Edges     = 88648\n",
            "#Features  = 500\n",
            "#Classes   = 3\n",
            "Final feature dim after augmentation = 500\n",
            "=== Planetoid-PubMed ===\n",
            "#Nodes     = 19717\n",
            "#Edges     = 88648\n",
            "#Features  = 500\n",
            "#Classes   = 3\n",
            "Added noise to features.\n",
            "Final feature dim after augmentation = 500\n",
            "\n",
            "=== Loading existing model from pubmed_gcn.pt ===\n",
            "Loaded model | Train=1.0000 Val=0.7780 Test=0.7760\n",
            "\n",
            "=== Building normalized adjacency list ===\n",
            "\n",
            "=== Explaining 100 test nodes (pred class as target) ===\n",
            "Building feature importance via decomposition for 100 nodes...\n",
            "  processed 1/100 nodes\n",
            "  processed 20/100 nodes\n",
            "  processed 40/100 nodes\n",
            "  processed 60/100 nodes\n",
            "  processed 80/100 nodes\n",
            "  processed 100/100 nodes\n",
            "Building feature importance via decomposition for 100 nodes...\n",
            "  processed 1/100 nodes\n",
            "  processed 20/100 nodes\n",
            "  processed 40/100 nodes\n",
            "  processed 60/100 nodes\n",
            "  processed 80/100 nodes\n",
            "  processed 100/100 nodes\n",
            "Running GNN-LRP for 100 nodes...\n",
            "  GNN-LRP processed 1/100 nodes\n",
            "  GNN-LRP processed 10/100 nodes\n",
            "  GNN-LRP processed 20/100 nodes\n",
            "  GNN-LRP processed 30/100 nodes\n",
            "  GNN-LRP processed 40/100 nodes\n",
            "  GNN-LRP processed 50/100 nodes\n",
            "  GNN-LRP processed 60/100 nodes\n",
            "  GNN-LRP processed 70/100 nodes\n",
            "  GNN-LRP processed 80/100 nodes\n",
            "  GNN-LRP processed 90/100 nodes\n",
            "  GNN-LRP processed 100/100 nodes\n",
            "Running GNN-LRP for 100 nodes...\n",
            "  GNN-LRP processed 1/100 nodes\n",
            "  GNN-LRP processed 10/100 nodes\n",
            "  GNN-LRP processed 20/100 nodes\n",
            "  GNN-LRP processed 30/100 nodes\n",
            "  GNN-LRP processed 40/100 nodes\n",
            "  GNN-LRP processed 50/100 nodes\n",
            "  GNN-LRP processed 60/100 nodes\n",
            "  GNN-LRP processed 70/100 nodes\n",
            "  GNN-LRP processed 80/100 nodes\n",
            "  GNN-LRP processed 90/100 nodes\n",
            "  GNN-LRP processed 100/100 nodes\n",
            "Running GOAT for 100 nodes...\n",
            "  GOAT processed 1/100 nodes\n",
            "  GOAT processed 10/100 nodes\n",
            "  GOAT processed 20/100 nodes\n",
            "  GOAT processed 30/100 nodes\n",
            "  GOAT processed 40/100 nodes\n",
            "  GOAT processed 50/100 nodes\n",
            "  GOAT processed 60/100 nodes\n",
            "  GOAT processed 70/100 nodes\n",
            "  GOAT processed 80/100 nodes\n",
            "  GOAT processed 90/100 nodes\n",
            "  GOAT processed 100/100 nodes\n",
            "Running GOAT for 100 nodes...\n",
            "  GOAT processed 1/100 nodes\n",
            "  GOAT processed 10/100 nodes\n",
            "  GOAT processed 20/100 nodes\n",
            "  GOAT processed 30/100 nodes\n",
            "  GOAT processed 40/100 nodes\n",
            "  GOAT processed 50/100 nodes\n",
            "  GOAT processed 60/100 nodes\n",
            "  GOAT processed 70/100 nodes\n",
            "  GOAT processed 80/100 nodes\n",
            "  GOAT processed 90/100 nodes\n",
            "  GOAT processed 100/100 nodes\n",
            "Running LIME for 100 nodes...\n",
            "  LIME processed 1/100 nodes\n",
            "  LIME processed 10/100 nodes\n",
            "  LIME processed 20/100 nodes\n",
            "  LIME processed 30/100 nodes\n",
            "  LIME processed 40/100 nodes\n",
            "  LIME processed 50/100 nodes\n",
            "  LIME processed 60/100 nodes\n",
            "  LIME processed 70/100 nodes\n",
            "  LIME processed 80/100 nodes\n",
            "  LIME processed 90/100 nodes\n",
            "  LIME processed 100/100 nodes\n",
            "Running LIME for 100 nodes...\n",
            "  LIME processed 1/100 nodes\n",
            "  LIME processed 10/100 nodes\n",
            "  LIME processed 20/100 nodes\n",
            "  LIME processed 30/100 nodes\n",
            "  LIME processed 40/100 nodes\n",
            "  LIME processed 50/100 nodes\n",
            "  LIME processed 60/100 nodes\n",
            "  LIME processed 70/100 nodes\n",
            "  LIME processed 80/100 nodes\n",
            "  LIME processed 90/100 nodes\n",
            "  LIME processed 100/100 nodes\n",
            "Running Clean IG for 100 nodes...\n",
            "  IG processed 1/100 nodes\n",
            "  IG processed 10/100 nodes\n",
            "  IG processed 20/100 nodes\n",
            "  IG processed 30/100 nodes\n",
            "  IG processed 40/100 nodes\n",
            "  IG processed 50/100 nodes\n",
            "  IG processed 60/100 nodes\n",
            "  IG processed 70/100 nodes\n",
            "  IG processed 80/100 nodes\n",
            "  IG processed 90/100 nodes\n",
            "  IG processed 100/100 nodes\n",
            "Running Clean IG for 100 nodes...\n",
            "  IG processed 1/100 nodes\n",
            "  IG processed 10/100 nodes\n",
            "  IG processed 20/100 nodes\n",
            "  IG processed 30/100 nodes\n",
            "  IG processed 40/100 nodes\n",
            "  IG processed 50/100 nodes\n",
            "  IG processed 60/100 nodes\n",
            "  IG processed 70/100 nodes\n",
            "  IG processed 80/100 nodes\n",
            "  IG processed 90/100 nodes\n",
            "  IG processed 100/100 nodes\n",
            "\n",
            "=== Computing Recovery scores ===\n",
            "\n",
            "=== Summary: mean Δ = p_orig - p_mask ===\n",
            "\n",
            "Robustness (pick top-k% features):\n",
            "k=0.02 | self=0.9836 | self+2hop=0.9757 | rand=0.5204 | gradcam=0.9903 | lrp=0.5881 | goat=0.9889 | lime=0.0231 | ig=0.9839 | \n",
            "k=0.05 | self=0.9787 | self+2hop=0.9596 | rand=0.5229 | gradcam=0.9908 | lrp=0.5810 | goat=0.9852 | lime=0.0270 | ig=0.9791 | \n",
            "k=0.10 | self=0.8758 | self+2hop=0.9397 | rand=0.5285 | gradcam=0.9907 | lrp=0.4957 | goat=0.9810 | lime=0.0321 | ig=0.8776 | \n",
            "k=0.20 | self=0.4867 | self+2hop=0.8968 | rand=0.5433 | gradcam=0.9906 | lrp=0.2673 | goat=0.9724 | lime=0.0391 | ig=0.4882 | \n",
            "k=0.95 | self=0.1026 | self+2hop=0.5129 | rand=0.4256 | gradcam=0.9756 | lrp=0.0563 | goat=0.7475 | lime=0.1047 | ig=0.1029 | \n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Cora"
      ],
      "metadata": {
        "id": "IST6mApctP2C"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# -------------------------------------------------------------\n",
        "# 1. Dataset + splits\n",
        "# -------------------------------------------------------------\n",
        "from torch_geometric.datasets import Planetoid\n",
        "def create_splits(data, num_train_per_class=20, num_val_per_class=30):\n",
        "    y = data.y\n",
        "    num_nodes = data.num_nodes\n",
        "    num_classes = int(y.max().item() + 1)\n",
        "\n",
        "    train_mask = torch.zeros(num_nodes, dtype=torch.bool)\n",
        "    val_mask = torch.zeros(num_nodes, dtype=torch.bool)\n",
        "    test_mask = torch.zeros(num_nodes, dtype=torch.bool)\n",
        "\n",
        "    for c in range(num_classes):\n",
        "        idx = (y == c).nonzero(as_tuple=False).view(-1)\n",
        "        idx = idx[torch.randperm(idx.size(0))]\n",
        "\n",
        "        n_train = min(num_train_per_class, idx.size(0))\n",
        "        n_val = min(num_val_per_class, max(0, idx.size(0) - n_train))\n",
        "\n",
        "        train_idx = idx[:n_train]\n",
        "        val_idx = idx[n_train:n_train + n_val]\n",
        "        test_idx = idx[n_train + n_val:]\n",
        "\n",
        "        train_mask[train_idx] = True\n",
        "        val_mask[val_idx] = True\n",
        "        test_mask[test_idx] = True\n",
        "\n",
        "    data.train_mask = train_mask\n",
        "    data.val_mask = val_mask\n",
        "    data.test_mask = test_mask\n",
        "\n",
        "    print(f\"#Train nodes = {int(train_mask.sum())}\")\n",
        "    print(f\"#Val nodes   = {int(val_mask.sum())}\")\n",
        "    print(f\"#Test nodes  = {int(test_mask.sum())}\")\n",
        "\n",
        "    return data\n",
        "\n",
        "def load_planetoid(\n",
        "    name: str = 'Cora',\n",
        "    root: str = \"data/Amazon\",\n",
        "    percent_noisy: float = 0,              # <<< % of noisy features to add\n",
        "    seed: int = 1234,\n",
        "    add_noise = True\n",
        "):\n",
        "    dataset = Planetoid(\n",
        "        root=root,\n",
        "        name=name,\n",
        "        transform=T.NormalizeFeatures(),   # common preprocessing\n",
        "    )\n",
        "    data = dataset[0]\n",
        "\n",
        "    print(\"=== Planetoid-{name} ===\")\n",
        "    print(f\"#Nodes     = {data.num_nodes}\")\n",
        "    print(f\"#Edges     = {data.num_edges}\")\n",
        "    print(f\"#Features  = {dataset.num_features}\")\n",
        "    print(f\"#Classes   = {dataset.num_classes}\")\n",
        "\n",
        "    # -----------------------------\n",
        "    # 1) Add Noisy Features (NEW)\n",
        "    # -----------------------------\n",
        "    if percent_noisy > 0:\n",
        "        print(f\"Adding {percent_noisy*100}% noisy features...\")\n",
        "        data, noisy_mask = augment_with_noisy_features(\n",
        "            data,\n",
        "            n_noisy=int(percent_noisy*dataset.num_features),\n",
        "            seed=seed\n",
        "        )\n",
        "    else:\n",
        "        noisy_mask = torch.zeros(dataset.num_features, dtype=torch.bool)\n",
        "    if add_noise:\n",
        "        data = make_noisy_data(data)\n",
        "        print(\"Added noise to features.\")\n",
        "    # -----------------------------\n",
        "    # 2) Split (same as before)\n",
        "    # -----------------------------\n",
        "    if not hasattr(data, \"train_mask\") or data.train_mask is None:\n",
        "        data = create_splits(\n",
        "            data, num_train_per_class=20, num_val_per_class=30\n",
        "        )\n",
        "\n",
        "    # -----------------------------\n",
        "    # Return updated feature dim\n",
        "    # -----------------------------\n",
        "    in_dim = data.x.size(1)\n",
        "    num_classes = dataset.num_classes\n",
        "\n",
        "    print(f\"Final feature dim after augmentation = {in_dim}\")\n",
        "\n",
        "    return data, in_dim, num_classes, noisy_mask\n",
        "\n",
        "\n",
        "def main():\n",
        "    method_times = {}\n",
        "    parser = argparse.ArgumentParser()\n",
        "    parser.add_argument(\"--seed\", type=int, default=42)\n",
        "    parser.add_argument(\"--hidden_dim\", type=int, default=64)\n",
        "    parser.add_argument(\"--dropout\", type=float, default=0.5)\n",
        "    parser.add_argument(\"--epochs\", type=int, default=500)\n",
        "    parser.add_argument(\"--lr\", type=float, default=0.05)\n",
        "    parser.add_argument(\"--weight_decay\", type=float, default=5e-4)\n",
        "    parser.add_argument(\"--root\", type=str, default=\"data/Planetoid\")\n",
        "    parser.add_argument(\"--k_fracs\", type=float, nargs=\"+\",\n",
        "                        default=[0.02,0.05,0.1,0.2,0.95]) #0.95 sanity check\n",
        "    parser.add_argument(\"--num_explain\", type=int, default=100,\n",
        "                        help=\"how many test nodes to explain via decomposition\")\n",
        "    parser.add_argument(\"--use_true_class\", action=\"store_true\",\n",
        "                        help=\"use true label instead of predicted label as target class\")\n",
        "    parser.add_argument(\"--model_path\", type=str, default=\"cora_gcn.pt\",\n",
        "                        help=\"path to save/load trained GCN model\")\n",
        "    args = parser.parse_args(args=[])\n",
        "\n",
        "    set_seed(args.seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    # 1) Load data\n",
        "    data, in_dim, num_classes,_ = load_planetoid(name='Cora',root=args.root,add_noise=False)\n",
        "    data = data.to(device)\n",
        "\n",
        "    data_noise, _, _,_ = load_planetoid(name='Cora',root=args.root,add_noise=True)\n",
        "    data_noise = data_noise.to(device)\n",
        "\n",
        "    # 2) Build model\n",
        "    model = GCN(\n",
        "        in_channels=in_dim,\n",
        "        hidden_channels=args.hidden_dim,\n",
        "        out_channels=num_classes,\n",
        "        dropout=args.dropout,\n",
        "    ).to(device)\n",
        "\n",
        "    # 3) Train or load model\n",
        "    if os.path.exists(args.model_path):\n",
        "        print(f\"\\n=== Loading existing model from {args.model_path} ===\")\n",
        "        state = torch.load(args.model_path, map_location=device)\n",
        "        model.load_state_dict(state)\n",
        "        train_acc, val_acc, test_acc, _ = evaluate(model, data)\n",
        "        print(f\"Loaded model | Train={train_acc:.4f} Val={val_acc:.4f} Test={test_acc:.4f}\")\n",
        "    else:\n",
        "        print(\"\\n=== Training GCN (no saved model found) ===\")\n",
        "        optimizer = torch.optim.Adam(\n",
        "            model.parameters(),\n",
        "            lr=args.lr,\n",
        "            weight_decay=args.weight_decay,\n",
        "        )\n",
        "\n",
        "        best_val_acc = 0.0\n",
        "        best_state = None\n",
        "\n",
        "        for epoch in range(1, args.epochs + 1):\n",
        "            loss = train(model, data, optimizer)\n",
        "            train_acc, val_acc, test_acc, _ = evaluate(model, data)\n",
        "\n",
        "            if val_acc > best_val_acc:\n",
        "                best_val_acc = val_acc\n",
        "                best_state = {\n",
        "                    \"model\": model.state_dict(),\n",
        "                    \"epoch\": epoch,\n",
        "                    \"test_acc\": test_acc,\n",
        "                }\n",
        "\n",
        "            if epoch % 20 == 0 or epoch == 1:\n",
        "                print(\n",
        "                    f\"Epoch {epoch:03d} | \"\n",
        "                    f\"Loss {loss:.4f} | \"\n",
        "                    f\"Train {train_acc:.4f} | \"\n",
        "                    f\"Val {val_acc:.4f} | \"\n",
        "                    f\"Test {test_acc:.4f}\"\n",
        "                )\n",
        "\n",
        "        if best_state is not None:\n",
        "            model.load_state_dict(best_state[\"model\"])\n",
        "            print(\n",
        "                f\"\\nBest epoch = {best_state['epoch']} | \"\n",
        "                f\"Best val acc = {best_val_acc:.4f} | \"\n",
        "                f\"Test acc @best = {best_state['test_acc']:.4f}\"\n",
        "            )\n",
        "\n",
        "        # Save trained model\n",
        "        torch.save(model.state_dict(), args.model_path)\n",
        "        print(f\"Model saved to {args.model_path}\")\n",
        "\n",
        "    # 4) Build normalized adjacency list\n",
        "    print(\"\\n=== Building normalized adjacency list ===\")\n",
        "    adj_list = build_normalized_adjacency_list(model, data)\n",
        "\n",
        "    # 5) Choose test nodes to explain (correct predictions preferred)\n",
        "    model.eval()\n",
        "    logits = model(data.x, data.edge_index)\n",
        "    preds = logits.argmax(dim=-1)\n",
        "\n",
        "    test_nodes = torch.nonzero(data.test_mask, as_tuple=False).view(-1)\n",
        "    correct_mask = preds == data.y\n",
        "    test_correct_nodes = test_nodes[correct_mask[test_nodes]]\n",
        "\n",
        "    if test_correct_nodes.numel() == 0:\n",
        "        print(\"WARNING: no correctly classified test nodes, using all test nodes.\")\n",
        "        nodes_to_explain = test_nodes\n",
        "    else:\n",
        "        nodes_to_explain = test_correct_nodes\n",
        "\n",
        "    if nodes_to_explain.numel() > args.num_explain:\n",
        "        nodes_to_explain = nodes_to_explain[:args.num_explain]\n",
        "\n",
        "    print(\n",
        "        f\"\\n=== Explaining {nodes_to_explain.numel()} test nodes \"\n",
        "        f\"({'true' if args.use_true_class else 'pred'} class as target) ===\"\n",
        "    )\n",
        "\n",
        "    # 6) Build feature importance: self vs self+1hop vs self+2hop\n",
        "\n",
        "    feat_self, feat_self_1hop, feat_self_2hop = build_feature_importance_decomposition(\n",
        "        data=data,\n",
        "        model=model,\n",
        "        adj_list=adj_list,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        "        hop=2,\n",
        "        use_true_class=args.use_true_class,\n",
        "    )\n",
        "\n",
        "\n",
        "    feat_self_noise, feat_self_1hop_noise, feat_self_2hop_noise = build_feature_importance_decomposition(\n",
        "        data=data_noise,\n",
        "        model=model,\n",
        "        adj_list=adj_list,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        "        hop=2,\n",
        "        use_true_class=args.use_true_class,\n",
        "    )\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "#     # feat_self, feat_self_1hop, feat_self_2hop = build_feature_importance_decomposition(\n",
        "#     #     data=data,\n",
        "#     #     model=model,\n",
        "#     #     adj_list=adj_list,\n",
        "#     #     nodes_to_explain=nodes_to_explain,\n",
        "#     #     hop=2,\n",
        "#     #     use_true_class=args.use_true_class,\n",
        "#     # )\n",
        "#     # ---- Random baseline ----\n",
        "#     print(\"\\n=== Building RANDOM importance baseline ===\")\n",
        "\n",
        "#     t0 = _now()\n",
        "    feat_random = build_random_importance(\n",
        "        num_nodes=data.num_nodes,\n",
        "        num_features=data.x.size(1),\n",
        "        device=device,\n",
        "    )\n",
        "    feat_random_noise = build_random_importance(\n",
        "        num_nodes=data.num_nodes,\n",
        "        num_features=data.x.size(1),\n",
        "        device=device,\n",
        "    )\n",
        "#     t1 = _now()\n",
        "#     method_times[\"random\"] = t1 - t0\n",
        "\n",
        "# #     feat_random = build_random_importance(\n",
        "# #         num_nodes=data.num_nodes,\n",
        "# #         num_features=data.x.size(1),\n",
        "# #         device=device\n",
        "# # )\n",
        "#     print(\"\\n=== Building GRAD-Cam ===\")\n",
        "#     # feat_gradcam = gradcam_feature_importance(model, data, use_true_class=True)\n",
        "#     t0 = _now()\n",
        "    feat_gradcam = gradcam_feature_importance(\n",
        "        model=model,\n",
        "        data=data,\n",
        "        use_true_class=args.use_true_class,\n",
        "    )\n",
        "    feat_gradcam_noise = gradcam_feature_importance(\n",
        "        model=model,\n",
        "        data=data_noise,\n",
        "        use_true_class=args.use_true_class,\n",
        "    )\n",
        "#     t1 = _now()\n",
        "#     method_times[\"gradcam\"] = t1 - t0\n",
        "\n",
        "#     print(\"\\n=== Building GNN-LRP ===\")\n",
        "# #     feat_lrp = gnn_lrp_feature_importance(\n",
        "# #     model,\n",
        "# #     data,\n",
        "# #     eps=1e-6,\n",
        "# #     use_true_class=True,\n",
        "# #     nodes_to_explain=nodes_to_explain  # same subset you use for fidelity\n",
        "# # )\n",
        "#     t0 = _now()\n",
        "    feat_lrp = gnn_lrp_feature_importance(\n",
        "        model=model,\n",
        "        data=data,\n",
        "        eps=1e-6,\n",
        "        use_true_class=args.use_true_class,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        "    )\n",
        "    feat_lrp_noise = gnn_lrp_feature_importance(\n",
        "        model=model,\n",
        "        data=data_noise,\n",
        "        eps=1e-6,\n",
        "        use_true_class=args.use_true_class,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        "    )\n",
        "#     t1 = _now()\n",
        "#     method_times[\"lrp\"] = t1 - t0\n",
        "\n",
        "\n",
        "#     print(\"\\n=== Building GOAT ===\")\n",
        "#     # feat_goat = build_goat_feature_importance(\n",
        "#     #     data=data,\n",
        "#     #     model=model,\n",
        "#     #     nodes_to_explain=nodes_to_explain,\n",
        "#     # )\n",
        "#     t0 = _now()\n",
        "    feat_goat = build_goat_feature_importance(\n",
        "        data=data,\n",
        "        model=model,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        "    )\n",
        "    feat_goat_noise = build_goat_feature_importance(\n",
        "        data=data_noise,\n",
        "        model=model,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        "    )\n",
        "    # t1 = _now()\n",
        "#     method_times[\"goat\"] = t1 - t0\n",
        "\n",
        "#     print(\"\\n=== Building LIME ===\")\n",
        "#     # feat_lime = build_lime_importance(\n",
        "#     #     data=data,\n",
        "#     #     model=model,\n",
        "#     #     nodes_to_explain=nodes_to_explain,\n",
        "#     #     num_classes=num_classes,\n",
        "#     #     num_samples=50,\n",
        "#     # )\n",
        "#     # t0 = _now()\n",
        "    feat_lime = build_lime_importance(\n",
        "        data=data,\n",
        "        model=model,\n",
        "        nodes_to_explain=nodes_to_explain,  # or smaller subset\n",
        "        num_classes=num_classes,\n",
        "        num_samples=50,\n",
        "    )\n",
        "    feat_lime_noise = build_lime_importance(\n",
        "        data=data_noise,\n",
        "        model=model,\n",
        "        nodes_to_explain=nodes_to_explain,  # or smaller subset\n",
        "        num_classes=num_classes,\n",
        "        num_samples=50,\n",
        "    )\n",
        "#     t1 = _now()\n",
        "#     method_times[\"lime\"] = t1 - t0\n",
        "\n",
        "#     # GraphLime: if possible\n",
        "#     # print(\"\\n=== Building GraphLIME ===\")\n",
        "#     # t0 = _now()\n",
        "#     # feat_glime = build_graphlime_importance(\n",
        "#     # data=data,\n",
        "#     # model=model,\n",
        "#     # nodes_to_explain=nodes_to_explain,\n",
        "#     # hop=2,\n",
        "#     # rho=0.1,\n",
        "#     # )\n",
        "#     # t1 = _now()\n",
        "#     # method_times[\"graphlime\"] = t1 - t0\n",
        "\n",
        "\n",
        "#     #probably not use the following two ---\n",
        "\n",
        "#     #     # ---- GNNExplainer baseline importance ----\n",
        "#     # print(\"\\n=== Building GNNExplainer importance baseline ===\")\n",
        "#     # feat_gnnexp = build_gnnexplainer_importance(\n",
        "#     #     data=data,\n",
        "#     #     model=model,\n",
        "#     #     nodes_to_explain=nodes_to_explain,\n",
        "#     #     epochs=50,   # you can change this\n",
        "#     # )\n",
        "\n",
        "#     print(\"\\n=== Building Integrated Gradients (Captum) importance baseline ===\")\n",
        "#     t0 = _now()\n",
        "    feat_ig = build_ig_importance(\n",
        "        data=data,\n",
        "        model=model,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        ")\n",
        "    feat_ig_noise = build_ig_importance(\n",
        "        data=data_noise,\n",
        "        model=model,\n",
        "        nodes_to_explain=nodes_to_explain,\n",
        ")\n",
        "#     t1 = _now()\n",
        "#     method_times[\"ig\"] = t1 - t0\n",
        "\n",
        "    # 7) Recovery scores for each importance variant\n",
        "    print(\"\\n=== Computing Recovery scores ===\")\n",
        "\n",
        "    robust_self = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_self,\n",
        "    feat_imp_noisy= feat_self_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "    robust_s2 = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_self_2hop,\n",
        "    feat_imp_noisy= feat_self_2hop_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "    robust_gradcam = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_gradcam,\n",
        "    feat_imp_noisy= feat_gradcam_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "    robust_lrp = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_lrp,\n",
        "    feat_imp_noisy= feat_lrp_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "    robust_goat = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_goat,\n",
        "    feat_imp_noisy= feat_goat_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "    robust_random = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_random,\n",
        "    feat_imp_noisy= feat_random_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "    robust_lime = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_lime,\n",
        "    feat_imp_noisy= feat_lime_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "    robust_ig = compute_robustness_topk(\n",
        "    feat_imp_clean = feat_ig,\n",
        "    feat_imp_noisy= feat_ig_noise,\n",
        "    nodes=nodes_to_explain,\n",
        "    k_frac=args.k_fracs,  # Can be float or list of floats\n",
        "    eps = 1e-8,\n",
        ")\n",
        "\n",
        "    # frac_self = compute_noisy_fraction(\n",
        "    #   feat_imp=feat_self,       # [N, F_total]\n",
        "    #   noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #   nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #   k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #   label = 'SELF'\n",
        "    # )\n",
        "    # frac_s1 = compute_noisy_fraction(\n",
        "    #   feat_imp=feat_self_1hop,       # [N, F_total]\n",
        "    #   noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #   nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #   k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #   label = 'SELF + 1HOP'\n",
        "    # )\n",
        "    # frac_s2 = compute_noisy_fraction(\n",
        "    #   feat_imp=feat_self_2hop,       # [N, F_total]\n",
        "    #   noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #   nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #   k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #   label = 'SELF + 1HOP + 2HOP'\n",
        "    # )\n",
        "#     fp_s1, fm_s1, fk_s1 = compute_fidelity_scores(\n",
        "#         data=data,\n",
        "#         model=model,\n",
        "#         feat_imp=feat_self_1hop,\n",
        "#         nodes=nodes_to_explain,\n",
        "#         k_fracs=args.k_fracs,\n",
        "#         use_true_class=args.use_true_class,\n",
        "#         label=\"SELF + 1HOP\",\n",
        "#     )\n",
        "\n",
        "#     fp_s2, fm_s2, fk_s2 = compute_fidelity_scores(\n",
        "#         data=data,\n",
        "#         model=model,\n",
        "#         feat_imp=feat_self_2hop,\n",
        "#         nodes=nodes_to_explain,\n",
        "#         k_fracs=args.k_fracs,\n",
        "#         use_true_class=args.use_true_class,\n",
        "#         label=\"SELF + 1HOP + 2HOP\",\n",
        "#     )\n",
        "    # frac_rand = compute_noisy_fraction(\n",
        "    #   feat_imp=feat_random,       # [N, F_total]\n",
        "    #   noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #   nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #   k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #   label = 'RANDOM BASELINE'\n",
        "    # )\n",
        "    # frac_gc = compute_noisy_fraction(\n",
        "    #   feat_imp=feat_gradcam,       # [N, F_total]\n",
        "    #   noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #   nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #   k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #   label = 'GradCAM-feature'\n",
        "    # )\n",
        "#     fp_rand, fm_rand, fk_rand = compute_fidelity_scores(\n",
        "#     data=data,\n",
        "#     model=model,\n",
        "#     feat_imp=feat_random,\n",
        "#     nodes=nodes_to_explain,\n",
        "#     k_fracs=args.k_fracs,\n",
        "#     use_true_class=args.use_true_class,\n",
        "#     label=\"RANDOM BASELINE\",\n",
        "#     )\n",
        "\n",
        "#     fp_gc, fm_gc, fk_gc = compute_fidelity_scores(\n",
        "#     data=data,\n",
        "#     model=model,\n",
        "#     feat_imp=feat_gradcam,\n",
        "#     nodes=nodes_to_explain,\n",
        "#     k_fracs=args.k_fracs,\n",
        "#     use_true_class=args.use_true_class,\n",
        "#     label=\"GradCAM-feature\",\n",
        "# )\n",
        "\n",
        "#     fp_lrp, fm_lrp, fk_lrp = compute_fidelity_scores(\n",
        "#         data=data,\n",
        "#         model=model,\n",
        "#         feat_imp=feat_lrp,\n",
        "#         nodes=nodes_to_explain,\n",
        "#         k_fracs=args.k_fracs,\n",
        "#         use_true_class=args.use_true_class,\n",
        "#         label=\"GNN-LRP-feature\",\n",
        "#     )\n",
        "    # frac_lrp = compute_noisy_fraction(\n",
        "    #     feat_imp=feat_lrp,       # [N, F_total]\n",
        "    #     noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #     nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #     k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #     label = 'GradCAM-feature'\n",
        "    #   )\n",
        "\n",
        "#     fp_goat, fm_goat, fk_goat = compute_fidelity_scores(\n",
        "#         data=data,\n",
        "#         model=model,\n",
        "#         feat_imp=feat_goat,\n",
        "#         nodes=nodes_to_explain,\n",
        "#         k_fracs=args.k_fracs,\n",
        "#         use_true_class=args.use_true_class,\n",
        "#         label=\"GOAT-feature\",\n",
        "#     )\n",
        "    # frac_goat = compute_noisy_fraction(\n",
        "    #     feat_imp=feat_goat,       # [N, F_total]\n",
        "    #     noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #     nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #     k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #     label = 'GradCAM-feature'\n",
        "    #   )\n",
        "    # frac_lime = compute_noisy_fraction(\n",
        "    #     feat_imp=feat_lime,       # [N, F_total]\n",
        "    #     noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #     nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #     k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #     label = 'GradCAM-feature'\n",
        "      # )\n",
        "#     fp_lime, fm_lime, fk_lime = compute_fidelity_scores(\n",
        "#         data=data,\n",
        "#         model=model,\n",
        "#         feat_imp=feat_lime,\n",
        "#         nodes=nodes_to_explain,\n",
        "#         k_fracs=args.k_fracs,\n",
        "#         use_true_class=args.use_true_class,\n",
        "#         label=\"LIME\",\n",
        "#     )\n",
        "    # fp_glime, fm_glime, fk_glime = compute_fidelity_scores(\n",
        "    # data=data,\n",
        "    # model=model,\n",
        "    # feat_imp=feat_glime,\n",
        "    # nodes=nodes_to_explain,\n",
        "    # k_fracs=args.k_fracs,\n",
        "    # use_true_class=args.use_true_class,\n",
        "    # label=\"GraphLIME\",\n",
        "    # )\n",
        "    # fp_gnn, fm_gnn, fk_gnn = compute_fidelity_scores(\n",
        "    # data=data,\n",
        "    # model=model,\n",
        "    # feat_imp=feat_gnnexp,\n",
        "    # nodes=nodes_to_explain,\n",
        "    # k_fracs=args.k_fracs,\n",
        "    # use_true_class=args.use_true_class,\n",
        "    # label=\"GNNEXPLAINER BASELINE\",\n",
        "    # )\n",
        "  #   fp_ig, fm_ig, fk_ig = compute_fidelity_scores(\n",
        "  #   data=data,\n",
        "  #   model=model,\n",
        "  #   feat_imp=feat_ig,\n",
        "  #   nodes=nodes_to_explain,\n",
        "  #   k_fracs=args.k_fracs,\n",
        "  #   use_true_class=args.use_true_class,\n",
        "  #   label=\"IG (Captum) BASELINE\",\n",
        "  # )\n",
        "    # frac_ig = compute_noisy_fraction(\n",
        "    #     feat_imp=feat_ig,       # [N, F_total]\n",
        "    #     noisy_mask=noisy_mask,     # [F_total] bool\n",
        "    #     nodes=nodes_to_explain,          # [M] node indices to consider\n",
        "    #     k_frac=args.k_fracs,  # e.g., 0.1 for top 10% of features\n",
        "    #     label = 'IG BASELINE'\n",
        "    #   )\n",
        "\n",
        "\n",
        "    # 8) Summary\n",
        "    print(\"\\n=== Summary: mean Δ = p_orig - p_mask ===\")\n",
        "    print(\"\\nRobustness (pick top-k% features):\")\n",
        "    for k in args.k_fracs:\n",
        "        print(\n",
        "            f\"k={k:.2f} | \"\n",
        "            f\"self={robust_self[k]:.4f} | \"\n",
        "            f\"self+2hop={robust_s2[k]:.4f} | \"\n",
        "            f\"rand={robust_random[k]:.4f} | \"\n",
        "            f\"gradcam={robust_gradcam[k]:.4f} | \"\n",
        "            f\"lrp={robust_lrp[k]:.4f} | \"\n",
        "            f\"goat={robust_goat[k]:.4f} | \"\n",
        "            f\"lime={robust_lime[k]:.4f} | \"\n",
        "            f\"ig={robust_ig[k]:.4f} | \"\n",
        "        )\n",
        "\n",
        "    # print(\"\\nFidelity- (remove bottom-k% features):\")\n",
        "    # for k in args.k_fracs:\n",
        "    #     print(\n",
        "    #         f\"k={k:.2f} | \"\n",
        "    #         f\"self={fm_self[k]:.4f} | \"\n",
        "    #         f\"self+1hop={fm_s1[k]:.4f} | \"\n",
        "    #         f\"self+2hop={fm_s2[k]:.4f} | \"\n",
        "    #         f\"rand={fm_rand[k]:.4f} | \"\n",
        "    #         f\"gradcam={fm_gc[k]:.4f} | \"\n",
        "    #         f\"lrp={fm_lrp[k]:.4f} | \"\n",
        "    #         f\"goat={fm_goat[k]:.4f} | \"\n",
        "    #         f\"lime={fm_lime[k]:.4f} | \"\n",
        "    #         f\"ig={fm_ig[k]:.4f} | \"\n",
        "    #     )\n",
        "\n",
        "    # print(\"\\n=== Build-Time Summary (seconds) ===\")\n",
        "    # print(f\"{'method':20s} {'build_time':>12s}\")\n",
        "    # for name, t in method_times.items():\n",
        "    #     print(f\"{name:20s} {t:12.3f}\")\n",
        "    # print(frac_self)\n",
        "    # print(frac_s1)\n",
        "    # print(frac_s2)\n",
        "    # print(frac_rand)\n",
        "    # print(frac_gc)\n",
        "    # print(frac_lime)\n",
        "    #return frac_self\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "  main()"
      ],
      "metadata": {
        "id": "hAkPRBvJtVZB",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "acc51208-c254-487f-fdb7-e63ce69b0895"
      },
      "execution_count": 57,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "=== Planetoid-{name} ===\n",
            "#Nodes     = 2708\n",
            "#Edges     = 10556\n",
            "#Features  = 1433\n",
            "#Classes   = 7\n",
            "Final feature dim after augmentation = 1433\n",
            "=== Planetoid-{name} ===\n",
            "#Nodes     = 2708\n",
            "#Edges     = 10556\n",
            "#Features  = 1433\n",
            "#Classes   = 7\n",
            "Added noise to features.\n",
            "Final feature dim after augmentation = 1433\n",
            "\n",
            "=== Loading existing model from cora_gcn.pt ===\n",
            "Loaded model | Train=1.0000 Val=0.7820 Test=0.8170\n",
            "\n",
            "=== Building normalized adjacency list ===\n",
            "\n",
            "=== Explaining 100 test nodes (pred class as target) ===\n",
            "Building feature importance via decomposition for 100 nodes...\n",
            "  processed 1/100 nodes\n",
            "  processed 20/100 nodes\n",
            "  processed 40/100 nodes\n",
            "  processed 60/100 nodes\n",
            "  processed 80/100 nodes\n",
            "  processed 100/100 nodes\n",
            "Building feature importance via decomposition for 100 nodes...\n",
            "  processed 1/100 nodes\n",
            "  processed 20/100 nodes\n",
            "  processed 40/100 nodes\n",
            "  processed 60/100 nodes\n",
            "  processed 80/100 nodes\n",
            "  processed 100/100 nodes\n",
            "Running GNN-LRP for 100 nodes...\n",
            "  GNN-LRP processed 1/100 nodes\n",
            "  GNN-LRP processed 10/100 nodes\n",
            "  GNN-LRP processed 20/100 nodes\n",
            "  GNN-LRP processed 30/100 nodes\n",
            "  GNN-LRP processed 40/100 nodes\n",
            "  GNN-LRP processed 50/100 nodes\n",
            "  GNN-LRP processed 60/100 nodes\n",
            "  GNN-LRP processed 70/100 nodes\n",
            "  GNN-LRP processed 80/100 nodes\n",
            "  GNN-LRP processed 90/100 nodes\n",
            "  GNN-LRP processed 100/100 nodes\n",
            "Running GNN-LRP for 100 nodes...\n",
            "  GNN-LRP processed 1/100 nodes\n",
            "  GNN-LRP processed 10/100 nodes\n",
            "  GNN-LRP processed 20/100 nodes\n",
            "  GNN-LRP processed 30/100 nodes\n",
            "  GNN-LRP processed 40/100 nodes\n",
            "  GNN-LRP processed 50/100 nodes\n",
            "  GNN-LRP processed 60/100 nodes\n",
            "  GNN-LRP processed 70/100 nodes\n",
            "  GNN-LRP processed 80/100 nodes\n",
            "  GNN-LRP processed 90/100 nodes\n",
            "  GNN-LRP processed 100/100 nodes\n",
            "Running GOAT for 100 nodes...\n",
            "  GOAT processed 1/100 nodes\n",
            "  GOAT processed 10/100 nodes\n",
            "  GOAT processed 20/100 nodes\n",
            "  GOAT processed 30/100 nodes\n",
            "  GOAT processed 40/100 nodes\n",
            "  GOAT processed 50/100 nodes\n",
            "  GOAT processed 60/100 nodes\n",
            "  GOAT processed 70/100 nodes\n",
            "  GOAT processed 80/100 nodes\n",
            "  GOAT processed 90/100 nodes\n",
            "  GOAT processed 100/100 nodes\n",
            "Running GOAT for 100 nodes...\n",
            "  GOAT processed 1/100 nodes\n",
            "  GOAT processed 10/100 nodes\n",
            "  GOAT processed 20/100 nodes\n",
            "  GOAT processed 30/100 nodes\n",
            "  GOAT processed 40/100 nodes\n",
            "  GOAT processed 50/100 nodes\n",
            "  GOAT processed 60/100 nodes\n",
            "  GOAT processed 70/100 nodes\n",
            "  GOAT processed 80/100 nodes\n",
            "  GOAT processed 90/100 nodes\n",
            "  GOAT processed 100/100 nodes\n",
            "Running LIME for 100 nodes...\n",
            "  LIME processed 1/100 nodes\n",
            "  LIME processed 10/100 nodes\n",
            "  LIME processed 20/100 nodes\n",
            "  LIME processed 30/100 nodes\n",
            "  LIME processed 40/100 nodes\n",
            "  LIME processed 50/100 nodes\n",
            "  LIME processed 60/100 nodes\n",
            "  LIME processed 70/100 nodes\n",
            "  LIME processed 80/100 nodes\n",
            "  LIME processed 90/100 nodes\n",
            "  LIME processed 100/100 nodes\n",
            "Running LIME for 100 nodes...\n",
            "  LIME processed 1/100 nodes\n",
            "  LIME processed 10/100 nodes\n",
            "  LIME processed 20/100 nodes\n",
            "  LIME processed 30/100 nodes\n",
            "  LIME processed 40/100 nodes\n",
            "  LIME processed 50/100 nodes\n",
            "  LIME processed 60/100 nodes\n",
            "  LIME processed 70/100 nodes\n",
            "  LIME processed 80/100 nodes\n",
            "  LIME processed 90/100 nodes\n",
            "  LIME processed 100/100 nodes\n",
            "Running Clean IG for 100 nodes...\n",
            "  IG processed 1/100 nodes\n",
            "  IG processed 10/100 nodes\n",
            "  IG processed 20/100 nodes\n",
            "  IG processed 30/100 nodes\n",
            "  IG processed 40/100 nodes\n",
            "  IG processed 50/100 nodes\n",
            "  IG processed 60/100 nodes\n",
            "  IG processed 70/100 nodes\n",
            "  IG processed 80/100 nodes\n",
            "  IG processed 90/100 nodes\n",
            "  IG processed 100/100 nodes\n",
            "Running Clean IG for 100 nodes...\n",
            "  IG processed 1/100 nodes\n",
            "  IG processed 10/100 nodes\n",
            "  IG processed 20/100 nodes\n",
            "  IG processed 30/100 nodes\n",
            "  IG processed 40/100 nodes\n",
            "  IG processed 50/100 nodes\n",
            "  IG processed 60/100 nodes\n",
            "  IG processed 70/100 nodes\n",
            "  IG processed 80/100 nodes\n",
            "  IG processed 90/100 nodes\n",
            "  IG processed 100/100 nodes\n",
            "\n",
            "=== Computing Recovery scores ===\n",
            "\n",
            "=== Summary: mean Δ = p_orig - p_mask ===\n",
            "\n",
            "Robustness (pick top-k% features):\n",
            "k=0.02 | self=0.6067 | self+2hop=0.9275 | rand=0.5136 | gradcam=0.9873 | lrp=0.3146 | goat=0.9697 | lime=0.0070 | ig=0.6167 | \n",
            "k=0.05 | self=0.2460 | self+2hop=0.8576 | rand=0.5149 | gradcam=0.9866 | lrp=0.1267 | goat=0.9530 | lime=0.0072 | ig=0.2493 | \n",
            "k=0.10 | self=0.1268 | self+2hop=0.7612 | rand=0.5244 | gradcam=0.9851 | lrp=0.0649 | goat=0.9255 | lime=0.0079 | ig=0.1285 | \n",
            "k=0.20 | self=0.0706 | self+2hop=0.5899 | rand=0.5421 | gradcam=0.9822 | lrp=0.0387 | goat=0.8587 | lime=0.0092 | ig=0.0712 | \n",
            "k=0.95 | self=0.0239 | self+2hop=0.1737 | rand=0.4234 | gradcam=0.9672 | lrp=0.0156 | goat=0.3639 | lime=0.0287 | ig=0.0240 | \n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "data, in_dim, num_classes, noisy_mask = load_cora(\"data/Planetoid\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "y2fJDuRaXOrj",
        "outputId": "8ecfea4d-7568-46ef-b8af-cf5a4dc78040"
      },
      "execution_count": 21,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "=== Planetoid-{name} ===\n",
            "#Nodes     = 2708\n",
            "#Edges     = 10556\n",
            "#Features  = 1433\n",
            "#Classes   = 7\n",
            "Adding 20.0% noisy features...\n",
            "augment_with_noisy_features: detected binary=False, F=1433, N=2708\n",
            "  Using Bernoulli noise with p=0.0130\n",
            "  Augmented features: original F=1433, noisy=286, total=1719\n",
            "Final feature dim after augmentation = 1719\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## BASHAPE"
      ],
      "metadata": {
        "id": "cRjK-RPUer0n"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch_geometric.datasets import BAShapes\n",
        "from torch_geometric.nn import GCNConv\n",
        "from torch_geometric.utils import add_remaining_self_loops\n",
        "from torch_scatter import scatter_add\n",
        "from torch_geometric.explain import Explainer\n",
        "from torch_geometric.explain.algorithm import GNNExplainer,DummyExplainer\n",
        "\n",
        "from torch_geometric.explain.algorithm import CaptumExplainer\n",
        "from captum.attr import IntegratedGradients\n",
        "\n",
        "from torch_geometric.utils import to_undirected, add_self_loops, degree\n",
        "\n",
        "# ============================================================\n",
        "# 0. Setup\n",
        "# ============================================================\n",
        "\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "torch.manual_seed(0)\n",
        "\n",
        "\n",
        "\n",
        "# ============================================================\n",
        "# 1. Load BA-Shapes dataset\n",
        "# ============================================================\n",
        "\n",
        "dataset = BAShapes()   # synthetic, generated in memory\n",
        "data = dataset[0].to(device)\n",
        "\n",
        "print(data)\n",
        "print(f\"#nodes = {data.num_nodes}, #edges = {data.num_edges}\")\n",
        "print(f\"#features = {data.num_node_features}, #classes = {int(data.y.max())+1}\")\n",
        "\n",
        "num_nodes = data.num_nodes\n",
        "num_features = data.num_node_features\n",
        "num_classes = int(data.y.max().item()) + 1\n",
        "\n",
        "# Simple random train/val/test split if not already present\n",
        "if not hasattr(data, \"train_mask\"):\n",
        "    perm = torch.randperm(num_nodes, device=device)\n",
        "    n_train = int(0.6 * num_nodes)\n",
        "    n_val   = int(0.2 * num_nodes)\n",
        "    n_test  = num_nodes - n_train - n_val\n",
        "\n",
        "    train_idx = perm[:n_train]\n",
        "    val_idx   = perm[n_train:n_train+n_val]\n",
        "    test_idx  = perm[n_train+n_val:]\n",
        "\n",
        "    data.train_mask = torch.zeros(num_nodes, dtype=torch.bool, device=device)\n",
        "    data.val_mask   = torch.zeros(num_nodes, dtype=torch.bool, device=device)\n",
        "    data.test_mask  = torch.zeros(num_nodes, dtype=torch.bool, device=device)\n",
        "\n",
        "    data.train_mask[train_idx] = True\n",
        "    data.val_mask[val_idx]     = True\n",
        "    data.test_mask[test_idx]   = True\n",
        "\n",
        "# ============================================================\n",
        "# 2. Standard 2-layer GCN (no modifications for training)\n",
        "# ============================================================\n",
        "\n",
        "class VanillaGCN(nn.Module):\n",
        "    def __init__(self, in_channels, hidden_channels, out_channels, dropout=0.5):\n",
        "        super().__init__()\n",
        "        self.conv1 = GCNConv(in_channels, hidden_channels, cached=False)\n",
        "        self.conv2 = GCNConv(hidden_channels, out_channels, cached=False)\n",
        "        self.dropout = dropout\n",
        "\n",
        "    def forward(self, x, edge_index):\n",
        "        x = self.conv1(x, edge_index)\n",
        "        x = F.relu(x)\n",
        "        x = F.dropout(x, p=self.dropout, training=self.training)\n",
        "        x = self.conv2(x, edge_index)  # logits\n",
        "        return x\n",
        "\n",
        "model = VanillaGCN(num_features, hidden_channels=32, out_channels=num_classes).to(device)\n",
        "optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)\n",
        "\n",
        "def train_one_epoch():\n",
        "    model.train()\n",
        "    optimizer.zero_grad()\n",
        "    logits = model(data.x, data.edge_index)\n",
        "    loss = F.cross_entropy(logits[data.train_mask], data.y[data.train_mask])\n",
        "    loss.backward()\n",
        "    optimizer.step()\n",
        "    return loss.item()\n",
        "\n",
        "@torch.no_grad()\n",
        "def accuracy(mask):\n",
        "    model.eval()\n",
        "    logits = model(data.x, data.edge_index)\n",
        "    pred = logits.argmax(dim=-1)\n",
        "    correct = (pred[mask] == data.y[mask]).sum()\n",
        "    return float(correct) / int(mask.sum())\n",
        "\n",
        "print(\"\\nTraining GCN on BA-Shapes...\")\n",
        "for epoch in range(1, 1001):\n",
        "    loss = train_one_epoch()\n",
        "    if epoch % 20 == 0 or epoch == 1:\n",
        "        train_acc = accuracy(data.train_mask)\n",
        "        val_acc   = accuracy(data.val_mask)\n",
        "        test_acc  = accuracy(data.test_mask)\n",
        "        print(f\"Epoch {epoch:03d} | Loss {loss:.4f} | \"\n",
        "              f\"Train {train_acc:.3f} | Val {val_acc:.3f} | Test {test_acc:.3f}\")\n",
        "\n",
        "# ============================================================\n",
        "# 3. Post-hoc decomposition of the trained GCN (neighbor-level)\n",
        "# ============================================================\n",
        "\n",
        "@torch.no_grad()\n",
        "def compute_gcn_norm(edge_index, num_nodes, device, dtype=torch.float32):\n",
        "    \"\"\"\n",
        "    Rebuild GCN-like normalization:\n",
        "        A_hat = A + I, then D^{-1/2} A_hat D^{-1/2}.\n",
        "    We treat edge_index as [src, dst]; internally we use row=dst, col=src.\n",
        "    \"\"\"\n",
        "    edge_index = edge_index.flip(0)  # now [dst, src]\n",
        "    edge_index, _ = add_remaining_self_loops(edge_index, num_nodes=num_nodes)\n",
        "    row, col = edge_index\n",
        "\n",
        "    edge_weight = torch.ones(row.size(0), device=device, dtype=dtype)\n",
        "    deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)\n",
        "    deg_inv_sqrt = deg.pow(-0.5)\n",
        "    deg_inv_sqrt[deg_inv_sqrt == float(\"inf\")] = 0\n",
        "\n",
        "    norm = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]  # [E]\n",
        "    return edge_index, norm  # row=dst, col=src with symmetric norm\n",
        "\n",
        "\n",
        "@torch.no_grad()\n",
        "def layer1_decomposition(model, x, edge_index):\n",
        "    \"\"\"\n",
        "    Recompute layer-1 pre-activations and messages:\n",
        "        Z1 = A_norm X W1 + b1\n",
        "    Messages: m1_{u<-v} = norm_uv * (X_v W1)\n",
        "    \"\"\"\n",
        "    device = x.device\n",
        "    num_nodes = x.size(0)\n",
        "\n",
        "    W1 = model.conv1.lin.weight.T     # [F_in, H]\n",
        "    b1 = model.conv1.bias             # [H] or None\n",
        "\n",
        "    ei1, norm1 = compute_gcn_norm(edge_index, num_nodes, device)\n",
        "    row1, col1 = ei1  # row1=dst=u, col1=src=v\n",
        "\n",
        "    xW1 = x @ W1                      # [N, H]\n",
        "    m1 = xW1[col1] * norm1.view(-1, 1)  # [E1, H]\n",
        "\n",
        "    z1 = scatter_add(m1, row1, dim=0, dim_size=num_nodes)  # [N, H]\n",
        "    if b1 is not None:\n",
        "        z1 = z1 + b1\n",
        "    h1 = F.relu(z1)\n",
        "    return h1, z1, m1, ei1\n",
        "\n",
        "\n",
        "@torch.no_grad()\n",
        "def layer2_decomposition(model, h1, edge_index):\n",
        "    \"\"\"\n",
        "    Recompute layer-2 pre-activations and messages:\n",
        "        Z2 = A_norm H1 W2 + b2\n",
        "    Messages: m2_{u<-v} = norm_uv * (H1_v W2)\n",
        "    \"\"\"\n",
        "    device = h1.device\n",
        "    num_nodes = h1.size(0)\n",
        "\n",
        "    W2 = model.conv2.lin.weight.T     # [H, C]\n",
        "    b2 = model.conv2.bias             # [C] or None\n",
        "\n",
        "    ei2, norm2 = compute_gcn_norm(edge_index, num_nodes, device)\n",
        "    row2, col2 = ei2\n",
        "\n",
        "    h1W2 = h1 @ W2                    # [N, C]\n",
        "    m2 = h1W2[col2] * norm2.view(-1, 1)  # [E2, C]\n",
        "\n",
        "    z2 = scatter_add(m2, row2, dim=0, dim_size=num_nodes)  # [N, C]\n",
        "    if b2 is not None:\n",
        "        z2 = z2 + b2\n",
        "    return z2, m2, ei2\n",
        "\n",
        "\n",
        "@torch.no_grad()\n",
        "def posthoc_decompose(model, data):\n",
        "    \"\"\"\n",
        "    Full post-hoc decomposition of the trained 2-layer GCN:\n",
        "      - layer1 messages/embeddings\n",
        "      - layer2 messages/logits\n",
        "    \"\"\"\n",
        "    h1, z1, m1, ei1 = layer1_decomposition(model, data.x, data.edge_index)\n",
        "    z2, m2, ei2 = layer2_decomposition(model, h1, data.edge_index)\n",
        "    return {\n",
        "        \"h1\": h1,\n",
        "        \"z1\": z1,\n",
        "        \"messages1\": m1,\n",
        "        \"edge_index1\": ei1,\n",
        "        \"z2\": z2,            # logits\n",
        "        \"messages2\": m2,\n",
        "        \"edge_index2\": ei2,\n",
        "    }\n",
        "\n",
        "# sanity check\n",
        "with torch.no_grad():\n",
        "    expl_tmp = posthoc_decompose(model, data)\n",
        "    logits_model = model(data.x, data.edge_index)\n",
        "    diff = (expl_tmp[\"z2\"] - logits_model).abs().mean().item()\n",
        "    print(f\"\\nMean |Z2_decomp - logits_model| = {diff:.6f} (should be small-ish)\")\n",
        "\n",
        "# ============================================================\n",
        "# 4. Neighbor importance: Decomp, GNNExplainer, Random\n",
        "# ============================================================\n",
        "\n",
        "@torch.no_grad()\n",
        "def decomp_neighbor_scores(model, data, nid):\n",
        "    \"\"\"\n",
        "    Our method: neighbor scores from final-layer messages.\n",
        "    Returns:\n",
        "      neighbors: [K] neighbor node indices\n",
        "      scores:    [K] importance scores (abs contribution to predicted class)\n",
        "    \"\"\"\n",
        "    expl = posthoc_decompose(model, data)\n",
        "    z2 = expl[\"z2\"]               # [N, C] logits\n",
        "    m2 = expl[\"messages2\"]        # [E2, C]\n",
        "    ei2 = expl[\"edge_index2\"]     # [2, E2], row=dst, col=src\n",
        "\n",
        "    row2, col2 = ei2\n",
        "    probs = F.softmax(z2, dim=-1)\n",
        "    c = int(probs[nid].argmax().item())\n",
        "\n",
        "    mask = (row2 == nid)\n",
        "    neigh = col2[mask]                     # neighbors v such that v->nid\n",
        "    edge_scores = m2[mask, c].abs()        # [num_edges_into_nid]\n",
        "\n",
        "    if neigh.numel() == 0:\n",
        "        return neigh, edge_scores\n",
        "\n",
        "    unique_neigh, inv = torch.unique(neigh, return_inverse=True)\n",
        "    neigh_scores = scatter_add(edge_scores, inv, dim=0,\n",
        "                               dim_size=unique_neigh.size(0))\n",
        "    return unique_neigh, neigh_scores\n",
        "\n",
        "\n",
        "def get_in_neighbors(edge_index, nid):\n",
        "    src, dst = edge_index\n",
        "    mask = (dst == nid)\n",
        "    return src[mask]\n",
        "\n",
        "# --- GNNExplainer in edge mode (NO @torch.no_grad here!) ---\n",
        "expl_gnn_edge = Explainer(\n",
        "    model=model,\n",
        "    algorithm=GNNExplainer(epochs=100),\n",
        "    explanation_type='model',\n",
        "    node_mask_type=None,\n",
        "    edge_mask_type='object',\n",
        "    model_config=dict(\n",
        "        mode='multiclass_classification',\n",
        "        task_level='node',\n",
        "        return_type='raw',  # model returns logits\n",
        "    ),\n",
        ")\n",
        "\n",
        "def gnnexplainer_neighbor_scores(nid):\n",
        "    \"\"\"\n",
        "    Use GNNExplainer to get edge importance for node nid, then\n",
        "    aggregate to neighbor-level scores.\n",
        "    This must run with gradients enabled.\n",
        "    \"\"\"\n",
        "    model.eval()\n",
        "    exp = expl_gnn_edge(data.x, data.edge_index, index=int(nid))\n",
        "\n",
        "    e_src, e_dst = exp.edge_index\n",
        "    e_mask = exp.edge_mask\n",
        "\n",
        "    mask = (e_dst == nid)\n",
        "    neigh = e_src[mask]\n",
        "    edge_scores = e_mask[mask].abs()\n",
        "\n",
        "    if neigh.numel() == 0:\n",
        "        return neigh.to(device), edge_scores.to(device)\n",
        "\n",
        "    unique_neigh, inv = torch.unique(neigh, return_inverse=True)\n",
        "    neigh_scores = scatter_add(edge_scores, inv, dim=0,\n",
        "                               dim_size=unique_neigh.size(0))\n",
        "    return unique_neigh.to(device), neigh_scores.to(device)\n",
        "\n",
        "\n",
        "@torch.no_grad()\n",
        "def random_neighbor_scores(nid):\n",
        "    \"\"\"\n",
        "    Random neighbor baseline: uniform random scores over incoming neighbors.\n",
        "    \"\"\"\n",
        "    neigh = get_in_neighbors(data.edge_index, nid)\n",
        "    if neigh.numel() == 0:\n",
        "        return neigh, torch.empty(0, device=device)\n",
        "    scores = torch.rand(neigh.size(0), device=device)\n",
        "    return neigh, scores\n",
        "\n",
        "expl_ig = Explainer(\n",
        "    model=model,\n",
        "    algorithm=CaptumExplainer(IntegratedGradients),\n",
        "    explanation_type='model',\n",
        "    node_mask_type='attributes', # Changed from None\n",
        "    edge_mask_type=None,         # Changed from 'object'\n",
        "    model_config=dict(\n",
        "        mode='multiclass_classification',\n",
        "        task_level='node',\n",
        "        return_type='raw',\n",
        "    ),\n",
        ")\n",
        "\n",
        "def ig_neighbor_scores(nid):\n",
        "    \"\"\"\n",
        "    IG produces feature scores for node nid.\n",
        "    We convert those into a SINGLE scalar importance for node nid,\n",
        "    and assign that importance to all neighbors.\n",
        "    \"\"\"\n",
        "    model.eval()\n",
        "    exp = expl_ig(data.x, data.edge_index, index=int(nid))\n",
        "\n",
        "    node_mask = exp.node_mask  # [F] for the target node\n",
        "\n",
        "    # Aggregate feature importance for the target node into a single scalar\n",
        "    node_importance = node_mask.abs().sum().item()\n",
        "\n",
        "    neighbors = get_in_neighbors(data.edge_index, nid)\n",
        "    if neighbors.numel() == 0:\n",
        "        return neighbors.to(device), torch.empty(0, device=device)\n",
        "\n",
        "    # Assign this scalar importance to all neighbors\n",
        "    scores = torch.full((neighbors.size(0),), node_importance, device=device)\n",
        "    return neighbors.to(device), scores\n",
        "\n",
        "\n",
        "###\n",
        "#GOAT Method#\n",
        "###\n",
        "\n",
        "# =========================\n",
        "# 1) Build dense Â for BA-Shapes\n",
        "# =========================\n",
        "def get_in_neighbors(edge_index: torch.Tensor, nid: int) -> torch.Tensor:\n",
        "    \"\"\"Return unique in-neighbors of nid (src where dst == nid).\"\"\"\n",
        "    src, dst = edge_index\n",
        "    mask = (dst == nid)\n",
        "    return src[mask].unique()\n",
        "\n",
        "N = data.num_nodes\n",
        "D_total = data.num_node_features\n",
        "\n",
        "def build_gcn_norm_dense(edge_index, num_nodes):\n",
        "    \"\"\"\n",
        "    Build symmetric-normalized adjacency Â = D^{-1/2} (A + I) D^{-1/2} in dense form.\n",
        "    \"\"\"\n",
        "    edge_index = to_undirected(edge_index, num_nodes=num_nodes)\n",
        "    edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)\n",
        "    row, col = edge_index\n",
        "    deg = degree(row, num_nodes=num_nodes)\n",
        "    deg_inv_sqrt = deg.pow(-0.5)\n",
        "    deg_inv_sqrt[torch.isinf(deg_inv_sqrt)] = 0.\n",
        "    w = deg_inv_sqrt[row] * deg_inv_sqrt[col]  # scalar per edge\n",
        "    A = torch.zeros((num_nodes, num_nodes), device=device)\n",
        "    A.index_put_((row, col), w, accumulate=True)\n",
        "    return A  # [N, N]\n",
        "\n",
        "A_hat = build_gcn_norm_dense(data.edge_index, N)  # [N, N]\n",
        "\n",
        "# Precompute X^T once\n",
        "XT = data.x.t()   # [D_total, N]\n",
        "\n",
        "\n",
        "# =========================\n",
        "# 2) Pre-activations and ReLU masks for 2-layer GCN\n",
        "# =========================\n",
        "\n",
        "@torch.no_grad()\n",
        "def preactivations_and_masks_2layer():\n",
        "    \"\"\"\n",
        "    For a 2-layer GCN with:\n",
        "        z1 = Â X W1\n",
        "        h1 = ReLU(z1)\n",
        "        z2 = Â h1 W2\n",
        "        logits = z2\n",
        "\n",
        "    We compute:\n",
        "      z1: [N, H], z2: [N, C],\n",
        "      M1: [N, H] ReLU mask for layer 1 (1 if z1>0, else 0),\n",
        "      W1: [D_total, H], W2: [H, C].\n",
        "    \"\"\"\n",
        "    # Extract linear weights from your GCNConv layers\n",
        "    W1 = model.conv1.lin.weight.T    # [D_total, H]\n",
        "    W2 = model.conv2.lin.weight.T    # [H, C]\n",
        "\n",
        "    # First layer pre-activation: z1 = Â X W1\n",
        "    z1 = A_hat @ (data.x @ W1)       # [N, H]\n",
        "    M1 = (z1 > 0).float()\n",
        "\n",
        "    # Second layer pre-activation: z2 = Â (ReLU(z1) W2)\n",
        "    h1 = M1 * z1                     # [N, H]\n",
        "    z2 = A_hat @ (h1 @ W2)           # [N, C]\n",
        "\n",
        "    return z1, z2, M1, W1, W2\n",
        "\n",
        "z1_goat, z2_goat, M1_goat, W1_goat, W2_goat = preactivations_and_masks_2layer()\n",
        "\n",
        "\n",
        "# =========================\n",
        "# 3) GOAt-style per-feature attribution (2-layer)\n",
        "# =========================\n",
        "\n",
        "@torch.no_grad()\n",
        "def goat_2layer_feature_importance_for_node(i: int):\n",
        "    \"\"\"\n",
        "    Returns phi: [D_total], feature contributions to the predicted class logit at node i\n",
        "    for a 2-layer GCN (conv1 -> conv2).\n",
        "\n",
        "    Model:\n",
        "        z1 = Â X W1          (before ReLU)\n",
        "        h1 = ReLU(z1) = M1 ⊙ z1\n",
        "        z2 = Â h1 W2         (logits)\n",
        "\n",
        "    We approximate:\n",
        "        phi_f(i,c) ≈ sum_{h,u} x_{u,f} * [ (r_i^T D1_h Â)_u ] * W1[f,h] * W2[h,c]\n",
        "\n",
        "    where:\n",
        "        r_i^T is row i of Â,\n",
        "        D1_h = diag(M1[:,h]) is the ReLU mask for hidden unit h across nodes.\n",
        "    \"\"\"\n",
        "    i = int(i)\n",
        "\n",
        "    # predicted class c at node i\n",
        "    logits = model(data.x, data.edge_index)     # [N, C]\n",
        "    c = int(logits[i].argmax().item())\n",
        "\n",
        "    # row i of Â: how i aggregates from all nodes\n",
        "    r = A_hat[i, :]                             # [N]\n",
        "\n",
        "    # Initialize feature importance vector\n",
        "    phi = torch.zeros(D_total, device=device)   # [D_total]\n",
        "\n",
        "    H = W1_goat.size(1)  # hidden dimension\n",
        "\n",
        "    for h in range(H):\n",
        "        # Mask for hidden unit h across all nodes\n",
        "        # D1_h applied as elementwise multiplication with M1[:,h]\n",
        "        # First: r_i^T D1_h Â  ->  (r * M1[:,h])^T Â\n",
        "        t1 = (r * M1_goat[:, h]) @ A_hat        # [N]\n",
        "\n",
        "        # Aggregate over source nodes u:\n",
        "        # v_f = sum_u x_{u,f} * t1[u] = (X^T @ t1)_f\n",
        "        v = XT @ t1                             # [D_total]\n",
        "\n",
        "        # Chain weight for path (f -> h -> c)\n",
        "        # W1[f,h] * W2[h,c], but W1 is [D_total, H], W2 is [H, C]\n",
        "        chain = W1_goat[:, h] * W2_goat[h, c]   # [D_total]\n",
        "\n",
        "        # Add contribution for this hidden unit\n",
        "        phi += v * chain                        # [D_total]\n",
        "\n",
        "    return phi  # [D_total], feature importance for node i, class c\n",
        "\n",
        "@torch.no_grad()\n",
        "def goat_neighbor_scores(nid):\n",
        "    phi = goat_2layer_feature_importance_for_node(nid)  # [D_total]\n",
        "    node_score = phi.abs().sum().item()                 # collapse to scalar\n",
        "\n",
        "    neigh = get_in_neighbors(data.edge_index, nid)\n",
        "    if neigh.numel() == 0:\n",
        "        return neigh, torch.empty(0, device=device)\n",
        "    scores = torch.full((neigh.size(0),), node_score, device=device)\n",
        "    return neigh, scores\n",
        "\n",
        "##GRAD-CAM\n",
        "def gradcam_neighbor_scores(model, data, nid: int):\n",
        "    \"\"\"\n",
        "    Grad-CAM for node nid on BA-Shapes:\n",
        "\n",
        "    1) Take conv1 outputs h1 (after ReLU), keep grad.\n",
        "    2) Backprop from node nid's predicted logit.\n",
        "    3) alpha_k = mean_n grad_{n,k}\n",
        "    4) node_score[n] = ReLU( sum_k alpha_k * h1[n,k] )\n",
        "    5) Return scores restricted to neighbors of nid.\n",
        "    \"\"\"\n",
        "    model.eval()\n",
        "    nid = int(nid)\n",
        "\n",
        "    x = data.x.to(device)\n",
        "    edge_index = data.edge_index.to(device)\n",
        "\n",
        "    # Forward up to conv1\n",
        "    x1 = model.conv1(x, edge_index)      # [N, H]\n",
        "    h1 = F.relu(x1)\n",
        "    h1.retain_grad()                     # we want grad wrt h1\n",
        "\n",
        "    # Forward to logits\n",
        "    out = model.conv2(h1, edge_index)    # [N, C]\n",
        "    logits = out\n",
        "    c = int(logits[nid].argmax().item())\n",
        "\n",
        "    # Backprop from target logit\n",
        "    model.zero_grad()\n",
        "    if h1.grad is not None:\n",
        "        h1.grad.zero_()\n",
        "    logits[nid, c].backward(retain_graph=False)\n",
        "\n",
        "    grads = h1.grad                      # [N, H]\n",
        "    # Channel weights α_k: global average over nodes\n",
        "    alpha = grads.mean(dim=0)            # [H]\n",
        "\n",
        "    # Node-wise Grad-CAM scores\n",
        "    node_scores = F.relu((h1 * alpha)    # [N, H]\n",
        "                         .sum(dim=1))    # [N]\n",
        "\n",
        "    neighbors = get_in_neighbors(edge_index, nid)\n",
        "    if neighbors.numel() == 0:\n",
        "        return neighbors, torch.empty(0, device=device)\n",
        "\n",
        "    scores = node_scores[neighbors]\n",
        "    return neighbors, scores\n",
        "\n",
        "##GNN-LRP\n",
        "def gnn_lrp_neighbor_scores(model, data, nid: int, eps: float = 1e-6):\n",
        "    \"\"\"\n",
        "    Simplified GNN-LRP for a 2-layer GCN on BA-Shapes.\n",
        "\n",
        "    We only propagate relevance through the last GCN layer:\n",
        "\n",
        "        z2 = Â h1 W2\n",
        "\n",
        "    For node nid and class c:\n",
        "        m_v = Â[nid, v] * (h1[v] @ W2[:, c])\n",
        "        R_v = (m_v / (sum_v m_v + eps)) * z2[nid, c]\n",
        "\n",
        "    Neighbor score = |R_v|.\n",
        "    \"\"\"\n",
        "    model.eval()\n",
        "    nid = int(nid)\n",
        "\n",
        "    x = data.x.to(device)\n",
        "    edge_index = data.edge_index.to(device)\n",
        "\n",
        "    # Forward 1st layer\n",
        "    z1 = model.conv1(x, edge_index)\n",
        "    h1 = F.relu(z1)                       # [N, H]\n",
        "\n",
        "    # Extract W2 from conv2\n",
        "    # conv2.lin.weight: [C, H]  -> transpose to [H, C]\n",
        "    W2 = model.conv2.lin.weight.T        # [H, C]\n",
        "\n",
        "    # Logits via explicit GCN conv2 computation\n",
        "    # support = h1 @ W2  # [N, C]\n",
        "    support = h1 @ W2                    # [N, C]\n",
        "    z2 = A_hat @ support                 # [N, C]  (same as conv2 without bias)\n",
        "\n",
        "    # target class and relevance at output\n",
        "    c = int(z2[nid].argmax().item())\n",
        "    R_out = z2[nid, c]                   # scalar relevance at (nid,c)\n",
        "\n",
        "    # message from each node v into (nid, c)\n",
        "    # A_hat[nid] is row vector [N]\n",
        "    row_nid = A_hat[nid]                 # [N]\n",
        "    support_c = support[:, c]           # [N]\n",
        "    m = row_nid * support_c             # [N]\n",
        "\n",
        "    Z = m.sum() + eps\n",
        "    R_nodes = (m / Z) * R_out           # [N]\n",
        "\n",
        "    neighbors = get_in_neighbors(edge_index, nid)\n",
        "    if neighbors.numel() == 0:\n",
        "        return neighbors, torch.empty(0, device=device)\n",
        "\n",
        "    scores = R_nodes[neighbors].abs()\n",
        "    return neighbors, scores\n",
        "\n",
        "def to_full_neighbor_vector(neighbor_fn):\n",
        "    def wrapper(nid: int):\n",
        "        neighbors, scores = neighbor_fn(model, data, nid)\n",
        "        full = torch.zeros(N, device=device)\n",
        "        if neighbors.numel() > 0:\n",
        "            full[neighbors] = scores\n",
        "        return full\n",
        "    return wrapper\n",
        "\n",
        "# ============================================================\n",
        "# 5. Metrics: Fidelity+ / Fidelity− / Sparsity (neighbor-level)\n",
        "# ============================================================\n",
        "\n",
        "@torch.no_grad()\n",
        "def fidelity_negative_neighbors(model, data, nid, neighbors, scores, k_frac=0.5):\n",
        "    \"\"\"\n",
        "    Fidelity+ : keep ONLY top-k_frac neighbors into nid, drop others.\n",
        "    Returns ratio: prob_keep / prob_orig (for original predicted class).\n",
        "    \"\"\"\n",
        "    if neighbors.numel() == 0:\n",
        "        return 1.0\n",
        "\n",
        "    logits = model(data.x, data.edge_index)\n",
        "    probs = F.softmax(logits, dim=-1)\n",
        "    c = int(probs[nid].argmax().item())\n",
        "    orig = float(probs[nid, c].item())\n",
        "\n",
        "    deg_u = neighbors.size(0)\n",
        "    k = max(1, int(k_frac * deg_u))\n",
        "\n",
        "    abs_scores = scores.abs()\n",
        "    _, idx = torch.topk(abs_scores, k=k)\n",
        "    top_neighbors = neighbors[idx]\n",
        "\n",
        "    src, dst = data.edge_index\n",
        "    incoming = (dst == nid)\n",
        "    keep_from_top = incoming & torch.isin(src, top_neighbors)\n",
        "    keep_mask = (~incoming) | keep_from_top\n",
        "\n",
        "    edge_index_keep = data.edge_index[:, keep_mask]\n",
        "\n",
        "    logits2 = model(data.x, edge_index_keep)\n",
        "    probs2 = F.softmax(logits2, dim=-1)\n",
        "    keep = float(probs2[nid, c].item())\n",
        "\n",
        "    return orig - keep\n",
        "\n",
        "\n",
        "@torch.no_grad()\n",
        "def fidelity_positive_neighbors(model, data, nid, neighbors, scores, k_frac=0.5):\n",
        "    \"\"\"\n",
        "    Fidelity− : REMOVE top-k_frac neighbors into nid.\n",
        "    Returns probability drop: orig_prob - new_prob (for original class).\n",
        "    \"\"\"\n",
        "    if neighbors.numel() == 0:\n",
        "        return 0.0\n",
        "\n",
        "    logits = model(data.x, data.edge_index)\n",
        "    probs = F.softmax(logits, dim=-1)\n",
        "    c = int(probs[nid].argmax().item())\n",
        "    orig = float(probs[nid, c].item())\n",
        "\n",
        "    deg_u = neighbors.size(0)\n",
        "    k = max(1, int(k_frac * deg_u))\n",
        "\n",
        "    abs_scores = scores.abs()\n",
        "    _, idx = torch.topk(abs_scores, k=k)\n",
        "    top_neighbors = neighbors[idx]\n",
        "\n",
        "    src, dst = data.edge_index\n",
        "    incoming = (dst == nid)\n",
        "    drop_edges = incoming & torch.isin(src, top_neighbors)\n",
        "    keep_mask = ~drop_edges\n",
        "\n",
        "    edge_index_drop = data.edge_index[:, keep_mask]\n",
        "\n",
        "    logits2 = model(data.x, edge_index_drop)\n",
        "    probs2 = F.softmax(logits2, dim=-1)\n",
        "    new = float(probs2[nid, c].item())\n",
        "\n",
        "    return orig - new\n",
        "\n",
        "\n",
        "@torch.no_grad()\n",
        "def neighbor_sparsity(neighbors, scores, rel_thresh=0.1):\n",
        "    \"\"\"\n",
        "    Sparsity = 1 - (#active_neighbors / deg).\n",
        "    A neighbor is active if |score| >= rel_thresh * max(|score|).\n",
        "    \"\"\"\n",
        "    deg_u = neighbors.size(0)\n",
        "    if deg_u == 0:\n",
        "        return 1.0\n",
        "\n",
        "    abs_scores = scores.abs()\n",
        "    max_score = float(abs_scores.max().item())\n",
        "    if max_score == 0.0:\n",
        "        return 1.0\n",
        "\n",
        "    tau = rel_thresh * max_score\n",
        "    num_active = int((abs_scores >= tau).sum().item())\n",
        "    sparsity = 1.0 - num_active / deg_u\n",
        "    return sparsity\n",
        "\n",
        "# ============================================================\n",
        "# 6. Evaluation loop: mean metrics across test nodes & methods\n",
        "# ============================================================\n",
        "import time\n",
        "\n",
        "def timed(fn):\n",
        "    \"\"\"\n",
        "    Wraps an explanation function so it returns:\n",
        "        neighbors, scores, elapsed_time\n",
        "    \"\"\"\n",
        "    def wrapped(nid):\n",
        "        start = time.perf_counter()\n",
        "        neighbors, scores = fn(nid)\n",
        "        end = time.perf_counter()\n",
        "        return neighbors, scores, (end - start)\n",
        "    return wrapped\n",
        "\n",
        "# METHODS = {\n",
        "#     \"Decomp\":        lambda nid: decomp_neighbor_scores(model, data, nid),\n",
        "#     \"GNNExplainer\":  lambda nid: gnnexplainer_neighbor_scores(nid),\n",
        "#     \"IG\":            lambda nid: ig_neighbor_scores(nid),\n",
        "#     \"Random\":        lambda nid: random_neighbor_scores(nid),\n",
        "#     \"GOAT\":          lambda nid: goat_neighbor_scores(nid),\n",
        "#     \"GradCAM\":       lambda nid: gradcam_neighbor_scores(model, data, nid),\n",
        "#     \"GNN-LRP\":       lambda nid: gnn_lrp_neighbor_scores(model, data, nid),\n",
        "# }\n",
        "\n",
        "METHODS = {\n",
        "    \"Decomp\":        timed(lambda nid: decomp_neighbor_scores(model, data, nid)),\n",
        "    \"GNNExplainer\":  timed(lambda nid: gnnexplainer_neighbor_scores(nid)),\n",
        "    \"IG\":            timed(lambda nid: ig_neighbor_scores(nid)),\n",
        "    \"Random\":        timed(lambda nid: random_neighbor_scores(nid)),\n",
        "    \"GOAT\":          timed(lambda nid: goat_neighbor_scores(nid)),\n",
        "    \"GradCAM\":       timed(lambda nid: gradcam_neighbor_scores(model, data, nid)),\n",
        "    \"GNN-LRP\":       timed(lambda nid: gnn_lrp_neighbor_scores(model, data, nid)),\n",
        "}\n",
        "\n",
        "\n",
        "def evaluate_methods_on_nodes(node_indices, k_frac=0.5, rel_thresh=0.1):\n",
        "    \"\"\"\n",
        "    node_indices: iterable of node ids.\n",
        "    Returns method_name -> dict of mean metrics.\n",
        "    WARNING: we MUST allow gradients inside for GNNExplainer.\n",
        "    \"\"\"\n",
        "    results = {m: {\"fid_pos\": [], \"fid_neg\": [], \"sparsity\": [], \"time\":[]}\n",
        "               for m in METHODS.keys()}\n",
        "\n",
        "    for nid in node_indices:\n",
        "        nid = int(nid)\n",
        "        for name, get_scores in METHODS.items():\n",
        "            # neighbors, scores = get_scores(nid)\n",
        "            neighbors, scores, t = METHODS[name](nid)\n",
        "            # fid_plus, fid_minus, sparsity = compute_metrics(nid, neighbors, scores)\n",
        "            # results[name][\"fid_plus\"].append(fid_plus)\n",
        "            # results[name][\"fid_minus\"].append(fid_minus)\n",
        "            # results[name][\"sparsity\"].append(sparsity)\n",
        "            # results[name][\"time\"].append(t)\n",
        "\n",
        "            if neighbors.numel() == 0:\n",
        "                continue\n",
        "            if name == \"IG\" or name == \"GOAT\":\n",
        "              fp = fidelity_negative_neighbors(model, data, nid, neighbors, scores, k_frac=k_frac)\n",
        "              fn = fidelity_positive_neighbors(model, data, nid, neighbors, scores, k_frac=k_frac)\n",
        "            else:\n",
        "              fp = fidelity_positive_neighbors(model, data, nid, neighbors, scores, k_frac=k_frac)\n",
        "              fn = fidelity_negative_neighbors(model, data, nid, neighbors, scores, k_frac=k_frac)\n",
        "            sp = neighbor_sparsity(neighbors, scores, rel_thresh=rel_thresh)\n",
        "\n",
        "            results[name][\"fid_pos\"].append(fp)\n",
        "            results[name][\"fid_neg\"].append(fn)\n",
        "            results[name][\"sparsity\"].append(sp)\n",
        "            results[name][\"time\"].append(t)\n",
        "    summary = {}\n",
        "    for name, metrics in results.items():\n",
        "        if len(metrics[\"fid_pos\"]) == 0:\n",
        "            continue\n",
        "        summary[name] = {\n",
        "            \"fid_pos_mean\": float(torch.tensor(metrics[\"fid_pos\"]).mean().item()),\n",
        "            \"fid_neg_mean\": float(torch.tensor(metrics[\"fid_neg\"]).mean().item()),\n",
        "            \"sparsity_mean\": float(torch.tensor(metrics[\"sparsity\"]).mean().item()),\n",
        "            \"time\": float(torch.tensor(metrics[\"time\"]).mean().item()),\n",
        "        }\n",
        "    return summary\n",
        "\n",
        "# Run evaluation on test nodes\n",
        "test_nodes = data.test_mask.nonzero(as_tuple=False).view(-1)\n",
        "print(f\"\\nEvaluating methods on {len(test_nodes)} test nodes...\")\n",
        "summary = evaluate_methods_on_nodes(test_nodes, k_frac=0.5, rel_thresh=0.1)\n",
        "\n",
        "print(\"\\nMean metrics over test nodes (k_frac = 0.5, keep/drop 50% neighbors):\")\n",
        "for name, mets in summary.items():\n",
        "    print(f\"{name:12s} | \"\n",
        "          f\"Fid+={mets['fid_pos_mean']:.3f}  \"\n",
        "          f\"Fid-={mets['fid_neg_mean']:.3f}  \"\n",
        "          f\"Sparsity={mets['sparsity_mean']:.3f}\"\n",
        "          f\"Time={mets['time']:.6f}\")\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "jt6Hv4bqubrz",
        "outputId": "5796318f-f396-4237-b5d5-dc69eda3306a",
        "collapsed": true
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/tmp/ipython-input-2299389920.py:29: UserWarning: 'BAShapes' is deprecated, use 'datasets.ExplainerDataset' in combination with 'datasets.graph_generator.BAGraph' instead\n",
            "  dataset = BAShapes()   # synthetic, generated in memory\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Data(x=[700, 10], edge_index=[2, 3958], y=[700], expl_mask=[700], edge_label=[3958])\n",
            "#nodes = 700, #edges = 3958\n",
            "#features = 10, #classes = 4\n",
            "\n",
            "Training GCN on BA-Shapes...\n",
            "Epoch 001 | Loss 1.3874 | Train 0.431 | Val 0.379 | Test 0.471\n",
            "Epoch 020 | Loss 1.2709 | Train 0.431 | Val 0.379 | Test 0.471\n",
            "Epoch 040 | Loss 1.2648 | Train 0.431 | Val 0.379 | Test 0.471\n",
            "Epoch 060 | Loss 1.2697 | Train 0.431 | Val 0.379 | Test 0.471\n",
            "Epoch 080 | Loss 1.2616 | Train 0.431 | Val 0.379 | Test 0.471\n",
            "Epoch 100 | Loss 1.2316 | Train 0.431 | Val 0.379 | Test 0.471\n",
            "Epoch 120 | Loss 1.1982 | Train 0.431 | Val 0.371 | Test 0.471\n",
            "Epoch 140 | Loss 1.1517 | Train 0.431 | Val 0.364 | Test 0.471\n",
            "Epoch 160 | Loss 1.0674 | Train 0.431 | Val 0.364 | Test 0.471\n",
            "Epoch 180 | Loss 1.0406 | Train 0.440 | Val 0.393 | Test 0.471\n",
            "Epoch 200 | Loss 1.0431 | Train 0.581 | Val 0.529 | Test 0.636\n",
            "Epoch 220 | Loss 0.9261 | Train 0.667 | Val 0.650 | Test 0.686\n",
            "Epoch 240 | Loss 0.9805 | Train 0.857 | Val 0.836 | Test 0.850\n",
            "Epoch 260 | Loss 0.9530 | Train 0.860 | Val 0.843 | Test 0.850\n",
            "Epoch 280 | Loss 0.9340 | Train 0.862 | Val 0.836 | Test 0.843\n",
            "Epoch 300 | Loss 0.8349 | Train 0.860 | Val 0.843 | Test 0.850\n",
            "Epoch 320 | Loss 0.8361 | Train 0.864 | Val 0.836 | Test 0.829\n",
            "Epoch 340 | Loss 0.7785 | Train 0.862 | Val 0.829 | Test 0.829\n",
            "Epoch 360 | Loss 0.8525 | Train 0.740 | Val 0.721 | Test 0.671\n",
            "Epoch 380 | Loss 0.7738 | Train 0.824 | Val 0.786 | Test 0.757\n",
            "Epoch 400 | Loss 0.7748 | Train 0.864 | Val 0.836 | Test 0.843\n",
            "Epoch 420 | Loss 0.7372 | Train 0.852 | Val 0.814 | Test 0.821\n",
            "Epoch 440 | Loss 0.7326 | Train 0.831 | Val 0.793 | Test 0.764\n",
            "Epoch 460 | Loss 0.7220 | Train 0.864 | Val 0.821 | Test 0.814\n",
            "Epoch 480 | Loss 0.7386 | Train 0.867 | Val 0.836 | Test 0.829\n",
            "Epoch 500 | Loss 0.7728 | Train 0.855 | Val 0.814 | Test 0.821\n",
            "Epoch 520 | Loss 0.7112 | Train 0.864 | Val 0.836 | Test 0.843\n",
            "Epoch 540 | Loss 0.6945 | Train 0.862 | Val 0.836 | Test 0.850\n",
            "Epoch 560 | Loss 0.6710 | Train 0.867 | Val 0.829 | Test 0.814\n",
            "Epoch 580 | Loss 0.7036 | Train 0.864 | Val 0.829 | Test 0.829\n",
            "Epoch 600 | Loss 0.7044 | Train 0.864 | Val 0.836 | Test 0.850\n",
            "Epoch 620 | Loss 0.7070 | Train 0.855 | Val 0.821 | Test 0.829\n",
            "Epoch 640 | Loss 0.7140 | Train 0.821 | Val 0.786 | Test 0.843\n",
            "Epoch 660 | Loss 0.7102 | Train 0.843 | Val 0.800 | Test 0.779\n",
            "Epoch 680 | Loss 0.6986 | Train 0.845 | Val 0.800 | Test 0.800\n",
            "Epoch 700 | Loss 0.6706 | Train 0.864 | Val 0.821 | Test 0.821\n",
            "Epoch 720 | Loss 0.6974 | Train 0.840 | Val 0.793 | Test 0.771\n",
            "Epoch 740 | Loss 0.6828 | Train 0.864 | Val 0.836 | Test 0.843\n",
            "Epoch 760 | Loss 0.6474 | Train 0.840 | Val 0.793 | Test 0.771\n",
            "Epoch 780 | Loss 0.6463 | Train 0.819 | Val 0.771 | Test 0.750\n",
            "Epoch 800 | Loss 0.6625 | Train 0.821 | Val 0.771 | Test 0.750\n",
            "Epoch 820 | Loss 0.6452 | Train 0.864 | Val 0.829 | Test 0.829\n",
            "Epoch 840 | Loss 0.6505 | Train 0.860 | Val 0.821 | Test 0.836\n",
            "Epoch 860 | Loss 0.6135 | Train 0.867 | Val 0.829 | Test 0.850\n",
            "Epoch 880 | Loss 0.6743 | Train 0.814 | Val 0.750 | Test 0.743\n",
            "Epoch 900 | Loss 0.6462 | Train 0.864 | Val 0.814 | Test 0.821\n",
            "Epoch 920 | Loss 0.6436 | Train 0.838 | Val 0.793 | Test 0.771\n",
            "Epoch 940 | Loss 0.6482 | Train 0.821 | Val 0.786 | Test 0.829\n",
            "Epoch 960 | Loss 0.6344 | Train 0.817 | Val 0.750 | Test 0.743\n",
            "Epoch 980 | Loss 0.6732 | Train 0.845 | Val 0.793 | Test 0.779\n",
            "Epoch 1000 | Loss 0.6686 | Train 0.864 | Val 0.829 | Test 0.850\n",
            "\n",
            "Mean |Z2_decomp - logits_model| = 0.000000 (should be small-ish)\n",
            "\n",
            "Evaluating methods on 140 test nodes...\n",
            "\n",
            "Mean metrics over test nodes (k_frac = 0.5, keep/drop 50% neighbors):\n",
            "Decomp       | Fid+=0.390  Fid-=0.066  Sparsity=0.268\n",
            "GNNExplainer | Fid+=0.164  Fid-=0.130  Sparsity=0.000\n",
            "IG           | Fid+=0.407  Fid-=0.024  Sparsity=0.000\n",
            "Random       | Fid+=0.112  Fid-=0.172  Sparsity=0.072\n",
            "GOAT         | Fid+=0.407  Fid-=0.024  Sparsity=0.000\n",
            "GradCAM      | Fid+=0.361  Fid-=0.083  Sparsity=0.516\n",
            "GNN-LRP      | Fid+=0.344  Fid-=0.087  Sparsity=0.169\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "for name, mets in summary.items():\n",
        "    print(f\"{name:12s} | \"\n",
        "          f\"Fid+={mets['fid_pos_mean']:.3f}  \"\n",
        "          f\"Fid-={mets['fid_neg_mean']:.3f}  \"\n",
        "          f\"Sparsity={mets['sparsity_mean']:.3f}\"\n",
        "          f\"Time={mets['time']:.6f}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "cV8I1T1xtkO8",
        "outputId": "65f2b010-44c7-47fa-af66-2245c9920381"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Decomp       | Fid+=0.390  Fid-=0.066  Sparsity=0.268Time=0.003105\n",
            "GNNExplainer | Fid+=0.164  Fid-=0.130  Sparsity=0.000Time=0.715819\n",
            "IG           | Fid+=0.407  Fid-=0.024  Sparsity=0.000Time=0.227541\n",
            "Random       | Fid+=0.112  Fid-=0.172  Sparsity=0.072Time=0.000231\n",
            "GOAT         | Fid+=0.407  Fid-=0.024  Sparsity=0.000Time=0.008780\n",
            "GradCAM      | Fid+=0.361  Fid-=0.083  Sparsity=0.516Time=0.004427\n",
            "GNN-LRP      | Fid+=0.344  Fid-=0.087  Sparsity=0.169Time=0.002405\n"
          ]
        }
      ]
    }
  ]
}