{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "collapsed": true,
    "id": "9b67R7ry2srl",
    "outputId": "6a604cf2-cbe2-4210-d0d2-b64306c74ff8"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting torch-scatter\n",
      "  Downloading torch_scatter-2.1.2.tar.gz (108 kB)\n",
      "\u001b[?25l     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/108.0 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m108.0/108.0 kB\u001b[0m \u001b[31m3.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25h  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
      "Building wheels for collected packages: torch-scatter\n",
      "  Building wheel for torch-scatter (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
      "  Created wheel for torch-scatter: filename=torch_scatter-2.1.2-cp312-cp312-linux_x86_64.whl size=640889 sha256=74e105693f5df32e81fa68415b791a2819673141e64211f6bc723f8d0d6c28b1\n",
      "  Stored in directory: /root/.cache/pip/wheels/84/20/50/44800723f57cd798630e77b3ec83bc80bd26a1e3dc3a672ef5\n",
      "Successfully built torch-scatter\n",
      "Installing collected packages: torch-scatter\n",
      "Successfully installed torch-scatter-2.1.2\n",
      "Collecting torch_geometric\n",
      "  Downloading torch_geometric-2.7.0-py3-none-any.whl.metadata (63 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m63.7/63.7 kB\u001b[0m \u001b[31m2.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: aiohttp in /usr/local/lib/python3.12/dist-packages (from torch_geometric) (3.13.2)\n",
      "Requirement already satisfied: fsspec in /usr/local/lib/python3.12/dist-packages (from torch_geometric) (2025.3.0)\n",
      "Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch_geometric) (3.1.6)\n",
      "Requirement already satisfied: numpy in /usr/local/lib/python3.12/dist-packages (from torch_geometric) (2.0.2)\n",
      "Requirement already satisfied: psutil>=5.8.0 in /usr/local/lib/python3.12/dist-packages (from torch_geometric) (5.9.5)\n",
      "Requirement already satisfied: pyparsing in /usr/local/lib/python3.12/dist-packages (from torch_geometric) (3.2.5)\n",
      "Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from torch_geometric) (2.32.4)\n",
      "Requirement already satisfied: tqdm in /usr/local/lib/python3.12/dist-packages (from torch_geometric) (4.67.1)\n",
      "Requirement already satisfied: xxhash in /usr/local/lib/python3.12/dist-packages (from torch_geometric) (3.6.0)\n",
      "Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp->torch_geometric) (2.6.1)\n",
      "Requirement already satisfied: aiosignal>=1.4.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp->torch_geometric) (1.4.0)\n",
      "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp->torch_geometric) (25.4.0)\n",
      "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.12/dist-packages (from aiohttp->torch_geometric) (1.8.0)\n",
      "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.12/dist-packages (from aiohttp->torch_geometric) (6.7.0)\n",
      "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp->torch_geometric) (0.4.1)\n",
      "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp->torch_geometric) (1.22.0)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch_geometric) (3.0.3)\n",
      "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->torch_geometric) (3.4.4)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests->torch_geometric) (3.11)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests->torch_geometric) (2.5.0)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests->torch_geometric) (2025.10.5)\n",
      "Requirement already satisfied: typing-extensions>=4.2 in /usr/local/lib/python3.12/dist-packages (from aiosignal>=1.4.0->aiohttp->torch_geometric) (4.15.0)\n",
      "Downloading torch_geometric-2.7.0-py3-none-any.whl (1.3 MB)\n",
      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m23.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hInstalling collected packages: torch_geometric\n",
      "Successfully installed torch_geometric-2.7.0\n",
      "Collecting captum\n",
      "  Downloading captum-0.8.0-py3-none-any.whl.metadata (26 kB)\n",
      "Requirement already satisfied: matplotlib in /usr/local/lib/python3.12/dist-packages (from captum) (3.10.0)\n",
      "Collecting numpy<2.0 (from captum)\n",
      "  Downloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m61.0/61.0 kB\u001b[0m \u001b[31m2.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from captum) (25.0)\n",
      "Requirement already satisfied: torch>=1.10 in /usr/local/lib/python3.12/dist-packages (from captum) (2.8.0+cu126)\n",
      "Requirement already satisfied: tqdm in /usr/local/lib/python3.12/dist-packages (from captum) (4.67.1)\n",
      "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (3.20.0)\n",
      "Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (4.15.0)\n",
      "Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (75.2.0)\n",
      "Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (1.13.3)\n",
      "Requirement already satisfied: networkx in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (3.5)\n",
      "Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (3.1.6)\n",
      "Requirement already satisfied: fsspec in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (2025.3.0)\n",
      "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (12.6.77)\n",
      "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (12.6.77)\n",
      "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (12.6.80)\n",
      "Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (9.10.2.21)\n",
      "Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (12.6.4.1)\n",
      "Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (11.3.0.4)\n",
      "Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (10.3.7.77)\n",
      "Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (11.7.1.2)\n",
      "Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (12.5.4.2)\n",
      "Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (0.7.1)\n",
      "Requirement already satisfied: nvidia-nccl-cu12==2.27.3 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (2.27.3)\n",
      "Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (12.6.77)\n",
      "Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (12.6.85)\n",
      "Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (1.11.1.6)\n",
      "Requirement already satisfied: triton==3.4.0 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10->captum) (3.4.0)\n",
      "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib->captum) (1.3.3)\n",
      "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.12/dist-packages (from matplotlib->captum) (0.12.1)\n",
      "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib->captum) (4.60.1)\n",
      "Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib->captum) (1.4.9)\n",
      "Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.12/dist-packages (from matplotlib->captum) (11.3.0)\n",
      "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib->captum) (3.2.5)\n",
      "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.12/dist-packages (from matplotlib->captum) (2.9.0.post0)\n",
      "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.12/dist-packages (from python-dateutil>=2.7->matplotlib->captum) (1.17.0)\n",
      "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch>=1.10->captum) (1.3.0)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch>=1.10->captum) (3.0.3)\n",
      "Downloading captum-0.8.0-py3-none-any.whl (1.4 MB)\n",
      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.4/1.4 MB\u001b[0m \u001b[31m25.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hDownloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.0 MB)\n",
      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m18.0/18.0 MB\u001b[0m \u001b[31m130.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hInstalling collected packages: numpy, captum\n",
      "  Attempting uninstall: numpy\n",
      "    Found existing installation: numpy 2.0.2\n",
      "    Uninstalling numpy-2.0.2:\n",
      "      Successfully uninstalled numpy-2.0.2\n",
      "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
      "opencv-python-headless 4.12.0.88 requires numpy<2.3.0,>=2; python_version >= \"3.9\", but you have numpy 1.26.4 which is incompatible.\n",
      "shap 0.50.0 requires numpy>=2, but you have numpy 1.26.4 which is incompatible.\n",
      "opencv-python 4.12.0.88 requires numpy<2.3.0,>=2; python_version >= \"3.9\", but you have numpy 1.26.4 which is incompatible.\n",
      "jax 0.7.2 requires numpy>=2.0, but you have numpy 1.26.4 which is incompatible.\n",
      "jaxlib 0.7.2 requires numpy>=2.0, but you have numpy 1.26.4 which is incompatible.\n",
      "pytensor 2.35.1 requires numpy>=2.0, but you have numpy 1.26.4 which is incompatible.\n",
      "opencv-contrib-python 4.12.0.88 requires numpy<2.3.0,>=2; python_version >= \"3.9\", but you have numpy 1.26.4 which is incompatible.\u001b[0m\u001b[31m\n",
      "\u001b[0mSuccessfully installed captum-0.8.0 numpy-1.26.4\n"
     ]
    },
    {
     "data": {
      "application/vnd.colab-display-data+json": {
       "id": "9fdcdf65ab0f4d318ccf4cd34d61a9c4",
       "pip_warning": {
        "packages": [
         "numpy"
        ]
       }
      }
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#install\n",
    "!pip install torch-scatter\n",
    "!pip install torch_geometric\n",
    "!pip install captum"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "collapsed": true,
    "id": "nxYB7hM3ajIl",
    "outputId": "aefe3946-1df7-410b-f60e-2d21e8e3c3e7"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting lime\n",
      "  Downloading lime-0.2.0.1.tar.gz (275 kB)\n",
      "\u001b[?25l     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/275.7 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K     \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[91m╸\u001b[0m\u001b[90m━\u001b[0m \u001b[32m266.2/275.7 kB\u001b[0m \u001b[31m8.9 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m275.7/275.7 kB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25h  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
      "Requirement already satisfied: matplotlib in /usr/local/lib/python3.12/dist-packages (from lime) (3.10.0)\n",
      "Requirement already satisfied: numpy in /usr/local/lib/python3.12/dist-packages (from lime) (1.26.4)\n",
      "Requirement already satisfied: scipy in /usr/local/lib/python3.12/dist-packages (from lime) (1.16.3)\n",
      "Requirement already satisfied: tqdm in /usr/local/lib/python3.12/dist-packages (from lime) (4.67.1)\n",
      "Requirement already satisfied: scikit-learn>=0.18 in /usr/local/lib/python3.12/dist-packages (from lime) (1.6.1)\n",
      "Requirement already satisfied: scikit-image>=0.12 in /usr/local/lib/python3.12/dist-packages (from lime) (0.25.2)\n",
      "Requirement already satisfied: networkx>=3.0 in /usr/local/lib/python3.12/dist-packages (from scikit-image>=0.12->lime) (3.5)\n",
      "Requirement already satisfied: pillow>=10.1 in /usr/local/lib/python3.12/dist-packages (from scikit-image>=0.12->lime) (11.3.0)\n",
      "Requirement already satisfied: imageio!=2.35.0,>=2.33 in /usr/local/lib/python3.12/dist-packages (from scikit-image>=0.12->lime) (2.37.2)\n",
      "Requirement already satisfied: tifffile>=2022.8.12 in /usr/local/lib/python3.12/dist-packages (from scikit-image>=0.12->lime) (2025.10.16)\n",
      "Requirement already satisfied: packaging>=21 in /usr/local/lib/python3.12/dist-packages (from scikit-image>=0.12->lime) (25.0)\n",
      "Requirement already satisfied: lazy-loader>=0.4 in /usr/local/lib/python3.12/dist-packages (from scikit-image>=0.12->lime) (0.4)\n",
      "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn>=0.18->lime) (1.5.2)\n",
      "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn>=0.18->lime) (3.6.0)\n",
      "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib->lime) (1.3.3)\n",
      "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.12/dist-packages (from matplotlib->lime) (0.12.1)\n",
      "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib->lime) (4.60.1)\n",
      "Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib->lime) (1.4.9)\n",
      "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib->lime) (3.2.5)\n",
      "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.12/dist-packages (from matplotlib->lime) (2.9.0.post0)\n",
      "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.12/dist-packages (from python-dateutil>=2.7->matplotlib->lime) (1.17.0)\n",
      "Building wheels for collected packages: lime\n",
      "  Building wheel for lime (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
      "  Created wheel for lime: filename=lime-0.2.0.1-py3-none-any.whl size=283834 sha256=b2e57617172b76c04781fc3702eec4a59f343a07a5bca86fd36904ed7ee0895b\n",
      "  Stored in directory: /root/.cache/pip/wheels/e7/5d/0e/4b4fff9a47468fed5633211fb3b76d1db43fe806a17fb7486a\n",
      "Successfully built lime\n",
      "Installing collected packages: lime\n",
      "Successfully installed lime-0.2.0.1\n"
     ]
    }
   ],
   "source": [
    "!pip install lime\n",
    "#!git clone https://github.com/WilliamCCHuang/GraphLIME.git"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "jouMiG7p7FfS",
    "outputId": "61c61ca2-20fd-4543-bd80-48cbbd814054"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.7.0\n"
     ]
    }
   ],
   "source": [
    "!python -c \"import torch_geometric; print(torch_geometric.__version__)\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#optimized\n",
    "\n",
    "import torch\n",
    "from torch_geometric.utils import k_hop_subgraph\n",
    "\n",
    "@torch.no_grad()\n",
    "def build_feature_importance_decomposition_fast(\n",
    "    data,\n",
    "    model,\n",
    "    adj_list,            # list of list[(nbr, weight)] with the SAME normalized A (incl. self-loops) used by the model\n",
    "    nodes_to_explain,\n",
    "    hop=2,               # keep 2-hop exact\n",
    "    use_true_class=True,\n",
    "):\n",
    "    \"\"\"\n",
    "    Exact 2-hop decomposition, optimized:\n",
    "      - No dense A_sub per node\n",
    "      - Reuses global M1 (first-layer ReLU mask)\n",
    "      - Accumulates U[u,k] via adjacency lists with vectorized index_add\n",
    "    Returns:\n",
    "        feat_self, feat_self_1hop, feat_self_2hop   [num_nodes, num_features]\n",
    "    \"\"\"\n",
    "    device = next(model.parameters()).device\n",
    "    X = data.x.to(device)                                  # [N, F]\n",
    "    N, F = X.size()\n",
    "\n",
    "    # Forward once to get target classes (logits on the model's own pipeline)\n",
    "    logits = model(data.x, data.edge_index)                # [N, C]\n",
    "    preds = logits.argmax(dim=-1)\n",
    "    target_class = data.y if use_true_class else preds     # [N]\n",
    "\n",
    "    # Extract weights (assuming GCNConv with .lin.weight and optional bias)\n",
    "    W1 = model.conv1.lin.weight.T.to(device)               # [F, H]\n",
    "    W2 = model.conv2.lin.weight.T.to(device)               # [H, C]\n",
    "    b1 = model.conv1.bias.to(device) if model.conv1.bias is not None else torch.zeros(W1.size(1), device=device)\n",
    "\n",
    "    H = W1.size(1)\n",
    "    C = W2.size(1)\n",
    "\n",
    "    # ---- Precompute global first-layer activations and ReLU mask M1 ----\n",
    "    # IMPORTANT: Must use the SAME normalized adjacency used to build adj_list (incl. self-loops!)\n",
    "    # We'll realize A @ (X @ W1) using adj_list once, row by row, to avoid dense A.\n",
    "    XW1 = X @ W1                                           # [N, H]\n",
    "    Z1 = torch.zeros(N, H, device=device)\n",
    "    # Z1[i] = sum_{j in nbrs(i)} A[i,j] * XW1[j] + b1\n",
    "    # Vectorized per-row accumulation via index_add:\n",
    "    # Build COO-like buffers once from adj_list\n",
    "    rows, cols, vals = [], [], []\n",
    "    for i in range(N):\n",
    "        for (j, w_ij) in adj_list[i]:\n",
    "            rows.append(i); cols.append(j); vals.append(w_ij)\n",
    "    rows = torch.tensor(rows, dtype=torch.long, device=device)\n",
    "    cols = torch.tensor(cols, dtype=torch.long, device=device)\n",
    "    vals = torch.tensor(vals, dtype=XW1.dtype, device=device)\n",
    "\n",
    "    # Accumulate A @ (XW1): for each edge (i,j) add w_ij * XW1[j] into Z1[i]\n",
    "    Z1.index_add_(0, rows, XW1[cols] * vals.unsqueeze(1))\n",
    "    Z1 += b1                                               # broadcast [H]\n",
    "    M1 = (Z1 > 0).to(X.dtype)                              # [N, H]\n",
    "\n",
    "    # Allocate outputs\n",
    "    feat_self      = torch.zeros(N, F, device=device)\n",
    "    feat_self_1hop = torch.zeros(N, F, device=device)\n",
    "    feat_self_2hop = torch.zeros(N, F, device=device)\n",
    "\n",
    "    # Helper to get 1-hop neighbors quickly as tensors\n",
    "    def one_hop_tensors(v):\n",
    "        nbrs, w = zip(*adj_list[v]) if len(adj_list[v]) > 0 else ([], [])\n",
    "        if len(nbrs) == 0:\n",
    "            return torch.empty(0, dtype=torch.long, device=device), torch.empty(0, dtype=X.dtype, device=device)\n",
    "        return torch.tensor(nbrs, dtype=torch.long, device=device), torch.tensor(w, dtype=X.dtype, device=device)\n",
    "\n",
    "    num_targets = len(nodes_to_explain)\n",
    "    print(f\"[fast] Building exact 2-hop decomposition for {num_targets} nodes...\")\n",
    "\n",
    "    for i, v in enumerate(nodes_to_explain.tolist(), start=1):\n",
    "        v = int(v)\n",
    "        c_idx = int(target_class[v].item())\n",
    "\n",
    "        # 2-hop ego nodes (relabel_nodes=False to retain global ids)\n",
    "        nodes_sub, _, _, _ = k_hop_subgraph(v, hop, data.edge_index, relabel_nodes=False)\n",
    "        nodes_sub = nodes_sub.to(device)\n",
    "        n_sub = nodes_sub.numel()\n",
    "        # map global id -> local index\n",
    "        mapping = {int(nid): idx for idx, nid in enumerate(nodes_sub.tolist())}\n",
    "\n",
    "        # ---- Build U_sub[u_local, k] = sum_j A[v,j] * M1[j,k] * A[j,u], u in nodes_sub ----\n",
    "        U_sub = torch.zeros(n_sub, H, device=device)\n",
    "\n",
    "        # 1-hop of v (j’s) and their weights A[v,j]\n",
    "        J, w_vj = one_hop_tensors(v)                       # [deg(v)], [deg(v)]\n",
    "        if J.numel() > 0:\n",
    "            # For each j in J:\n",
    "            #   alpha_j[k] = A[v,j] * M1[j,k]\n",
    "            # We’ll accumulate over u in nbrs(j) using adj_list[j].\n",
    "            # Do it in a vectorized way via index_add.\n",
    "            for idx_j in range(J.numel()):\n",
    "                j = int(J[idx_j].item())\n",
    "                wvj = w_vj[idx_j]                          # scalar\n",
    "                alpha_j = wvj * M1[j]                      # [H]\n",
    "\n",
    "                # neighbors u of j with weights A[j,u]\n",
    "                nbr_u, w_ju = one_hop_tensors(j)           # [deg(j)]\n",
    "                if nbr_u.numel() == 0:\n",
    "                    continue\n",
    "\n",
    "                # Filter u to the 2-hop ego only\n",
    "                # Build local indices for u ∈ nodes_sub\n",
    "                keep = []\n",
    "                local_idx = []\n",
    "                weights = []\n",
    "                for t in range(nbr_u.numel()):\n",
    "                    u_global = int(nbr_u[t].item())\n",
    "                    if u_global in mapping:\n",
    "                        keep.append(t)\n",
    "                        local_idx.append(mapping[u_global])\n",
    "                        weights.append(float(w_ju[t].item()))\n",
    "                if len(local_idx) == 0:\n",
    "                    continue\n",
    "\n",
    "                local_idx = torch.tensor(local_idx, dtype=torch.long, device=device)\n",
    "                w_ju_kept = torch.tensor(weights, dtype=X.dtype, device=device)  # [m]\n",
    "                # Accumulate: U_sub[local_idx, :] += (w_ju_kept[:,None] * alpha_j[None,:])\n",
    "                U_sub.index_add_(0, local_idx, w_ju_kept.unsqueeze(1) * alpha_j.unsqueeze(0))\n",
    "\n",
    "        # ---- Convert U_sub to per-(node,feature) contribution via T = W1 ⊙ w_kc^T ----\n",
    "        w_kc = W2[:, c_idx]                                 # [H]\n",
    "        T = W1 * w_kc.unsqueeze(0)                          # [F, H]\n",
    "        # Φ[u,f] = sum_k U_sub[u,k] * T[f,k]\n",
    "        Phi_sub = U_sub @ T.T                               # [n_sub, F]\n",
    "        X_sub = X[nodes_sub]                                # [n_sub, F]\n",
    "        C_sub = Phi_sub * X_sub                             # [n_sub, F]; exact contributions to Z2[v,c]\n",
    "\n",
    "        # ---- Aggregate by hop masks (self / 1-hop / 2-hop) ----\n",
    "        # self\n",
    "        mask_self = (nodes_sub == v)                        # [n_sub]\n",
    "        # 1-hop\n",
    "        nbrs_v = set(int(n) for n in J.tolist())           # already have 1-hop of v\n",
    "        mask_1hop = torch.tensor([ (int(nid.item()) in nbrs_v) for nid in nodes_sub ],\n",
    "                                 dtype=torch.bool, device=device)\n",
    "        # 2-hop (remaining)\n",
    "        mask_2hop = ~(mask_self | mask_1hop)\n",
    "\n",
    "        self_contrib = C_sub[mask_self].sum(dim=0) if mask_self.any() else torch.zeros(F, device=device)\n",
    "        hop1_contrib = C_sub[mask_1hop].sum(dim=0) if mask_1hop.any() else torch.zeros(F, device=device)\n",
    "        hop2_contrib = C_sub[mask_2hop].sum(dim=0) if mask_2hop.any() else torch.zeros(F, device=device)\n",
    "\n",
    "        feat_self[v]      = self_contrib\n",
    "        feat_self_1hop[v] = self_contrib + hop1_contrib\n",
    "        feat_self_2hop[v] = self_contrib + hop1_contrib + hop2_contrib\n",
    "\n",
    "        if i % 20 == 0 or i == 1 or i == num_targets:\n",
    "            print(f\"  processed {i}/{num_targets} nodes\", flush=True)\n",
    "\n",
    "    return feat_self, feat_self_1hop, feat_self_2hop"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "CT0NAmSY80h-"
   },
   "source": [
    "## AMZ-COMPUTER"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "68a316yJ2lfT",
    "outputId": "9458b6a2-1d34-43a8-b42e-fbf33ec0e37b"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=== Amazon-Computers ===\n",
      "#Nodes     = 13752\n",
      "#Edges     = 491722\n",
      "#Features  = 767\n",
      "#Classes   = 10\n",
      "#Train nodes = 9621\n",
      "#Val nodes   = 1370\n",
      "#Test nodes  = 2761\n",
      "\n",
      "=== Loading existing model from amazon_computer_gcn.pt ===\n",
      "Loaded model | Train=0.6363 Val=0.6314 Test=0.6338\n",
      "\n",
      "=== Building normalized adjacency list ===\n",
      "\n",
      "=== Explaining 100 test nodes (pred class as target) ===\n",
      "Building feature importance via decomposition for 100 nodes...\n",
      "  processed 1/100 nodes\n",
      "  processed 20/100 nodes\n",
      "  processed 40/100 nodes\n",
      "  processed 60/100 nodes\n",
      "  processed 80/100 nodes\n",
      "  processed 100/100 nodes\n",
      "\n",
      "=== Building RANDOM importance baseline ===\n",
      "\n",
      "=== Building GRAD-Cam ===\n",
      "\n",
      "=== Building GNN-LRP ===\n",
      "Running GNN-LRP for 100 nodes...\n",
      "  GNN-LRP processed 1/100 nodes\n",
      "  GNN-LRP processed 10/100 nodes\n",
      "  GNN-LRP processed 20/100 nodes\n",
      "  GNN-LRP processed 30/100 nodes\n",
      "  GNN-LRP processed 40/100 nodes\n",
      "  GNN-LRP processed 50/100 nodes\n",
      "  GNN-LRP processed 60/100 nodes\n",
      "  GNN-LRP processed 70/100 nodes\n",
      "  GNN-LRP processed 80/100 nodes\n",
      "  GNN-LRP processed 90/100 nodes\n",
      "  GNN-LRP processed 100/100 nodes\n",
      "\n",
      "=== Building GOAT ===\n",
      "Running GOAT for 100 nodes...\n",
      "  GOAT processed 1/100 nodes\n",
      "  GOAT processed 10/100 nodes\n",
      "  GOAT processed 20/100 nodes\n",
      "  GOAT processed 30/100 nodes\n",
      "  GOAT processed 40/100 nodes\n",
      "  GOAT processed 50/100 nodes\n",
      "  GOAT processed 60/100 nodes\n",
      "  GOAT processed 70/100 nodes\n",
      "  GOAT processed 80/100 nodes\n",
      "  GOAT processed 90/100 nodes\n",
      "  GOAT processed 100/100 nodes\n",
      "\n",
      "=== Building LIME ===\n",
      "Running LIME for 100 nodes...\n",
      "  LIME processed 1/100 nodes\n",
      "  LIME processed 10/100 nodes\n",
      "  LIME processed 20/100 nodes\n",
      "  LIME processed 30/100 nodes\n",
      "  LIME processed 40/100 nodes\n",
      "  LIME processed 50/100 nodes\n",
      "  LIME processed 60/100 nodes\n",
      "  LIME processed 70/100 nodes\n",
      "  LIME processed 80/100 nodes\n",
      "  LIME processed 90/100 nodes\n",
      "  LIME processed 100/100 nodes\n",
      "\n",
      "=== Building Integrated Gradients (Captum) importance baseline ===\n",
      "Running Clean IG for 100 nodes...\n",
      "  IG processed 1/100 nodes\n",
      "  IG processed 10/100 nodes\n",
      "  IG processed 20/100 nodes\n",
      "  IG processed 30/100 nodes\n",
      "  IG processed 40/100 nodes\n",
      "  IG processed 50/100 nodes\n",
      "  IG processed 60/100 nodes\n",
      "  IG processed 70/100 nodes\n",
      "  IG processed 80/100 nodes\n",
      "  IG processed 90/100 nodes\n",
      "  IG processed 100/100 nodes\n",
      "\n",
      "=== Computing fidelity scores ===\n",
      "\n",
      "=== Fidelity scores for importance: SELF ONLY ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=38) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0351\n",
      "Fidelity- (remove BOTTOM-k)= -0.0000\n",
      "Keep-top (only TOP-k kept) = 0.0001\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=77) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0361\n",
      "Fidelity- (remove BOTTOM-k)= -0.0000\n",
      "Keep-top (only TOP-k kept) = -0.0011\n",
      "\n",
      "=== Fidelity scores for importance: SELF + 1HOP ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=38) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0353\n",
      "Fidelity- (remove BOTTOM-k)= 0.0000\n",
      "Keep-top (only TOP-k kept) = -0.0001\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=77) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0370\n",
      "Fidelity- (remove BOTTOM-k)= -0.0000\n",
      "Keep-top (only TOP-k kept) = -0.0019\n",
      "\n",
      "=== Fidelity scores for importance: SELF + 1HOP + 2HOP ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=38) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0353\n",
      "Fidelity- (remove BOTTOM-k)= -0.0000\n",
      "Keep-top (only TOP-k kept) = -0.0002\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=77) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0371\n",
      "Fidelity- (remove BOTTOM-k)= -0.0000\n",
      "Keep-top (only TOP-k kept) = -0.0021\n",
      "\n",
      "=== Fidelity scores for importance: RANDOM BASELINE ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=38) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0013\n",
      "Fidelity- (remove BOTTOM-k)= 0.0013\n",
      "Keep-top (only TOP-k kept) = 0.0324\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=77) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0023\n",
      "Fidelity- (remove BOTTOM-k)= 0.0025\n",
      "Keep-top (only TOP-k kept) = 0.0304\n",
      "\n",
      "=== Fidelity scores for importance: GradCAM-feature ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=38) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0201\n",
      "Fidelity- (remove BOTTOM-k)= 0.0002\n",
      "Keep-top (only TOP-k kept) = 0.0131\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=77) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0270\n",
      "Fidelity- (remove BOTTOM-k)= 0.0001\n",
      "Keep-top (only TOP-k kept) = 0.0097\n",
      "\n",
      "=== Fidelity scores for importance: GNN-LRP-feature ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=38) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0288\n",
      "Fidelity- (remove BOTTOM-k)= -0.0000\n",
      "Keep-top (only TOP-k kept) = 0.0067\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=77) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0322\n",
      "Fidelity- (remove BOTTOM-k)= -0.0000\n",
      "Keep-top (only TOP-k kept) = 0.0028\n",
      "\n",
      "=== Fidelity scores for importance: GOAT-feature ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=38) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0352\n",
      "Fidelity- (remove BOTTOM-k)= -0.0001\n",
      "Keep-top (only TOP-k kept) = -0.0001\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=77) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0384\n",
      "Fidelity- (remove BOTTOM-k)= -0.0002\n",
      "Keep-top (only TOP-k kept) = -0.0033\n",
      "\n",
      "=== Fidelity scores for importance: LIME ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=38) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0070\n",
      "Fidelity- (remove BOTTOM-k)= 0.0004\n",
      "Keep-top (only TOP-k kept) = 0.0273\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=77) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0094\n",
      "Fidelity- (remove BOTTOM-k)= 0.0008\n",
      "Keep-top (only TOP-k kept) = 0.0243\n",
      "\n",
      "=== Fidelity scores for importance: IG (Captum) BASELINE ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=38) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0164\n",
      "Fidelity- (remove BOTTOM-k)= -0.0002\n",
      "Keep-top (only TOP-k kept) = 0.0127\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=77) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0205\n",
      "Fidelity- (remove BOTTOM-k)= -0.0002\n",
      "Keep-top (only TOP-k kept) = 0.0092\n",
      "\n",
      "=== Summary: mean Δ = p_orig - p_mask ===\n",
      "\n",
      "Fidelity+ (remove top-k% features):\n",
      "k=0.05 | self=0.0351 | self+1hop=0.0353 | self+2hop=0.0353 | rand=0.0013 | gradcam=0.0201 | lrp=0.0288 | goat=0.0352 | lime=0.0070 | ig=0.0164 | \n",
      "k=0.10 | self=0.0361 | self+1hop=0.0370 | self+2hop=0.0371 | rand=0.0023 | gradcam=0.0270 | lrp=0.0322 | goat=0.0384 | lime=0.0094 | ig=0.0205 | \n",
      "\n",
      "Fidelity- (remove bottom-k% features):\n",
      "k=0.05 | self=-0.0000 | self+1hop=0.0000 | self+2hop=-0.0000 | rand=0.0013 | gradcam=0.0002 | lrp=-0.0000 | goat=-0.0001 | lime=0.0004 | ig=-0.0002 | \n",
      "k=0.10 | self=-0.0000 | self+1hop=-0.0000 | self+2hop=-0.0000 | rand=0.0025 | gradcam=0.0001 | lrp=-0.0000 | goat=-0.0002 | lime=0.0008 | ig=-0.0002 | \n",
      "\n",
      "=== Build-Time Summary (seconds) ===\n",
      "method                 build_time\n",
      "decomposition              45.782\n",
      "random                      0.040\n",
      "gradcam                     0.132\n",
      "lrp                         0.327\n",
      "goat                      220.752\n",
      "lime                      495.143\n",
      "ig                        611.970\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append(\"GraphLIME\")\n",
    "#from graphlime import GraphLIME\n",
    "from lime.lime_tabular import LimeTabularExplainer\n",
    "import time\n",
    "import argparse\n",
    "import random\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import os\n",
    "from torch_geometric.datasets import Amazon\n",
    "import torch_geometric.transforms as T\n",
    "from torch_geometric.nn import GCNConv\n",
    "from torch_geometric.utils import k_hop_subgraph\n",
    "from torch_geometric.explain import Explainer, GNNExplainer\n",
    "from torch_geometric.explain import Explainer, GNNExplainer, CaptumExplainer\n",
    "from captum.attr import IntegratedGradients\n",
    "\n",
    "# -------------------------------------------------------------\n",
    "# 0. Utils\n",
    "# -------------------------------------------------------------\n",
    "def set_seed(seed: int = 42):\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "\n",
    "def _now():\n",
    "    # sync CUDA to avoid async timing issues\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.synchronize()\n",
    "    return time.time()\n",
    "\n",
    "# -------------------------------------------------------------\n",
    "# 1. Dataset + splits\n",
    "# -------------------------------------------------------------\n",
    "def load_amazon_computers(root: str = \"data/Amazon\"):\n",
    "    dataset = Amazon(root=root, name=\"Computers\", transform=T.NormalizeFeatures())\n",
    "    data = dataset[0]\n",
    "\n",
    "    print(\"=== Amazon-Computers ===\")\n",
    "    print(f\"#Nodes     = {data.num_nodes}\")\n",
    "    print(f\"#Edges     = {data.num_edges}\")\n",
    "    print(f\"#Features  = {dataset.num_features}\")\n",
    "    print(f\"#Classes   = {dataset.num_classes}\")\n",
    "\n",
    "    if not hasattr(data, \"train_mask\"):\n",
    "        # data = create_splits(data, num_train_per_class=20, num_val_per_class=30)\n",
    "        data = create_splits_stratified(data, train_ratio=0.7, val_ratio=0.1, seed=1234)\n",
    "\n",
    "    return data, dataset.num_features, dataset.num_classes\n",
    "\n",
    "\n",
    "# def create_splits(data, num_train_per_class=20, num_val_per_class=30):\n",
    "#     y = data.y\n",
    "#     num_nodes = data.num_nodes\n",
    "#     num_classes = int(y.max().item() + 1)\n",
    "\n",
    "#     train_mask = torch.zeros(num_nodes, dtype=torch.bool)\n",
    "#     val_mask = torch.zeros(num_nodes, dtype=torch.bool)\n",
    "#     test_mask = torch.zeros(num_nodes, dtype=torch.bool)\n",
    "\n",
    "#     for c in range(num_classes):\n",
    "#         idx = (y == c).nonzero(as_tuple=False).view(-1)\n",
    "#         idx = idx[torch.randperm(idx.size(0))]\n",
    "\n",
    "#         n_train = min(num_train_per_class, idx.size(0))\n",
    "#         n_val = min(num_val_per_class, max(0, idx.size(0) - n_train))\n",
    "\n",
    "#         train_idx = idx[:n_train]\n",
    "#         val_idx = idx[n_train:n_train + n_val]\n",
    "#         test_idx = idx[n_train + n_val:]\n",
    "\n",
    "#         train_mask[train_idx] = True\n",
    "#         val_mask[val_idx] = True\n",
    "#         test_mask[test_idx] = True\n",
    "\n",
    "#     data.train_mask = train_mask\n",
    "#     data.val_mask = val_mask\n",
    "#     data.test_mask = test_mask\n",
    "\n",
    "#     print(f\"#Train nodes = {int(train_mask.sum())}\")\n",
    "#     print(f\"#Val nodes   = {int(val_mask.sum())}\")\n",
    "#     print(f\"#Test nodes  = {int(test_mask.sum())}\")\n",
    "\n",
    "#     return data\n",
    "\n",
    "def create_splits_stratified(data, train_ratio=0.7, val_ratio=0.1, seed=42):\n",
    "    torch.manual_seed(seed)\n",
    "\n",
    "    y = data.y\n",
    "    num_nodes = data.num_nodes\n",
    "    num_classes = int(y.max().item() + 1)\n",
    "\n",
    "    train_mask = torch.zeros(num_nodes, dtype=torch.bool)\n",
    "    val_mask   = torch.zeros(num_nodes, dtype=torch.bool)\n",
    "    test_mask  = torch.zeros(num_nodes, dtype=torch.bool)\n",
    "\n",
    "    for c in range(num_classes):\n",
    "        idx = (y == c).nonzero(as_tuple=False).view(-1)\n",
    "\n",
    "        if idx.numel() == 0:\n",
    "            continue\n",
    "\n",
    "        idx = idx[torch.randperm(idx.size(0))]\n",
    "\n",
    "        n_c = idx.size(0)\n",
    "        n_train = int(train_ratio * n_c)\n",
    "        n_val   = int(val_ratio * n_c)\n",
    "        n_test  = n_c - n_train - n_val\n",
    "\n",
    "        train_idx = idx[:n_train]\n",
    "        val_idx   = idx[n_train:n_train + n_val]\n",
    "        test_idx  = idx[n_train + n_val:]\n",
    "\n",
    "        train_mask[train_idx] = True\n",
    "        val_mask[val_idx]     = True\n",
    "        test_mask[test_idx]   = True\n",
    "\n",
    "    data.train_mask = train_mask\n",
    "    data.val_mask   = val_mask\n",
    "    data.test_mask  = test_mask\n",
    "\n",
    "    print(f\"#Train nodes = {int(train_mask.sum())}\")\n",
    "    print(f\"#Val nodes   = {int(val_mask.sum())}\")\n",
    "    print(f\"#Test nodes  = {int(test_mask.sum())}\")\n",
    "\n",
    "    return data\n",
    "\n",
    "\n",
    "\n",
    "# -------------------------------------------------------------\n",
    "# 2. 2-layer GCN model\n",
    "# -------------------------------------------------------------\n",
    "class GCN(torch.nn.Module):\n",
    "    def __init__(self, in_channels, hidden_channels, out_channels, dropout=0.5):\n",
    "        super().__init__()\n",
    "        self.conv1 = GCNConv(in_channels, hidden_channels, cached=True)\n",
    "        self.conv2 = GCNConv(hidden_channels, out_channels, cached=True)\n",
    "        self.dropout = dropout\n",
    "\n",
    "    def forward(self, x, edge_index):\n",
    "        x = self.conv1(x, edge_index)\n",
    "        x = F.relu(x)\n",
    "        x = F.dropout(x, p=self.dropout, training=self.training)\n",
    "        x = self.conv2(x, edge_index)\n",
    "        return x\n",
    "\n",
    "\n",
    "def train(model, data, optimizer):\n",
    "    model.train()\n",
    "    optimizer.zero_grad()\n",
    "    out = model(data.x, data.edge_index)\n",
    "    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    return loss.item()\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def evaluate(model, data):\n",
    "    model.eval()\n",
    "    logits = model(data.x, data.edge_index)\n",
    "    preds = logits.argmax(dim=-1)\n",
    "\n",
    "    def acc(mask):\n",
    "        total = int(mask.sum())\n",
    "        if total == 0:\n",
    "            return 0.0\n",
    "        return (preds[mask] == data.y[mask]).sum().item() / total\n",
    "\n",
    "    return acc(data.train_mask), acc(data.val_mask), acc(data.test_mask), logits\n",
    "\n",
    "\n",
    "# -------------------------------------------------------------\n",
    "# 3. Build normalized adjacency list from cached GCNConv\n",
    "#    (row = target node, col = source node)\n",
    "# -------------------------------------------------------------\n",
    "@torch.no_grad()\n",
    "def build_normalized_adjacency_list(model, data):\n",
    "    \"\"\"\n",
    "    Use conv1's cached normalized edge weights.\n",
    "    For each node i, store a list of (neighbor_j, weight_ij) where\n",
    "    weight_ij corresponds to A[i, j] in the matrix A @ (XW).\n",
    "    \"\"\"\n",
    "    model.eval()\n",
    "\n",
    "    # Ensure cache is filled\n",
    "    _ = model(data.x, data.edge_index)\n",
    "\n",
    "    edge_index, edge_weight = model.conv1._cached_edge_index\n",
    "    edge_index = edge_index.cpu()\n",
    "    edge_weight = edge_weight.cpu()\n",
    "\n",
    "    num_nodes = data.num_nodes\n",
    "    adj = [[] for _ in range(num_nodes)]  # adj[i] = list of (j, weight) s.t. A[i,j] = weight\n",
    "\n",
    "    # In PyG, edge_index[0] = source (j), edge_index[1] = target (i)\n",
    "    src = edge_index[0].tolist()\n",
    "    dst = edge_index[1].tolist()\n",
    "    w   = edge_weight.tolist()\n",
    "\n",
    "    for s, d, weight in zip(src, dst, w):\n",
    "        adj[d].append((s, weight))\n",
    "\n",
    "    return adj\n",
    "\n",
    "\n",
    "# -------------------------------------------------------------\n",
    "# 4. 2-layer GCN decomposition on an ego-subgraph\n",
    "# -------------------------------------------------------------\n",
    "@torch.no_grad()\n",
    "def decompose_two_layer_gcn_ego(\n",
    "    A_sub,          # [n_sub, n_sub] dense normalized adjacency\n",
    "    X_sub,          # [n_sub, F]\n",
    "    W1, b1,         # W1: [F, H], b1: [H]\n",
    "    W2, b2,         # W2: [H, C], b2: [C]\n",
    "    v_local,        # index of central node in subgraph [0..n_sub-1]\n",
    "    c_idx           # target class index\n",
    "):\n",
    "    \"\"\"\n",
    "    Exact decomposition of 2-layer GCN logit for node v and class c_idx.\n",
    "\n",
    "    Returns:\n",
    "        C: [n_sub, F] contributions from node-features -> logit of v,c_idx\n",
    "        logit_vc: scalar, logit of v,c_idx\n",
    "    \"\"\"\n",
    "    device = X_sub.device\n",
    "    n_sub, F = X_sub.shape\n",
    "    H = W1.shape[1]\n",
    "\n",
    "    # 1) Forward on subgraph\n",
    "    Z1 = A_sub @ (X_sub @ W1) + b1        # [n_sub, H]\n",
    "    M = (Z1 > 0).float()                  # ReLU mask\n",
    "    H1 = M * Z1                           # [n_sub, H]\n",
    "\n",
    "    Z2 = A_sub @ (H1 @ W2) + b2           # [n_sub, C]\n",
    "\n",
    "    # 2) Contributions C[u,f] to logit of v for class c_idx\n",
    "    # φ_{v,u,f} = sum_{j,k} A[v,j] * M[j,k] * A[j,u] * W1[f,k] * W2[k,c_idx]\n",
    "    # We compute this vectorized.\n",
    "\n",
    "    Av = A_sub[v_local]                   # [n_sub] == A[v, :]\n",
    "    # alpha[k,j] = A[v,j] * M[j,k]\n",
    "    alpha = (M.T * Av.unsqueeze(0))       # [H, n_sub]\n",
    "    # S[k, u] = sum_j alpha[k,j] * A[j,u]\n",
    "    S = alpha @ A_sub                     # [H, n_sub]\n",
    "    U = S.T                               # [n_sub, H]\n",
    "\n",
    "    w_kc = W2[:, c_idx]                   # [H]\n",
    "    T = W1 * w_kc.unsqueeze(0)            # [F, H]\n",
    "\n",
    "    # Φ[u,f] = sum_k U[u,k] * T[f,k]\n",
    "    Phi = U @ T.T                         # [n_sub, F]\n",
    "\n",
    "    # Contribution from each node-feature = Φ[u,f] * X_sub[u,f]\n",
    "    C = Phi * X_sub                       # [n_sub, F]\n",
    "\n",
    "    return C, Z2[v_local, c_idx]\n",
    "\n",
    "\n",
    "# @torch.no_grad()\n",
    "# def bfs_hop_dist(A_sub, v_local):\n",
    "#     \"\"\"\n",
    "#     Simple BFS on dense adjacency to get hop distances from v_local.\n",
    "#     A_sub assumed undirected (or symmetric).\n",
    "#     Returns: dist [n_sub] with values 0,1,2,... or large if unreachable.\n",
    "#     \"\"\"\n",
    "#     device = A_sub.device\n",
    "#     n_sub = A_sub.size(0)\n",
    "#     dist = torch.full((n_sub,), fill_value=10_000, dtype=torch.long, device=device)\n",
    "#     dist[v_local] = 0\n",
    "#     queue = [v_local]\n",
    "\n",
    "#     while queue:\n",
    "#         cur = queue.pop(0)\n",
    "#         cur_d = dist[cur].item()\n",
    "#         neighbors = (A_sub[cur] != 0).nonzero(as_tuple=False).view(-1)\n",
    "#         for nb in neighbors.tolist():\n",
    "#             if dist[nb].item() > cur_d + 1:\n",
    "#                 dist[nb] = cur_d + 1\n",
    "#                 queue.append(nb)\n",
    "\n",
    "#     return dist\n",
    "\n",
    "\n",
    "# # -------------------------------------------------------------\n",
    "# # 5. Build feature importance (self vs self+1hop vs self+2hop)\n",
    "# # -------------------------------------------------------------\n",
    "# @torch.no_grad()\n",
    "# def build_feature_importance_decomposition(\n",
    "#     data,\n",
    "#     model,\n",
    "#     adj_list,\n",
    "#     nodes_to_explain,\n",
    "#     hop=2,\n",
    "#     use_true_class=True,\n",
    "# ):\n",
    "#     \"\"\"\n",
    "#     For each node v in nodes_to_explain, compute three feature importance variants:\n",
    "\n",
    "#         self_only[v,f]        = |C[v,f]|        (u = v only)\n",
    "#         self_plus_1hop[v,f]   = |C[v,f]| + sum_{u:dist=1} |C[u,f]|\n",
    "#         self_plus_2hop[v,f]   = |C[v,f]| + sum_{u:dist<=2} |C[u,f]|\n",
    "\n",
    "#     where C[u,f] is contribution from node u's feature f to v's logit.\n",
    "\n",
    "#     Returns:\n",
    "#         feat_self:       [num_nodes, num_features]\n",
    "#         feat_self_1hop:  [num_nodes, num_features]\n",
    "#         feat_self_2hop:  [num_nodes, num_features]\n",
    "#     \"\"\"\n",
    "#     device = next(model.parameters()).device\n",
    "#     num_nodes, num_features = data.x.size()\n",
    "\n",
    "#     feat_self = torch.zeros(num_nodes, num_features, device=device)\n",
    "#     feat_self_1hop = torch.zeros(num_nodes, num_features, device=device)\n",
    "#     feat_self_2hop = torch.zeros(num_nodes, num_features, device=device)\n",
    "\n",
    "#     logits = model(data.x, data.edge_index)\n",
    "#     preds = logits.argmax(dim=-1)\n",
    "#     if use_true_class:\n",
    "#         target_class = data.y\n",
    "#     else:\n",
    "#         target_class = preds\n",
    "\n",
    "#     W1 = model.conv1.lin.weight.T.to(device)  # [F, H]\n",
    "#     W2 = model.conv2.lin.weight.T.to(device)  # [H, C]\n",
    "\n",
    "#     if model.conv1.bias is not None:\n",
    "#         b1 = model.conv1.bias.to(device)\n",
    "#     else:\n",
    "#         b1 = torch.zeros(W1.shape[1], device=device)\n",
    "\n",
    "#     if model.conv2.bias is not None:\n",
    "#         b2 = model.conv2.bias.to(device)\n",
    "#     else:\n",
    "#         b2 = torch.zeros(W2.shape[1], device=device)\n",
    "\n",
    "#     edge_index = data.edge_index\n",
    "\n",
    "#     print(f\"Building feature importance via decomposition for {len(nodes_to_explain)} nodes...\")\n",
    "#     for i, v in enumerate(nodes_to_explain.tolist(), start=1):\n",
    "#         nodes_sub, _, _, _ = k_hop_subgraph(\n",
    "#             v, hop, edge_index, relabel_nodes=False\n",
    "#         )\n",
    "#         nodes_sub = nodes_sub.to(device)\n",
    "#         n_sub = nodes_sub.numel()\n",
    "\n",
    "#         mapping = {int(nid): idx for idx, nid in enumerate(nodes_sub.tolist())}\n",
    "#         v_local = mapping[int(v)]\n",
    "\n",
    "#         # Dense normalized adjacency A_sub\n",
    "#         A_sub = torch.zeros(n_sub, n_sub, device=device)\n",
    "#         for local_row, global_row in enumerate(nodes_sub.tolist()):\n",
    "#             for (nbr, w) in adj_list[global_row]:\n",
    "#                 if nbr in mapping:\n",
    "#                     local_col = mapping[nbr]\n",
    "#                     A_sub[local_row, local_col] = w\n",
    "\n",
    "#         X_sub = data.x[nodes_sub]  # [n_sub, F]\n",
    "#         c_idx = int(target_class[v].item())\n",
    "\n",
    "#         C_sub, _ = decompose_two_layer_gcn_ego(\n",
    "#             A_sub, X_sub, W1, b1, W2, b2, v_local, c_idx\n",
    "#         )\n",
    "\n",
    "#         C_abs = C_sub.abs()  # [n_sub, F]\n",
    "\n",
    "#         # hop distances 0,1,2,...\n",
    "#         dist = bfs_hop_dist(A_sub, v_local)\n",
    "\n",
    "#         mask_self = (dist == 0)\n",
    "#         mask_1hop = (dist == 1)\n",
    "#         mask_2hop = (dist == 2)\n",
    "\n",
    "#         self_contrib = C_abs[mask_self].sum(dim=0) if mask_self.any() else torch.zeros(num_features, device=device)\n",
    "#         hop1_contrib = C_abs[mask_1hop].sum(dim=0) if mask_1hop.any() else torch.zeros(num_features, device=device)\n",
    "#         hop2_contrib = C_abs[mask_2hop].sum(dim=0) if mask_2hop.any() else torch.zeros(num_features, device=device)\n",
    "\n",
    "#         feat_self[v] = self_contrib\n",
    "#         feat_self_1hop[v] = self_contrib + hop1_contrib\n",
    "#         feat_self_2hop[v] = self_contrib + hop1_contrib + hop2_contrib\n",
    "\n",
    "#         if i % 20 == 0 or i == 1 or i == len(nodes_to_explain):\n",
    "#             print(f\"  processed {i}/{len(nodes_to_explain)} nodes\", flush=True)\n",
    "\n",
    "#     return feat_self, feat_self_1hop, feat_self_2hop\n",
    "\n",
    "@torch.no_grad()\n",
    "def build_feature_importance_decomposition(\n",
    "    data,\n",
    "    model,\n",
    "    adj_list,\n",
    "    nodes_to_explain,\n",
    "    hop=2,\n",
    "    use_true_class=True,\n",
    "):\n",
    "    \"\"\"\n",
    "    For each node v in nodes_to_explain, compute three feature importance variants:\n",
    "\n",
    "        self_only[v,f]        = |C[v,f]|        (u = v only)\n",
    "        self_plus_1hop[v,f]   = |C[v,f]| + sum_{u:dist=1} |C[u,f]|\n",
    "        self_plus_2hop[v,f]   = |C[v,f]| + sum_{u:dist<=2} |C[u,f]|\n",
    "\n",
    "    where C[u,f] is contribution from node u's feature f to v's logit.\n",
    "\n",
    "    Returns:\n",
    "        feat_self:       [num_nodes, num_features]\n",
    "        feat_self_1hop:  [num_nodes, num_features]\n",
    "        feat_self_2hop:  [num_nodes, num_features]\n",
    "    \"\"\"\n",
    "    device = next(model.parameters()).device\n",
    "    num_nodes, num_features = data.x.size()\n",
    "\n",
    "    feat_self = torch.zeros(num_nodes, num_features, device=device)\n",
    "    feat_self_1hop = torch.zeros(num_nodes, num_features, device=device)\n",
    "    feat_self_2hop = torch.zeros(num_nodes, num_features, device=device)\n",
    "\n",
    "    # One forward pass to get target class\n",
    "    logits = model(data.x, data.edge_index)\n",
    "    preds = logits.argmax(dim=-1)\n",
    "    if use_true_class:\n",
    "        target_class = data.y\n",
    "    else:\n",
    "        target_class = preds\n",
    "\n",
    "    # Extract weights once\n",
    "    W1 = model.conv1.lin.weight.T.to(device)  # [F, H]\n",
    "    W2 = model.conv2.lin.weight.T.to(device)  # [H, C]\n",
    "\n",
    "    if model.conv1.bias is not None:\n",
    "        b1 = model.conv1.bias.to(device)\n",
    "    else:\n",
    "        b1 = torch.zeros(W1.shape[1], device=device)\n",
    "\n",
    "    edge_index = data.edge_index\n",
    "\n",
    "    print(f\"Building feature importance via decomposition for {len(nodes_to_explain)} nodes...\")\n",
    "    for i, v in enumerate(nodes_to_explain.tolist(), start=1):\n",
    "        v = int(v)\n",
    "\n",
    "        # 2-hop ego subgraph\n",
    "        nodes_sub, _, _, _ = k_hop_subgraph(\n",
    "            v, hop, edge_index, relabel_nodes=False\n",
    "        )\n",
    "        nodes_sub = nodes_sub.to(device)\n",
    "        n_sub = nodes_sub.numel()\n",
    "\n",
    "        # mapping: global node id -> local index in [0..n_sub-1]\n",
    "        # NOTE: nodes_sub is usually small (ego), so Python dict is fine\n",
    "        mapping = {int(nid): idx for idx, nid in enumerate(nodes_sub.tolist())}\n",
    "        v_local = mapping[v]\n",
    "\n",
    "        # Dense normalized adjacency A_sub\n",
    "        A_sub = torch.zeros(n_sub, n_sub, device=device)\n",
    "        # Use adj_list (already normalized) to fill rows\n",
    "        for local_row, global_row in enumerate(nodes_sub.tolist()):\n",
    "            for (nbr, w) in adj_list[global_row]:\n",
    "                if nbr in mapping:\n",
    "                    local_col = mapping[nbr]\n",
    "                    A_sub[local_row, local_col] = w\n",
    "\n",
    "        X_sub = data.x[nodes_sub]  # [n_sub, F]\n",
    "        c_idx = int(target_class[v].item())\n",
    "\n",
    "        # Decomposition on ego subgraph\n",
    "        C_sub, _ = decompose_two_layer_gcn_ego(\n",
    "            A_sub, X_sub, W1, b1, W2, model.conv2.bias, v_local, c_idx\n",
    "        )\n",
    "        C_abs = C_sub.abs()  # [n_sub, F]\n",
    "\n",
    "        # ---- Hop masks without BFS ----\n",
    "        # self node mask\n",
    "        mask_self = (nodes_sub == v)        # [n_sub]\n",
    "\n",
    "        # 1-hop neighbors of v (global indices), then intersect with nodes_sub\n",
    "        neighbors_global = [nbr for (nbr, _) in adj_list[v]]\n",
    "        neighbors_set = set(neighbors_global)\n",
    "        # Build 1-hop mask over nodes_sub\n",
    "        mask_1hop_list = [ (int(nid) in neighbors_set) for nid in nodes_sub.tolist() ]\n",
    "        mask_1hop = torch.tensor(mask_1hop_list, dtype=torch.bool, device=device)\n",
    "\n",
    "        # 2-hop = everyone else in the 2-hop ego (excluding self and 1-hop)\n",
    "        mask_2hop = ~(mask_self | mask_1hop)\n",
    "\n",
    "        # Aggregate contributions by hop\n",
    "        if mask_self.any():\n",
    "            self_contrib = C_abs[mask_self].sum(dim=0)\n",
    "        else:\n",
    "            self_contrib = torch.zeros(num_features, device=device)\n",
    "\n",
    "        if mask_1hop.any():\n",
    "            hop1_contrib = C_abs[mask_1hop].sum(dim=0)\n",
    "        else:\n",
    "            hop1_contrib = torch.zeros(num_features, device=device)\n",
    "\n",
    "        if mask_2hop.any():\n",
    "            hop2_contrib = C_abs[mask_2hop].sum(dim=0)\n",
    "        else:\n",
    "            hop2_contrib = torch.zeros(num_features, device=device)\n",
    "\n",
    "        feat_self[v] = self_contrib\n",
    "        feat_self_1hop[v] = self_contrib + hop1_contrib\n",
    "        feat_self_2hop[v] = self_contrib + hop1_contrib + hop2_contrib\n",
    "\n",
    "        if i % 20 == 0 or i == 1 or i == len(nodes_to_explain):\n",
    "            print(f\"  processed {i}/{len(nodes_to_explain)} nodes\", flush=True)\n",
    "\n",
    "    return feat_self, feat_self_1hop, feat_self_2hop\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def build_random_importance(num_nodes, num_features, device):\n",
    "    \"\"\"\n",
    "    Generate random feature-importance scores, normalized per-node\n",
    "    so that distributions resemble real importance magnitudes.\n",
    "    \"\"\"\n",
    "    rand_imp = torch.rand(num_nodes, num_features, device=device)\n",
    "    return rand_imp\n",
    "\n",
    "@torch.no_grad()\n",
    "def build_lime_importance(\n",
    "    data,\n",
    "    model,\n",
    "    nodes_to_explain,\n",
    "    num_classes,\n",
    "    num_samples: int = 200,\n",
    "):\n",
    "    \"\"\"\n",
    "    Build LIME-based feature importance for a subset of nodes.\n",
    "\n",
    "    Returns:\n",
    "        feat_lime: [num_nodes, num_features], zeros for nodes not explained.\n",
    "    \"\"\"\n",
    "    device = next(model.parameters()).device\n",
    "    X = data.x\n",
    "    num_nodes, num_features = X.size()\n",
    "\n",
    "    # Fit LIME on training features only (numpy)\n",
    "    X_train_np = X[data.train_mask].detach().cpu().numpy()\n",
    "    class_names = [f\"class_{i}\" for i in range(num_classes)]\n",
    "    feature_names = [f\"f{j}\" for j in range(num_features)]\n",
    "\n",
    "    lime_expl = LimeTabularExplainer(\n",
    "        training_data=X_train_np,\n",
    "        feature_names=feature_names,\n",
    "        class_names=class_names,\n",
    "        discretize_continuous=False,\n",
    "        mode=\"classification\",\n",
    "    )\n",
    "\n",
    "    feat_lime = torch.zeros(num_nodes, num_features, device=device)\n",
    "\n",
    "    nodes_to_explain = nodes_to_explain.to(device)\n",
    "\n",
    "    print(f\"Running LIME for {len(nodes_to_explain)} nodes...\")\n",
    "    for idx, nid in enumerate(nodes_to_explain.tolist(), start=1):\n",
    "        nid = int(nid)\n",
    "        x0 = X[nid].detach().cpu().numpy()\n",
    "\n",
    "        def predict_fn(X_batch_np):\n",
    "            # X_batch_np: [m, D_total] numpy; return probs [m, num_classes]\n",
    "            X_batch = torch.tensor(X_batch_np, dtype=torch.float32, device=device)\n",
    "            probs = []\n",
    "            # Save and restore node nid's features\n",
    "            saved = X[nid].clone()\n",
    "            try:\n",
    "                for i_row in range(X_batch.shape[0]):\n",
    "                    data.x[nid] = X_batch[i_row]\n",
    "                    logits_n = model(data.x, data.edge_index)[nid]  # [C] logits\n",
    "                    prob_n = F.softmax(logits_n, dim=-1)\n",
    "                    probs.append(prob_n.detach().cpu().numpy())\n",
    "            finally:\n",
    "                data.x[nid] = saved\n",
    "            return np.vstack(probs)\n",
    "\n",
    "        # LIME explanation around x0\n",
    "        exp = lime_expl.explain_instance(\n",
    "            data_row=x0,\n",
    "            predict_fn=predict_fn,\n",
    "            num_features=num_features,  # ask for full vector\n",
    "            num_samples=num_samples,\n",
    "        )\n",
    "\n",
    "        # Use the first available label (usually the predicted class)\n",
    "        label = exp.available_labels()[0]\n",
    "        weights = np.zeros(num_features, dtype=np.float32)\n",
    "\n",
    "        # exp.as_map()[label]: list of (feat_idx_or_name, weight)\n",
    "        for feat_id, w in exp.as_map()[label]:\n",
    "            if isinstance(feat_id, str):\n",
    "                if feat_id.startswith(\"f\"):\n",
    "                    feat_idx = int(feat_id[1:])\n",
    "                else:\n",
    "                    feat_idx = int(feat_id)\n",
    "            else:\n",
    "                feat_idx = int(feat_id)\n",
    "            if 0 <= feat_idx < num_features:\n",
    "                weights[feat_idx] = abs(w)\n",
    "\n",
    "        feat_lime[nid] = torch.from_numpy(weights).to(device)\n",
    "\n",
    "        if idx % 10 == 0 or idx == 1 or idx == len(nodes_to_explain):\n",
    "            print(f\"  LIME processed {idx}/{len(nodes_to_explain)} nodes\", flush=True)\n",
    "\n",
    "    return feat_lime\n",
    "\n",
    "@torch.no_grad()\n",
    "def build_graphlime_importance(\n",
    "    data,\n",
    "    model,\n",
    "    nodes_to_explain,\n",
    "    hop: int = 2,\n",
    "    rho: float = 0.1,\n",
    "):\n",
    "    \"\"\"\n",
    "    Build GraphLIME-based feature importance for a subset of nodes.\n",
    "\n",
    "    Assumes you have a GraphLIME API like:\n",
    "        glime = GraphLIME(model=model, hop=hop, rho=rho)\n",
    "        coefs = glime.explain_node(nid, x, edge_index)  # [D_total]\n",
    "\n",
    "    Returns:\n",
    "        feat_glime: [num_nodes, num_features], zeros for nodes not explained.\n",
    "    \"\"\"\n",
    "    device = next(model.parameters()).device\n",
    "    X = data.x.to(device)\n",
    "    num_nodes, num_features = X.size()\n",
    "\n",
    "    glime = GraphLIME(model=model, hop=hop, rho=rho)\n",
    "\n",
    "    feat_glime = torch.zeros(num_nodes, num_features, device=device)\n",
    "    nodes_to_explain = nodes_to_explain.to(device)\n",
    "\n",
    "    print(f\"Running GraphLIME for {len(nodes_to_explain)} nodes...\")\n",
    "    for idx, nid in enumerate(nodes_to_explain.tolist(), start=1):\n",
    "        nid = int(nid)\n",
    "        coefs = glime.explain_node(nid, X, data.edge_index.to(device))  # [D_total]\n",
    "\n",
    "        # coefs could already be a tensor; make sure we abs and move to device\n",
    "        if isinstance(coefs, np.ndarray):\n",
    "            coefs = torch.from_numpy(coefs)\n",
    "        coefs = coefs.to(device)\n",
    "        feat_glime[nid] = coefs.abs()\n",
    "\n",
    "        if idx % 10 == 0 or idx == 1 or idx == len(nodes_to_explain):\n",
    "            print(f\"  GraphLIME processed {idx}/{len(nodes_to_explain)} nodes\", flush=True)\n",
    "\n",
    "    return feat_glime\n",
    "\n",
    "\n",
    "def build_gnnexplainer_importance(\n",
    "    data,\n",
    "    model,\n",
    "    nodes_to_explain,\n",
    "    epochs=50,\n",
    "):\n",
    "    \"\"\"\n",
    "    Use PyG's GNNExplainer (via Explainer) to get per-node feature importance\n",
    "    for the nodes in `nodes_to_explain`.\n",
    "\n",
    "    Returns:\n",
    "        feat_gnnexp: [num_nodes, num_features] tensor\n",
    "                     (zeros for nodes not explained)\n",
    "    \"\"\"\n",
    "    device = next(model.parameters()).device\n",
    "    num_nodes, num_features = data.x.size()\n",
    "\n",
    "    feat_gnnexp = torch.zeros(num_nodes, num_features, device=device)\n",
    "\n",
    "    model.eval()\n",
    "\n",
    "    explainer = Explainer(\n",
    "        model=model,\n",
    "        algorithm=GNNExplainer(epochs=epochs),\n",
    "        explanation_type=\"model\",\n",
    "        node_mask_type=\"attributes\",\n",
    "        edge_mask_type=None,\n",
    "        model_config=dict(\n",
    "            mode=\"multiclass_classification\",\n",
    "            task_level=\"node\",\n",
    "            return_type=\"raw\",  # our model returns logits\n",
    "        ),\n",
    "    )\n",
    "\n",
    "    print(f\"Running GNNExplainer for {len(nodes_to_explain)} nodes...\")\n",
    "    for i, nid in enumerate(nodes_to_explain.tolist(), start=1):\n",
    "        # GNNExplainer call: node-level explanation for this index\n",
    "        exp = explainer(data.x, data.edge_index, index=int(nid))\n",
    "\n",
    "        node_mask = exp.node_mask  # shape [num_nodes_sub, F] or [F]\n",
    "\n",
    "        # Find the row corresponding to the explained node.\n",
    "        # exp.index is the node index used inside the explainer.\n",
    "        if hasattr(exp, \"index\"):\n",
    "            row = int(exp.index)\n",
    "        else:\n",
    "            row = int(nid)\n",
    "\n",
    "        if node_mask.dim() == 2:\n",
    "            feat_imp_row = node_mask[row]     # [F]\n",
    "        else:\n",
    "            feat_imp_row = node_mask          # [F]\n",
    "\n",
    "        # GNNExplainer node mask is non-negative importance; we can take it as-is\n",
    "        feat_gnnexp[nid] = feat_imp_row\n",
    "\n",
    "        if i % 10 == 0 or i == 1 or i == len(nodes_to_explain):\n",
    "            print(f\"  GNNExplainer processed {i}/{len(nodes_to_explain)} nodes\", flush=True)\n",
    "\n",
    "    return feat_gnnexp\n",
    "\n",
    "# def build_ig_importance(\n",
    "#     data,\n",
    "#     model,\n",
    "#     nodes_to_explain,\n",
    "# ):\n",
    "#     \"\"\"\n",
    "#     Use Captum's Integrated Gradients (via PyG's CaptumExplainer) to get\n",
    "#     per-node feature importance for the nodes in `nodes_to_explain`.\n",
    "\n",
    "#     Returns:\n",
    "#         feat_ig: [num_nodes, num_features] tensor\n",
    "#                  (zeros for nodes not explained)\n",
    "#     \"\"\"\n",
    "#     device = next(model.parameters()).device\n",
    "#     num_nodes, num_features = data.x.size()\n",
    "\n",
    "#     feat_ig = torch.zeros(num_nodes, num_features, device=device)\n",
    "\n",
    "#     model.eval()\n",
    "\n",
    "#     explainer = Explainer(\n",
    "#         model=model,\n",
    "#         algorithm=CaptumExplainer(IntegratedGradients),\n",
    "#         explanation_type=\"model\",\n",
    "#         node_mask_type=\"attributes\",\n",
    "#         edge_mask_type=None,\n",
    "#         model_config=dict(\n",
    "#             mode=\"multiclass_classification\",\n",
    "#             task_level=\"node\",\n",
    "#             return_type=\"raw\",\n",
    "#         ),\n",
    "#     )\n",
    "\n",
    "#     print(f\"Running Integrated Gradients (Captum) for {len(nodes_to_explain)} nodes...\")\n",
    "#     for i, nid in enumerate(nodes_to_explain.tolist(), start=1):\n",
    "#         exp = explainer(data.x, data.edge_index, index=int(nid))\n",
    "\n",
    "#         node_mask = exp.node_mask  # could be [num_nodes_sub, F] or [F]\n",
    "\n",
    "#         if hasattr(exp, \"index\"):\n",
    "#             row = int(exp.index)\n",
    "#         else:\n",
    "#             row = int(nid)\n",
    "\n",
    "#         if node_mask.dim() == 2:\n",
    "#             feat_imp_row = node_mask[row]       # [F]\n",
    "#         else:\n",
    "#             feat_imp_row = node_mask            # [F]\n",
    "\n",
    "#         # Take absolute attributions as importance\n",
    "#         feat_ig[nid] = feat_imp_row.abs()\n",
    "\n",
    "#         if i % 10 == 0 or i == 1 or i == len(nodes_to_explain):\n",
    "#             print(f\"  IG processed {i}/{len(nodes_to_explain)} nodes\", flush=True)\n",
    "\n",
    "#     return feat_ig\n",
    "\n",
    "#IG - Manual way\n",
    "def build_ig_importance(\n",
    "    data,\n",
    "    model,\n",
    "    nodes_to_explain,\n",
    "    n_steps=30,\n",
    "    use_true_class=True,\n",
    "    baseline='random'\n",
    "):\n",
    "    \"\"\"\n",
    "    Cleaner implementation using torch.autograd.grad directly.\n",
    "    \"\"\"\n",
    "    device = next(model.parameters()).device\n",
    "    num_nodes, num_features = data.x.size()\n",
    "\n",
    "    feat_ig = torch.zeros(num_nodes, num_features, device=device)\n",
    "\n",
    "    model.eval()\n",
    "\n",
    "    # Prepare baseline\n",
    "    if baseline == 'zero':\n",
    "        baseline_x = torch.zeros_like(data.x)\n",
    "    elif baseline == 'random':\n",
    "        baseline_x = torch.rand_like(data.x)\n",
    "    else:\n",
    "        baseline_x = torch.zeros_like(data.x)\n",
    "\n",
    "    x = data.x.to(device)\n",
    "    edge_index = data.edge_index.to(device)\n",
    "\n",
    "    # Get target classes\n",
    "    with torch.no_grad():\n",
    "        logits = model(x, edge_index)\n",
    "        if use_true_class:\n",
    "            target_class = data.y.to(device)\n",
    "        else:\n",
    "            target_class = logits.argmax(dim=-1)\n",
    "\n",
    "    print(f\"Running Clean IG for {len(nodes_to_explain)} nodes...\")\n",
    "\n",
    "    for i, nid in enumerate(nodes_to_explain.tolist(), start=1):\n",
    "        nid = int(nid)\n",
    "        node_target_class = target_class[nid]\n",
    "\n",
    "        # Compute attribution using clean method\n",
    "        attribution = _integrated_gradients_single(\n",
    "            model=model,\n",
    "            x=x,\n",
    "            baseline=baseline_x,\n",
    "            edge_index=edge_index,\n",
    "            target_node=nid,\n",
    "            target_class=node_target_class,\n",
    "            steps=n_steps\n",
    "        )\n",
    "\n",
    "        feat_ig[nid] = attribution.abs()\n",
    "\n",
    "        if i % 10 == 0 or i == 1 or i == len(nodes_to_explain):\n",
    "            print(f\"  IG processed {i}/{len(nodes_to_explain)} nodes\", flush=True)\n",
    "\n",
    "    return feat_ig\n",
    "\n",
    "\n",
    "def _integrated_gradients_single(model, x, baseline, edge_index, target_node, target_class, steps):\n",
    "    \"\"\"\n",
    "    Compute Integrated Gradients for a single node using autograd.grad.\n",
    "    \"\"\"\n",
    "    # Scale factor\n",
    "    scaled_inputs = [baseline + (float(i) / steps) * (x - baseline) for i in range(0, steps + 1)]\n",
    "\n",
    "    # Compute gradients for each scaled input\n",
    "    gradients = []\n",
    "\n",
    "    for scaled_input in scaled_inputs:\n",
    "        scaled_input = scaled_input.clone().requires_grad_(True)\n",
    "\n",
    "        # Forward pass\n",
    "        output = model(scaled_input, edge_index)\n",
    "\n",
    "        # Get the target node's score for target class\n",
    "        target_score = output[target_node, target_class]\n",
    "\n",
    "        # Compute gradient\n",
    "        gradient = torch.autograd.grad(outputs=target_score, inputs=scaled_input)[0]\n",
    "        gradients.append(gradient)\n",
    "\n",
    "    # Average the gradients\n",
    "    avg_gradients = torch.stack(gradients).mean(dim=0)\n",
    "\n",
    "    # Integrated Gradients formula\n",
    "    integrated_gradients = (x - baseline) * avg_gradients\n",
    "\n",
    "    return integrated_gradients[target_node]\n",
    "\n",
    "\n",
    "#GOAT\n",
    "from torch_geometric.utils import to_undirected, add_self_loops, degree\n",
    "\n",
    "def build_gcn_norm_dense(edge_index, num_nodes, device):\n",
    "    edge_index = to_undirected(edge_index, num_nodes=num_nodes)\n",
    "    edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)\n",
    "    row, col = edge_index\n",
    "    deg = degree(row, num_nodes=num_nodes)\n",
    "    deg_inv_sqrt = deg.pow(-0.5)\n",
    "    deg_inv_sqrt[torch.isinf(deg_inv_sqrt)] = 0.\n",
    "    w = deg_inv_sqrt[row] * deg_inv_sqrt[col]\n",
    "    A = torch.zeros((num_nodes, num_nodes), device=device)\n",
    "    A.index_put_((row, col), w, accumulate=True)\n",
    "    return A  # [N,N]\n",
    "\n",
    "def preactivations_and_masks(data, model, A_hat):\n",
    "    \"\"\"\n",
    "    Compute pre-activations z1 and ReLU masks M1 for GOAT,\n",
    "    using the current operating point (data.x).\n",
    "    For a 2-layer GCN:\n",
    "        z1 = A_hat @ (X W1)\n",
    "        h1 = ReLU(z1)\n",
    "        logits = A_hat @ (h1 W2)\n",
    "    \"\"\"\n",
    "    device = data.x.device\n",
    "\n",
    "    # Extract linear weights used inside GCNConv\n",
    "    # conv.lin.weight: [out_channels, in_channels]  so we transpose\n",
    "    W1 = model.conv1.lin.weight.T.to(device)  # [in_dim, H]\n",
    "    W2 = model.conv2.lin.weight.T.to(device)  # [H, C]\n",
    "\n",
    "    X = data.x.to(device)                     # [N, D_total]\n",
    "\n",
    "    # First layer pre-activation: z1 = Â X W1\n",
    "    z1 = A_hat @ (X @ W1)                     # [N, H]\n",
    "    M1 = (z1 > 0).float()                     # ReLU mask for h1\n",
    "\n",
    "    # Logits: z2 = Â (ReLU(z1) W2)\n",
    "    h1 = M1 * z1                              # [N, H] (output of first layer after ReLU)\n",
    "    z2 = A_hat @ (h1 @ W2)                    # [N, C] (logits, output of second layer)\n",
    "\n",
    "    return z1, z2, M1, W1, W2\n",
    "\n",
    "@torch.no_grad()\n",
    "def goat_2layer_feature_importance_for_node(\n",
    "    i: int,\n",
    "    data,\n",
    "    model,\n",
    "    A_hat,\n",
    "    z1, z2, M1,\n",
    "    W1, W2,\n",
    "):\n",
    "    \"\"\"\n",
    "    Returns phi: [D_total], feature contributions to the predicted class logit at node i.\n",
    "\n",
    "    Implements:\n",
    "        phi_f(i,c) ≈ sum_{h,u} x_{u,f} * [ (r_i^T D1_h Â)_u ] * W1[f,h] * W2[h,c]\n",
    "\n",
    "    where:\n",
    "        r_i^T is row i of Â,\n",
    "        D1_h = diag(M1[:,h]) is the ReLU mask for hidden unit h across nodes.\n",
    "    \"\"\"\n",
    "    device = data.x.device\n",
    "    X = data.x.to(device)                    # [N, D_total]\n",
    "    XT = X.t()                               # [D_total, N]\n",
    "    N, D_total = X.size()\n",
    "    H = W1.size(1)                           # hidden size\n",
    "\n",
    "    i = int(i)\n",
    "\n",
    "    # predicted class c at node i (using full model)\n",
    "    logits = model(X, data.edge_index.to(device))  # [N, C]\n",
    "    c = int(logits[i].argmax().item())\n",
    "\n",
    "    r = A_hat[i, :]                          # [N]  (row i of Â)\n",
    "    phi = torch.zeros(D_total, device=device)\n",
    "\n",
    "    # For each hidden channel (h) of the first layer\n",
    "    for h in range(H):\n",
    "        # t1 corresponds to (r_i^T D1_h Â)_u\n",
    "        # D1_h is diag(M1[:,h])\n",
    "        t1 = (r * M1[:, h]) @ A_hat          # [N]\n",
    "\n",
    "        # Aggregate over source nodes u: v_f = sum_u x_{u,f} * t1[u] = (X^T @ t1)_f\n",
    "        v = XT @ t1                      # [D_total]\n",
    "\n",
    "        # Weight chain for (f -> h -> c)\n",
    "        # W1[f, h] * W2[h, c]\n",
    "        chain = W1[:, h] * W2[h, c]      # [D_total]\n",
    "\n",
    "        phi += v * chain\n",
    "\n",
    "    return phi  # [D_total]\n",
    "\n",
    "def build_goat_feature_importance(\n",
    "    data,\n",
    "    model,\n",
    "    nodes_to_explain=None,\n",
    "):\n",
    "    \"\"\"\n",
    "    Compute GOAT-style feature importance for a set of nodes.\n",
    "\n",
    "    Returns:\n",
    "        feat_goat: [num_nodes, num_features]\n",
    "                   feat_goat[v, f] = |GOAT contribution of feature f to logit of v|\n",
    "                   (0 for nodes not in nodes_to_explain)\n",
    "    \"\"\"\n",
    "    device = data.x.device\n",
    "    N, D_total = data.x.size()\n",
    "\n",
    "    # Dense normalized adjacency\n",
    "    A_hat = build_gcn_norm_dense(data.edge_index.to(device), N, device=device)\n",
    "\n",
    "    # Pre-activations, masks, and weights\n",
    "    z1, z2, M1, W1, W2 = preactivations_and_masks(data, model, A_hat)\n",
    "\n",
    "    # If nodes_to_explain is None, explain all nodes\n",
    "    if nodes_to_explain is None:\n",
    "        nodes_to_explain = torch.arange(N, device=device)\n",
    "    else:\n",
    "        nodes_to_explain = nodes_to_explain.to(device)\n",
    "\n",
    "    feat_goat = torch.zeros(N, D_total, device=device)\n",
    "\n",
    "    print(f\"Running GOAT for {len(nodes_to_explain)} nodes...\")\n",
    "    for idx, nid in enumerate(nodes_to_explain.tolist(), start=1):\n",
    "        phi = goat_2layer_feature_importance_for_node(\n",
    "            nid,\n",
    "            data=data,\n",
    "            model=model,\n",
    "            A_hat=A_hat,\n",
    "            z1=z1, z2=z2,\n",
    "            M1=M1,\n",
    "            W1=W1, W2=W2,\n",
    "        )\n",
    "        feat_goat[nid] = phi.abs()\n",
    "\n",
    "        if idx % 10 == 0 or idx == 1 or idx == len(nodes_to_explain):\n",
    "            print(f\"  GOAT processed {idx}/{len(nodes_to_explain)} nodes\", flush=True)\n",
    "\n",
    "    return feat_goat\n",
    "\n",
    "\n",
    "#GRAD-CAM\n",
    "# def gradcam_feature_importance(model, data, use_true_class=True):\n",
    "#     \"\"\"\n",
    "#     GradCAM-style feature importance for ALL nodes.\n",
    "\n",
    "#     Returns:\n",
    "#         feat_imp: [num_nodes, num_features] tensor,\n",
    "#                   feat_imp[v, f] = importance of feature f of node v.\n",
    "#     \"\"\"\n",
    "#     device = next(model.parameters()).device\n",
    "#     x = data.x.to(device)\n",
    "#     edge_index = data.edge_index.to(device)\n",
    "\n",
    "#     model.eval()\n",
    "\n",
    "#     # ---- Forward conv1 with grad tracking ----\n",
    "#     # We need gradients w.r.t. z1 (pre-activation conv1 output)\n",
    "#     x.requires_grad_(True)\n",
    "#     z1 = model.conv1(x, edge_index)      # [N, H]\n",
    "#     h1 = F.relu(z1)\n",
    "#     h1.retain_grad()                     # store grad wrt h1\n",
    "\n",
    "#     # Forward to logits\n",
    "#     z2 = model.conv2(h1, edge_index)     # [N, C]\n",
    "#     logits = z2\n",
    "\n",
    "#     # Choose target class per node (true or predicted)\n",
    "#     if use_true_class:\n",
    "#         target_class = data.y.to(device)\n",
    "#     else:\n",
    "#         target_class = logits.argmax(dim=-1)\n",
    "\n",
    "#     # Backprop from sum over target logits to get global grads\n",
    "#     # (so α_k = average over nodes)\n",
    "#     model.zero_grad(set_to_none=True)\n",
    "#     if h1.grad is not None:\n",
    "#         h1.grad.zero_()\n",
    "\n",
    "#     # scalar: sum over nodes of logits[v, c_v]\n",
    "#     idx = torch.arange(x.size(0), device=device)\n",
    "#     target_logits = logits[idx, target_class[idx]]\n",
    "#     target_sum = target_logits.sum()\n",
    "#     target_sum.backward(retain_graph=False)\n",
    "\n",
    "#     grads = h1.grad           # [N, H]\n",
    "#     # Channel weights α_k: global average over nodes\n",
    "#     alpha = grads.mean(dim=0) # [H]\n",
    "\n",
    "#     # ReLU mask on z1\n",
    "#     M = (z1 > 0).float()      # [N, H]\n",
    "\n",
    "#     # W1: [F, H]\n",
    "#     W1 = model.conv1.lin.weight.T.to(device)  # [F, H]\n",
    "\n",
    "#     # S[f, k] = W1[f, k] * alpha[k]\n",
    "#     S = W1 * alpha.unsqueeze(0)              # [F, H]\n",
    "\n",
    "#     # For each node v, gradient of cam_v wrt X[v, f] is:\n",
    "#     #   dcam_dX[v, f] = sum_k M[v,k] * S[f,k]\n",
    "#     # Vectorize:\n",
    "#     # MAlpha[v,k] = M[v,k] * alpha[k]  -> we already absorbed alpha into S,\n",
    "#     # so just use M as is:\n",
    "#     # Equivalent: dcam_dX^T = S @ (M^T)  => [F, N], then transpose => [N, F]\n",
    "#     dcam_dX_T = S @ M.T       # [F, N]\n",
    "#     dcam_dX = dcam_dX_T.T     # [N, F]\n",
    "\n",
    "#     # GradCAM-style feature importance: |X * dcam_dX|\n",
    "#     feat_imp = (x * dcam_dX).abs()          # [N, F]\n",
    "\n",
    "#     return feat_imp\n",
    "\n",
    "def gradcam_feature_importance(model, data, use_true_class=True):\n",
    "    \"\"\"\n",
    "    GradCAM-style feature importance using ONLY positive gradients as importance scores.\n",
    "    \"\"\"\n",
    "    device = next(model.parameters()).device\n",
    "    x = data.x.to(device)\n",
    "    edge_index = data.edge_index.to(device)\n",
    "\n",
    "    model.eval()\n",
    "\n",
    "    # ---- Forward conv1 with grad tracking ----\n",
    "    x.requires_grad_(True)\n",
    "    z1 = model.conv1(x, edge_index)      # [N, H]\n",
    "    h1 = F.relu(z1)\n",
    "    h1.retain_grad()                     # store grad wrt h1\n",
    "\n",
    "    # Forward to logits\n",
    "    z2 = model.conv2(h1, edge_index)     # [N, C]\n",
    "    logits = z2\n",
    "\n",
    "    # Choose target class per node (true or predicted)\n",
    "    if use_true_class:\n",
    "        target_class = data.y.to(device)\n",
    "    else:\n",
    "        target_class = logits.argmax(dim=-1)\n",
    "\n",
    "    # Backprop from sum over target logits to get global grads\n",
    "    model.zero_grad(set_to_none=True)\n",
    "    if h1.grad is not None:\n",
    "        h1.grad.zero_()\n",
    "\n",
    "    idx = torch.arange(x.size(0), device=device)\n",
    "    target_logits = logits[idx, target_class[idx]]\n",
    "    target_sum = target_logits.sum()\n",
    "    target_sum.backward(retain_graph=False)\n",
    "\n",
    "    grads = h1.grad           # [N, H]\n",
    "\n",
    "    # MODIFICATION: Only keep positive gradients\n",
    "    positive_grads = F.relu(grads)  # [N, H]\n",
    "\n",
    "    # Channel weights α_k: global average over nodes (using only positive grads)\n",
    "    alpha = positive_grads.mean(dim=0)  # [H]\n",
    "\n",
    "    # ReLU mask on z1\n",
    "    M = (z1 > 0).float()      # [N, H]\n",
    "\n",
    "    W1 = model.conv1.lin.weight.T.to(device)  # [F, H]\n",
    "\n",
    "    # S[f, k] = W1[f, k] * alpha[k]\n",
    "    S = W1 * alpha.unsqueeze(0)              # [F, H]\n",
    "\n",
    "    # For each node v, gradient of cam_v wrt X[v, f] is:\n",
    "    #   dcam_dX[v, f] = sum_k M[v,k] * S[f,k]\n",
    "    dcam_dX_T = S @ M.T       # [F, N]\n",
    "    dcam_dX = dcam_dX_T.T     # [N, F]\n",
    "\n",
    "    # MODIFICATION: Use only positive gradients as importance (no multiplication by x)\n",
    "    # Keep only positive values\n",
    "    feat_imp = F.relu(dcam_dX)          # [N, F]\n",
    "\n",
    "    return feat_imp\n",
    "\n",
    "###GNN-LRP\n",
    "def gnn_lrp_feature_importance(model, data, eps: float = 1e-6, use_true_class=True, nodes_to_explain=None):\n",
    "    \"\"\"\n",
    "    Simple 2-layer GCN LRP producing feature-level relevance per node.\n",
    "\n",
    "    We apply epsilon-rule LRP:\n",
    "      1) logit -> hidden units at *same node*\n",
    "      2) hidden units -> input features at *same node*\n",
    "\n",
    "    This ignores neighbor mixing in conv1 but is a clean feature-level baseline.\n",
    "\n",
    "    Returns:\n",
    "        feat_imp: [num_nodes, num_features] tensor of |relevance|.\n",
    "    \"\"\"\n",
    "    device = next(model.parameters()).device\n",
    "    x = data.x.to(device)\n",
    "    edge_index = data.edge_index.to(device)\n",
    "    model.eval()\n",
    "\n",
    "    # Forward full model to get logits and hidden layer\n",
    "    z1 = model.conv1(x, edge_index)   # [N, H]\n",
    "    h1 = F.relu(z1)                   # [N, H]\n",
    "    z2 = model.conv2(h1, edge_index)  # [N, C]\n",
    "    logits = z2\n",
    "\n",
    "    num_nodes, num_features = x.size()\n",
    "    H = h1.size(1)\n",
    "    C = logits.size(1)\n",
    "\n",
    "    if use_true_class:\n",
    "        target_class = data.y.to(device)\n",
    "    else:\n",
    "        target_class = logits.argmax(dim=-1)\n",
    "\n",
    "    # Extract W1, W2\n",
    "    W1 = model.conv1.lin.weight.T.to(device)  # [F, H]\n",
    "    W2 = model.conv2.lin.weight.T.to(device)  # [H, C]\n",
    "\n",
    "    feat_relevance = torch.zeros(num_nodes, num_features, device=device)\n",
    "\n",
    "    if nodes_to_explain is None:\n",
    "        nodes_to_process = torch.arange(num_nodes, device=device)\n",
    "    else:\n",
    "        nodes_to_process = nodes_to_explain.to(device)\n",
    "\n",
    "    print(f\"Running GNN-LRP for {len(nodes_to_process)} nodes...\")\n",
    "    for idx, v_global in enumerate(nodes_to_process.tolist(), start=1):\n",
    "        v = int(v_global)\n",
    "        c = int(target_class[v].item())\n",
    "\n",
    "        # ---- Step 0: output relevance at (v,c)\n",
    "        R_out = logits[v, c]  # scalar\n",
    "\n",
    "        # ---- Step 1: distribute to hidden units at node v\n",
    "        h_v = h1[v]           # [H]\n",
    "        w2_c = W2[:, c]       # [H]\n",
    "\n",
    "        z_vk = h_v * w2_c     # [H]\n",
    "        Zk = z_vk.sum() + eps\n",
    "        if Zk.item() == 0.0:\n",
    "            continue\n",
    "\n",
    "        R_h = (z_vk / Zk) * R_out  # [H]\n",
    "\n",
    "        # ---- Step 2: distribute each hidden unit relevance to features of node v\n",
    "        x_v = x[v]                 # [F]\n",
    "        # z_fk for each k: [F] = x_v * W1[:, k]\n",
    "        R_x_v = torch.zeros(num_features, device=device)\n",
    "\n",
    "        for k in range(H):\n",
    "            w1_k = W1[:, k]        # [F]\n",
    "            z_fk = x_v * w1_k      # [F]\n",
    "            Zf = z_fk.sum() + eps\n",
    "            if Zf.item() != 0.0:\n",
    "                R_x_v += (z_fk / Zf) * R_h[k]  # [F]\n",
    "\n",
    "        feat_relevance[v] = R_x_v.abs()\n",
    "\n",
    "        if idx % 10 == 0 or idx == 1 or idx == len(nodes_to_process):\n",
    "            print(f\"  GNN-LRP processed {idx}/{len(nodes_to_process)} nodes\", flush=True)\n",
    "\n",
    "    return feat_relevance\n",
    "\n",
    "\n",
    "# -------------------------------------------------------------\n",
    "# 6. Fidelity metrics (feature deletion / keeping)\n",
    "# -------------------------------------------------------------\n",
    "@torch.no_grad()\n",
    "def compute_fidelity_scores(\n",
    "    data,\n",
    "    model,\n",
    "    feat_imp,\n",
    "    nodes,\n",
    "    k_fracs=(0.1, 0.2, 0.5),\n",
    "    use_true_class=True,\n",
    "    label=\"\",\n",
    "):\n",
    "    \"\"\"\n",
    "    Compute fidelity+ / fidelity- / keep-top for given feat_imp.\n",
    "    \"\"\"\n",
    "    device = next(model.parameters()).device\n",
    "    model.eval()\n",
    "\n",
    "    x_orig = data.x.clone()\n",
    "    edge_index = data.edge_index\n",
    "\n",
    "    logits_orig = model(x_orig, edge_index)\n",
    "    probs_orig = F.softmax(logits_orig, dim=-1)\n",
    "\n",
    "    if use_true_class:\n",
    "        target_class = data.y\n",
    "    else:\n",
    "        target_class = logits_orig.argmax(dim=-1)\n",
    "\n",
    "    nodes = nodes.to(device)\n",
    "    num_features = x_orig.size(1)\n",
    "\n",
    "    p_orig_nodes = probs_orig[nodes, target_class[nodes]]\n",
    "\n",
    "    results_plus = {}\n",
    "    results_minus = {}\n",
    "    results_keep = {}\n",
    "\n",
    "    print(f\"\\n=== Fidelity scores for importance: {label} ===\")\n",
    "    for k_frac in k_fracs:\n",
    "        k = max(1, int(round(k_frac * num_features)))\n",
    "        print(f\"\\n--- k_frac = {k_frac:.2f} (top {int(k_frac*100)}% -> k={k}) ---\")\n",
    "\n",
    "        imp_nodes = feat_imp[nodes]  # [M, F]\n",
    "        topk_idx = torch.topk(imp_nodes, k=k, dim=1, largest=True).indices    # [M, k]\n",
    "        bottomk_idx = torch.topk(imp_nodes, k=k, dim=1, largest=False).indices  # [M, k]\n",
    "\n",
    "        rows = nodes.unsqueeze(1).expand(-1, k).reshape(-1)\n",
    "\n",
    "        # 1) Fidelity+ (remove TOP-k)\n",
    "        x_mask_top = x_orig.clone()\n",
    "        cols_top = topk_idx.reshape(-1)\n",
    "        x_mask_top[rows, cols_top] = 0.0\n",
    "\n",
    "        logits_mask_top = model(x_mask_top, edge_index)\n",
    "        probs_mask_top = F.softmax(logits_mask_top, dim=-1)\n",
    "        p_mask_top = probs_mask_top[nodes, target_class[nodes]]\n",
    "\n",
    "        delta_plus = p_orig_nodes - p_mask_top\n",
    "        results_plus[k_frac] = float(delta_plus.mean().item())\n",
    "        print(f\"Fidelity+ (remove TOP-k)   = {results_plus[k_frac]:.4f}\")\n",
    "\n",
    "        # 2) Fidelity- (remove BOTTOM-k)\n",
    "        x_mask_bottom = x_orig.clone()\n",
    "        cols_bottom = bottomk_idx.reshape(-1)\n",
    "        x_mask_bottom[rows, cols_bottom] = 0.0\n",
    "\n",
    "        logits_mask_bottom = model(x_mask_bottom, edge_index)\n",
    "        probs_mask_bottom = F.softmax(logits_mask_bottom, dim=-1)\n",
    "        p_mask_bottom = probs_mask_bottom[nodes, target_class[nodes]]\n",
    "\n",
    "        delta_minus = p_orig_nodes - p_mask_bottom\n",
    "        results_minus[k_frac] = float(delta_minus.mean().item())\n",
    "        print(f\"Fidelity- (remove BOTTOM-k)= {results_minus[k_frac]:.4f}\")\n",
    "\n",
    "        # 3) Keep-top (only TOP-k kept)\n",
    "        x_keep_top = x_orig.clone()\n",
    "        x_keep_top[nodes] = 0.0\n",
    "        cols_keep = topk_idx.reshape(-1)\n",
    "        x_keep_top[rows, cols_keep] = x_orig[rows, cols_keep]\n",
    "\n",
    "        logits_keep_top = model(x_keep_top, edge_index)\n",
    "        probs_keep_top = F.softmax(logits_keep_top, dim=-1)\n",
    "        p_keep_top = probs_keep_top[nodes, target_class[nodes]]\n",
    "\n",
    "        delta_keep = p_orig_nodes - p_keep_top\n",
    "        results_keep[k_frac] = float(delta_keep.mean().item())\n",
    "        print(f\"Keep-top (only TOP-k kept) = {results_keep[k_frac]:.4f}\")\n",
    "\n",
    "    return results_plus, results_minus, results_keep\n",
    "\n",
    "\n",
    "# -------------------------------------------------------------\n",
    "# 7. Main script\n",
    "# -------------------------------------------------------------\n",
    "def main():\n",
    "    method_times = {}\n",
    "    parser = argparse.ArgumentParser()\n",
    "    parser.add_argument(\"--seed\", type=int, default=42)\n",
    "    parser.add_argument(\"--hidden_dim\", type=int, default=64)\n",
    "    parser.add_argument(\"--dropout\", type=float, default=0.5)\n",
    "    parser.add_argument(\"--epochs\", type=int, default=500)\n",
    "    parser.add_argument(\"--lr\", type=float, default=0.05)\n",
    "    parser.add_argument(\"--weight_decay\", type=float, default=5e-4)\n",
    "    parser.add_argument(\"--root\", type=str, default=\"data/Amazon\")\n",
    "    parser.add_argument(\"--k_fracs\", type=float, nargs=\"+\",\n",
    "                        default=[0.05, 0.1])\n",
    "    parser.add_argument(\"--num_explain\", type=int, default=100,\n",
    "                        help=\"how many test nodes to explain via decomposition\")\n",
    "    parser.add_argument(\"--use_true_class\", action=\"store_true\",\n",
    "                        help=\"use true label instead of predicted label as target class\")\n",
    "    parser.add_argument(\"--model_path\", type=str, default=\"amazon_computer_gcn.pt\",\n",
    "                        help=\"path to save/load trained GCN model\")\n",
    "    args = parser.parse_args(args=[])\n",
    "\n",
    "    set_seed(args.seed)\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "    # 1) Load data\n",
    "    data, in_dim, num_classes = load_amazon_computers(args.root)\n",
    "    data = data.to(device)\n",
    "\n",
    "    # 2) Build model\n",
    "    model = GCN(\n",
    "        in_channels=in_dim,\n",
    "        hidden_channels=args.hidden_dim,\n",
    "        out_channels=num_classes,\n",
    "        dropout=args.dropout,\n",
    "    ).to(device)\n",
    "\n",
    "    # 3) Train or load model\n",
    "    if os.path.exists(args.model_path):\n",
    "        print(f\"\\n=== Loading existing model from {args.model_path} ===\")\n",
    "        state = torch.load(args.model_path, map_location=device)\n",
    "        model.load_state_dict(state)\n",
    "        train_acc, val_acc, test_acc, _ = evaluate(model, data)\n",
    "        print(f\"Loaded model | Train={train_acc:.4f} Val={val_acc:.4f} Test={test_acc:.4f}\")\n",
    "    else:\n",
    "        print(\"\\n=== Training GCN (no saved model found) ===\")\n",
    "        optimizer = torch.optim.Adam(\n",
    "            model.parameters(),\n",
    "            lr=args.lr,\n",
    "            weight_decay=args.weight_decay,\n",
    "        )\n",
    "\n",
    "        best_val_acc = 0.0\n",
    "        best_state = None\n",
    "\n",
    "        for epoch in range(1, args.epochs + 1):\n",
    "            loss = train(model, data, optimizer)\n",
    "            train_acc, val_acc, test_acc, _ = evaluate(model, data)\n",
    "\n",
    "            if val_acc > best_val_acc:\n",
    "                best_val_acc = val_acc\n",
    "                best_state = {\n",
    "                    \"model\": model.state_dict(),\n",
    "                    \"epoch\": epoch,\n",
    "                    \"test_acc\": test_acc,\n",
    "                }\n",
    "\n",
    "            if epoch % 20 == 0 or epoch == 1:\n",
    "                print(\n",
    "                    f\"Epoch {epoch:03d} | \"\n",
    "                    f\"Loss {loss:.4f} | \"\n",
    "                    f\"Train {train_acc:.4f} | \"\n",
    "                    f\"Val {val_acc:.4f} | \"\n",
    "                    f\"Test {test_acc:.4f}\"\n",
    "                )\n",
    "\n",
    "        if best_state is not None:\n",
    "            model.load_state_dict(best_state[\"model\"])\n",
    "            print(\n",
    "                f\"\\nBest epoch = {best_state['epoch']} | \"\n",
    "                f\"Best val acc = {best_val_acc:.4f} | \"\n",
    "                f\"Test acc @best = {best_state['test_acc']:.4f}\"\n",
    "            )\n",
    "\n",
    "        # Save trained model\n",
    "        torch.save(model.state_dict(), args.model_path)\n",
    "        print(f\"Model saved to {args.model_path}\")\n",
    "\n",
    "    # 4) Build normalized adjacency list\n",
    "    print(\"\\n=== Building normalized adjacency list ===\")\n",
    "    adj_list = build_normalized_adjacency_list(model, data)\n",
    "\n",
    "    # 5) Choose test nodes to explain (correct predictions preferred)\n",
    "    model.eval()\n",
    "    logits = model(data.x, data.edge_index)\n",
    "    preds = logits.argmax(dim=-1)\n",
    "\n",
    "    test_nodes = torch.nonzero(data.test_mask, as_tuple=False).view(-1)\n",
    "    correct_mask = preds == data.y\n",
    "    test_correct_nodes = test_nodes[correct_mask[test_nodes]]\n",
    "\n",
    "    if test_correct_nodes.numel() == 0:\n",
    "        print(\"WARNING: no correctly classified test nodes, using all test nodes.\")\n",
    "        nodes_to_explain = test_nodes\n",
    "    else:\n",
    "        nodes_to_explain = test_correct_nodes\n",
    "\n",
    "    if nodes_to_explain.numel() > args.num_explain:\n",
    "        nodes_to_explain = nodes_to_explain[:args.num_explain]\n",
    "\n",
    "    print(\n",
    "        f\"\\n=== Explaining {nodes_to_explain.numel()} test nodes \"\n",
    "        f\"({'true' if args.use_true_class else 'pred'} class as target) ===\"\n",
    "    )\n",
    "\n",
    "    # 6) Build feature importance: self vs self+1hop vs self+2hop\n",
    "\n",
    "    t0 = _now()\n",
    "    feat_self, feat_self_1hop, feat_self_2hop = build_feature_importance_decomposition(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        adj_list=adj_list,\n",
    "        nodes_to_explain=nodes_to_explain,\n",
    "        hop=2,\n",
    "        use_true_class=args.use_true_class,\n",
    "    )\n",
    "    t1 = _now()\n",
    "    method_times[\"decomposition\"] = t1 - t0\n",
    "\n",
    "    # feat_self, feat_self_1hop, feat_self_2hop = build_feature_importance_decomposition(\n",
    "    #     data=data,\n",
    "    #     model=model,\n",
    "    #     adj_list=adj_list,\n",
    "    #     nodes_to_explain=nodes_to_explain,\n",
    "    #     hop=2,\n",
    "    #     use_true_class=args.use_true_class,\n",
    "    # )\n",
    "    # ---- Random baseline ----\n",
    "    print(\"\\n=== Building RANDOM importance baseline ===\")\n",
    "\n",
    "    t0 = _now()\n",
    "    feat_random = build_random_importance(\n",
    "        num_nodes=data.num_nodes,\n",
    "        num_features=data.x.size(1),\n",
    "        device=device,\n",
    "    )\n",
    "    t1 = _now()\n",
    "    method_times[\"random\"] = t1 - t0\n",
    "\n",
    "#     feat_random = build_random_importance(\n",
    "#         num_nodes=data.num_nodes,\n",
    "#         num_features=data.x.size(1),\n",
    "#         device=device\n",
    "# )\n",
    "    print(\"\\n=== Building GRAD-Cam ===\")\n",
    "    # feat_gradcam = gradcam_feature_importance(model, data, use_true_class=True)\n",
    "    t0 = _now()\n",
    "    feat_gradcam = gradcam_feature_importance(\n",
    "        model=model,\n",
    "        data=data,\n",
    "        use_true_class=args.use_true_class,\n",
    "    )\n",
    "    t1 = _now()\n",
    "    method_times[\"gradcam\"] = t1 - t0\n",
    "\n",
    "    print(\"\\n=== Building GNN-LRP ===\")\n",
    "#     feat_lrp = gnn_lrp_feature_importance(\n",
    "#     model,\n",
    "#     data,\n",
    "#     eps=1e-6,\n",
    "#     use_true_class=True,\n",
    "#     nodes_to_explain=nodes_to_explain  # same subset you use for fidelity\n",
    "# )\n",
    "    t0 = _now()\n",
    "    feat_lrp = gnn_lrp_feature_importance(\n",
    "        model=model,\n",
    "        data=data,\n",
    "        eps=1e-6,\n",
    "        use_true_class=args.use_true_class,\n",
    "        nodes_to_explain=nodes_to_explain,\n",
    "    )\n",
    "    t1 = _now()\n",
    "    method_times[\"lrp\"] = t1 - t0\n",
    "\n",
    "\n",
    "    print(\"\\n=== Building GOAT ===\")\n",
    "    # feat_goat = build_goat_feature_importance(\n",
    "    #     data=data,\n",
    "    #     model=model,\n",
    "    #     nodes_to_explain=nodes_to_explain,\n",
    "    # )\n",
    "    t0 = _now()\n",
    "    feat_goat = build_goat_feature_importance(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        nodes_to_explain=nodes_to_explain,\n",
    "    )\n",
    "    t1 = _now()\n",
    "    method_times[\"goat\"] = t1 - t0\n",
    "\n",
    "    print(\"\\n=== Building LIME ===\")\n",
    "    # feat_lime = build_lime_importance(\n",
    "    #     data=data,\n",
    "    #     model=model,\n",
    "    #     nodes_to_explain=nodes_to_explain,\n",
    "    #     num_classes=num_classes,\n",
    "    #     num_samples=50,\n",
    "    # )\n",
    "    t0 = _now()\n",
    "    feat_lime = build_lime_importance(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        nodes_to_explain=nodes_to_explain,  # or smaller subset\n",
    "        num_classes=num_classes,\n",
    "        num_samples=50,\n",
    "    )\n",
    "    t1 = _now()\n",
    "    method_times[\"lime\"] = t1 - t0\n",
    "\n",
    "    # GraphLime: if possible\n",
    "    # print(\"\\n=== Building GraphLIME ===\")\n",
    "    # t0 = _now()\n",
    "    # feat_glime = build_graphlime_importance(\n",
    "    # data=data,\n",
    "    # model=model,\n",
    "    # nodes_to_explain=nodes_to_explain,\n",
    "    # hop=2,\n",
    "    # rho=0.1,\n",
    "    # )\n",
    "    # t1 = _now()\n",
    "    # method_times[\"graphlime\"] = t1 - t0\n",
    "\n",
    "\n",
    "    #probably not use the following two ---\n",
    "\n",
    "    #     # ---- GNNExplainer baseline importance ----\n",
    "    # print(\"\\n=== Building GNNExplainer importance baseline ===\")\n",
    "    # feat_gnnexp = build_gnnexplainer_importance(\n",
    "    #     data=data,\n",
    "    #     model=model,\n",
    "    #     nodes_to_explain=nodes_to_explain,\n",
    "    #     epochs=50,   # you can change this\n",
    "    # )\n",
    "    print(\"\\n=== Building Integrated Gradients (Captum) importance baseline ===\")\n",
    "    t0 = _now()\n",
    "    feat_ig = build_ig_importance(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        nodes_to_explain=nodes_to_explain,\n",
    ")\n",
    "    t1 = _now()\n",
    "    method_times[\"ig\"] = t1 - t0\n",
    "\n",
    "    # 7) Fidelity scores for each importance variant\n",
    "    print(\"\\n=== Computing fidelity scores ===\")\n",
    "    fp_self, fm_self, fk_self = compute_fidelity_scores(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        feat_imp=feat_self,\n",
    "        nodes=nodes_to_explain,\n",
    "        k_fracs=args.k_fracs,\n",
    "        use_true_class=args.use_true_class,\n",
    "        label=\"SELF ONLY\",\n",
    "    )\n",
    "\n",
    "    fp_s1, fm_s1, fk_s1 = compute_fidelity_scores(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        feat_imp=feat_self_1hop,\n",
    "        nodes=nodes_to_explain,\n",
    "        k_fracs=args.k_fracs,\n",
    "        use_true_class=args.use_true_class,\n",
    "        label=\"SELF + 1HOP\",\n",
    "    )\n",
    "\n",
    "    fp_s2, fm_s2, fk_s2 = compute_fidelity_scores(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        feat_imp=feat_self_2hop,\n",
    "        nodes=nodes_to_explain,\n",
    "        k_fracs=args.k_fracs,\n",
    "        use_true_class=args.use_true_class,\n",
    "        label=\"SELF + 1HOP + 2HOP\",\n",
    "    )\n",
    "    fp_rand, fm_rand, fk_rand = compute_fidelity_scores(\n",
    "    data=data,\n",
    "    model=model,\n",
    "    feat_imp=feat_random,\n",
    "    nodes=nodes_to_explain,\n",
    "    k_fracs=args.k_fracs,\n",
    "    use_true_class=args.use_true_class,\n",
    "    label=\"RANDOM BASELINE\",\n",
    "    )\n",
    "\n",
    "    fp_gc, fm_gc, fk_gc = compute_fidelity_scores(\n",
    "    data=data,\n",
    "    model=model,\n",
    "    feat_imp=feat_gradcam,\n",
    "    nodes=nodes_to_explain,\n",
    "    k_fracs=args.k_fracs,\n",
    "    use_true_class=args.use_true_class,\n",
    "    label=\"GradCAM-feature\",\n",
    ")\n",
    "\n",
    "    fp_lrp, fm_lrp, fk_lrp = compute_fidelity_scores(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        feat_imp=feat_lrp,\n",
    "        nodes=nodes_to_explain,\n",
    "        k_fracs=args.k_fracs,\n",
    "        use_true_class=args.use_true_class,\n",
    "        label=\"GNN-LRP-feature\",\n",
    "    )\n",
    "\n",
    "    fp_goat, fm_goat, fk_goat = compute_fidelity_scores(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        feat_imp=feat_goat,\n",
    "        nodes=nodes_to_explain,\n",
    "        k_fracs=args.k_fracs,\n",
    "        use_true_class=args.use_true_class,\n",
    "        label=\"GOAT-feature\",\n",
    "    )\n",
    "    fp_lime, fm_lime, fk_lime = compute_fidelity_scores(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        feat_imp=feat_lime,\n",
    "        nodes=nodes_to_explain,\n",
    "        k_fracs=args.k_fracs,\n",
    "        use_true_class=args.use_true_class,\n",
    "        label=\"LIME\",\n",
    "    )\n",
    "    # fp_glime, fm_glime, fk_glime = compute_fidelity_scores(\n",
    "    # data=data,\n",
    "    # model=model,\n",
    "    # feat_imp=feat_glime,\n",
    "    # nodes=nodes_to_explain,\n",
    "    # k_fracs=args.k_fracs,\n",
    "    # use_true_class=args.use_true_class,\n",
    "    # label=\"GraphLIME\",\n",
    "    # )\n",
    "    # fp_gnn, fm_gnn, fk_gnn = compute_fidelity_scores(\n",
    "    # data=data,\n",
    "    # model=model,\n",
    "    # feat_imp=feat_gnnexp,\n",
    "    # nodes=nodes_to_explain,\n",
    "    # k_fracs=args.k_fracs,\n",
    "    # use_true_class=args.use_true_class,\n",
    "    # label=\"GNNEXPLAINER BASELINE\",\n",
    "    # )\n",
    "    fp_ig, fm_ig, fk_ig = compute_fidelity_scores(\n",
    "    data=data,\n",
    "    model=model,\n",
    "    feat_imp=feat_ig,\n",
    "    nodes=nodes_to_explain,\n",
    "    k_fracs=args.k_fracs,\n",
    "    use_true_class=args.use_true_class,\n",
    "    label=\"IG (Captum) BASELINE\",\n",
    "  )\n",
    "\n",
    "\n",
    "    # 8) Summary\n",
    "    print(\"\\n=== Summary: mean Δ = p_orig - p_mask ===\")\n",
    "    print(\"\\nFidelity+ (remove top-k% features):\")\n",
    "    for k in args.k_fracs:\n",
    "        print(\n",
    "            f\"k={k:.2f} | \"\n",
    "            f\"self={fp_self[k]:.4f} | \"\n",
    "            f\"self+1hop={fp_s1[k]:.4f} | \"\n",
    "            f\"self+2hop={fp_s2[k]:.4f} | \"\n",
    "            f\"rand={fp_rand[k]:.4f} | \"\n",
    "            f\"gradcam={fp_gc[k]:.4f} | \"\n",
    "            f\"lrp={fp_lrp[k]:.4f} | \"\n",
    "            f\"goat={fp_goat[k]:.4f} | \"\n",
    "            f\"lime={fp_lime[k]:.4f} | \"\n",
    "            f\"ig={fp_ig[k]:.4f} | \"\n",
    "        )\n",
    "\n",
    "    print(\"\\nFidelity- (remove bottom-k% features):\")\n",
    "    for k in args.k_fracs:\n",
    "        print(\n",
    "            f\"k={k:.2f} | \"\n",
    "            f\"self={fm_self[k]:.4f} | \"\n",
    "            f\"self+1hop={fm_s1[k]:.4f} | \"\n",
    "            f\"self+2hop={fm_s2[k]:.4f} | \"\n",
    "            f\"rand={fm_rand[k]:.4f} | \"\n",
    "            f\"gradcam={fm_gc[k]:.4f} | \"\n",
    "            f\"lrp={fm_lrp[k]:.4f} | \"\n",
    "            f\"goat={fm_goat[k]:.4f} | \"\n",
    "            f\"lime={fm_lime[k]:.4f} | \"\n",
    "            f\"ig={fm_ig[k]:.4f} | \"\n",
    "        )\n",
    "\n",
    "    print(\"\\n=== Build-Time Summary (seconds) ===\")\n",
    "    print(f\"{'method':20s} {'build_time':>12s}\")\n",
    "    for name, t in method_times.items():\n",
    "        print(f\"{name:20s} {t:12.3f}\")\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "  main()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ngHRDErv8TQd"
   },
   "source": [
    "## AMZ-PHOTO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "IaTCL-1SC1Lu"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "dd35ea89",
    "outputId": "d72552e0-53c0-4eb5-92cf-3f359bd3fd9c"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GraphLIME module reloaded successfully.\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "# Remove the module from sys.modules if it was previously imported\n",
    "# TypeError: LassoLars.__init__() got an unexpected keyword argument 'normalize'\n",
    "# go to graphlime/__init__.py to change.\n",
    "if 'graphlime' in sys.modules:\n",
    "    del sys.modules['graphlime']\n",
    "sys.path.append(\"GraphLIME\")\n",
    "from graphlime import GraphLIME\n",
    "print(\"GraphLIME module reloaded successfully.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "fwpjF5gw8XzP",
    "outputId": "75258eb3-848e-4b5d-f8a4-ca30fefc2aa9"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading https://github.com/shchur/gnn-benchmark/raw/master/data/npz/amazon_electronics_photo.npz\n",
      "Processing...\n",
      "Done!\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=== Amazon-Photo ===\n",
      "#Nodes     = 7650\n",
      "#Edges     = 238162\n",
      "#Features  = 745\n",
      "#Classes   = 8\n",
      "#Train nodes = 5352\n",
      "#Val nodes   = 762\n",
      "#Test nodes  = 1536\n",
      "\n",
      "=== Loading existing model from amazon_photo_gcn.pt ===\n",
      "Loaded model | Train=0.8696 Val=0.8622 Test=0.8639\n",
      "\n",
      "=== Building normalized adjacency list ===\n",
      "\n",
      "=== Explaining 100 test nodes (pred class as target) ===\n",
      "Building feature importance via decomposition for 100 nodes...\n",
      "  processed 1/100 nodes\n",
      "  processed 20/100 nodes\n",
      "  processed 40/100 nodes\n",
      "  processed 60/100 nodes\n",
      "  processed 80/100 nodes\n",
      "  processed 100/100 nodes\n",
      "\n",
      "=== Building RANDOM importance baseline ===\n",
      "\n",
      "=== Building GRAD-Cam ===\n",
      "\n",
      "=== Building GNN-LRP ===\n",
      "Running GNN-LRP for 100 nodes...\n",
      "  GNN-LRP processed 1/100 nodes\n",
      "  GNN-LRP processed 10/100 nodes\n",
      "  GNN-LRP processed 20/100 nodes\n",
      "  GNN-LRP processed 30/100 nodes\n",
      "  GNN-LRP processed 40/100 nodes\n",
      "  GNN-LRP processed 50/100 nodes\n",
      "  GNN-LRP processed 60/100 nodes\n",
      "  GNN-LRP processed 70/100 nodes\n",
      "  GNN-LRP processed 80/100 nodes\n",
      "  GNN-LRP processed 90/100 nodes\n",
      "  GNN-LRP processed 100/100 nodes\n",
      "\n",
      "=== Building GOAT ===\n",
      "Running GOAT for 100 nodes...\n",
      "  GOAT processed 1/100 nodes\n",
      "  GOAT processed 10/100 nodes\n",
      "  GOAT processed 20/100 nodes\n",
      "  GOAT processed 30/100 nodes\n",
      "  GOAT processed 40/100 nodes\n",
      "  GOAT processed 50/100 nodes\n",
      "  GOAT processed 60/100 nodes\n",
      "  GOAT processed 70/100 nodes\n",
      "  GOAT processed 80/100 nodes\n",
      "  GOAT processed 90/100 nodes\n",
      "  GOAT processed 100/100 nodes\n",
      "\n",
      "=== Building LIME ===\n",
      "Running LIME for 100 nodes...\n",
      "  LIME processed 1/100 nodes\n",
      "  LIME processed 10/100 nodes\n",
      "  LIME processed 20/100 nodes\n",
      "  LIME processed 30/100 nodes\n",
      "  LIME processed 40/100 nodes\n",
      "  LIME processed 50/100 nodes\n",
      "  LIME processed 60/100 nodes\n",
      "  LIME processed 70/100 nodes\n",
      "  LIME processed 80/100 nodes\n",
      "  LIME processed 90/100 nodes\n",
      "  LIME processed 100/100 nodes\n",
      "\n",
      "=== Building Integrated Gradients (Captum) importance baseline ===\n",
      "Running Clean IG for 100 nodes...\n",
      "  IG processed 1/100 nodes\n",
      "  IG processed 10/100 nodes\n",
      "  IG processed 20/100 nodes\n",
      "  IG processed 30/100 nodes\n",
      "  IG processed 40/100 nodes\n",
      "  IG processed 50/100 nodes\n",
      "  IG processed 60/100 nodes\n",
      "  IG processed 70/100 nodes\n",
      "  IG processed 80/100 nodes\n",
      "  IG processed 90/100 nodes\n",
      "  IG processed 100/100 nodes\n",
      "\n",
      "=== Computing fidelity scores ===\n",
      "\n",
      "=== Fidelity scores for importance: SELF ONLY ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=37) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0340\n",
      "Fidelity- (remove BOTTOM-k)= -0.0000\n",
      "Keep-top (only TOP-k kept) = 0.0018\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=74) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0342\n",
      "Fidelity- (remove BOTTOM-k)= -0.0000\n",
      "Keep-top (only TOP-k kept) = -0.0001\n",
      "\n",
      "=== Fidelity scores for importance: SELF + 1HOP ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=37) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0340\n",
      "Fidelity- (remove BOTTOM-k)= 0.0000\n",
      "Keep-top (only TOP-k kept) = 0.0017\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=74) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0355\n",
      "Fidelity- (remove BOTTOM-k)= 0.0000\n",
      "Keep-top (only TOP-k kept) = -0.0013\n",
      "\n",
      "=== Fidelity scores for importance: SELF + 1HOP + 2HOP ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=37) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0342\n",
      "Fidelity- (remove BOTTOM-k)= 0.0000\n",
      "Keep-top (only TOP-k kept) = 0.0015\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=74) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0356\n",
      "Fidelity- (remove BOTTOM-k)= 0.0001\n",
      "Keep-top (only TOP-k kept) = -0.0013\n",
      "\n",
      "=== Fidelity scores for importance: RANDOM BASELINE ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=37) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0009\n",
      "Fidelity- (remove BOTTOM-k)= 0.0010\n",
      "Keep-top (only TOP-k kept) = 0.0325\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=74) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0017\n",
      "Fidelity- (remove BOTTOM-k)= 0.0024\n",
      "Keep-top (only TOP-k kept) = 0.0317\n",
      "\n",
      "=== Fidelity scores for importance: GradCAM-feature ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=37) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0024\n",
      "Fidelity- (remove BOTTOM-k)= 0.0046\n",
      "Keep-top (only TOP-k kept) = 0.0332\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=74) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0043\n",
      "Fidelity- (remove BOTTOM-k)= 0.0083\n",
      "Keep-top (only TOP-k kept) = 0.0341\n",
      "\n",
      "=== Fidelity scores for importance: GNN-LRP-feature ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=37) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0227\n",
      "Fidelity- (remove BOTTOM-k)= -0.0000\n",
      "Keep-top (only TOP-k kept) = 0.0110\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=74) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0278\n",
      "Fidelity- (remove BOTTOM-k)= -0.0000\n",
      "Keep-top (only TOP-k kept) = 0.0057\n",
      "\n",
      "=== Fidelity scores for importance: GOAT-feature ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=37) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0340\n",
      "Fidelity- (remove BOTTOM-k)= -0.0001\n",
      "Keep-top (only TOP-k kept) = 0.0016\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=74) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0354\n",
      "Fidelity- (remove BOTTOM-k)= -0.0001\n",
      "Keep-top (only TOP-k kept) = -0.0016\n",
      "\n",
      "=== Fidelity scores for importance: LIME ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=37) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0114\n",
      "Fidelity- (remove BOTTOM-k)= 0.0001\n",
      "Keep-top (only TOP-k kept) = 0.0220\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=74) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0168\n",
      "Fidelity- (remove BOTTOM-k)= 0.0003\n",
      "Keep-top (only TOP-k kept) = 0.0171\n",
      "\n",
      "=== Fidelity scores for importance: IG (Captum) BASELINE ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=37) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0164\n",
      "Fidelity- (remove BOTTOM-k)= 0.0003\n",
      "Keep-top (only TOP-k kept) = 0.0170\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=74) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0208\n",
      "Fidelity- (remove BOTTOM-k)= 0.0004\n",
      "Keep-top (only TOP-k kept) = 0.0127\n",
      "\n",
      "=== Summary: mean Δ = p_orig - p_mask ===\n",
      "\n",
      "Fidelity+ (remove top-k% features):\n",
      "k=0.05 | self=0.0340 | self+1hop=0.0340 | self+2hop=0.0342 | rand=0.0009 | gradcam=0.0024 | lrp=0.0227 | goat=0.0340 | lime=0.0114 | ig=0.0164 | \n",
      "k=0.10 | self=0.0342 | self+1hop=0.0355 | self+2hop=0.0356 | rand=0.0017 | gradcam=0.0043 | lrp=0.0278 | goat=0.0354 | lime=0.0168 | ig=0.0208 | \n",
      "\n",
      "Fidelity- (remove bottom-k% features):\n",
      "k=0.05 | self=-0.0000 | self+1hop=0.0000 | self+2hop=0.0000 | rand=0.0010 | gradcam=0.0046 | lrp=-0.0000 | goat=-0.0001 | lime=0.0001 | ig=0.0003 | \n",
      "k=0.10 | self=-0.0000 | self+1hop=0.0000 | self+2hop=0.0001 | rand=0.0024 | gradcam=0.0083 | lrp=-0.0000 | goat=-0.0001 | lime=0.0003 | ig=0.0004 | \n",
      "\n",
      "=== Build-Time Summary (seconds) ===\n",
      "method                 build_time\n",
      "decomposition              13.857\n",
      "random                      0.023\n",
      "gradcam                     0.071\n",
      "lrp                         0.535\n",
      "goat                       74.038\n",
      "lime                      161.958\n",
      "ig                        252.356\n"
     ]
    }
   ],
   "source": [
    "# -------------------------------------------------------------\n",
    "# 1. Dataset + splits\n",
    "# -------------------------------------------------------------\n",
    "def load_amazon_photo(root: str = \"data/Amazon\"):\n",
    "    dataset = Amazon(root=root, name=\"Photo\", transform=T.NormalizeFeatures())\n",
    "    data = dataset[0]\n",
    "\n",
    "    print(\"=== Amazon-Photo ===\")\n",
    "    print(f\"#Nodes     = {data.num_nodes}\")\n",
    "    print(f\"#Edges     = {data.num_edges}\")\n",
    "    print(f\"#Features  = {dataset.num_features}\")\n",
    "    print(f\"#Classes   = {dataset.num_classes}\")\n",
    "\n",
    "    if not hasattr(data, \"train_mask\"):\n",
    "        # data = create_splits(data, num_train_per_class=20, num_val_per_class=30)\n",
    "        data = create_splits_stratified(data, train_ratio=0.7, val_ratio=0.1, seed=1234)\n",
    "\n",
    "    return data, dataset.num_features, dataset.num_classes\n",
    "\n",
    "def main():\n",
    "    method_times = {}\n",
    "    parser = argparse.ArgumentParser()\n",
    "    parser.add_argument(\"--seed\", type=int, default=42)\n",
    "    parser.add_argument(\"--hidden_dim\", type=int, default=64)\n",
    "    parser.add_argument(\"--dropout\", type=float, default=0.5)\n",
    "    parser.add_argument(\"--epochs\", type=int, default=500)\n",
    "    parser.add_argument(\"--lr\", type=float, default=0.05)\n",
    "    parser.add_argument(\"--weight_decay\", type=float, default=5e-4)\n",
    "    parser.add_argument(\"--root\", type=str, default=\"data/Amazon\")\n",
    "    parser.add_argument(\"--k_fracs\", type=float, nargs=\"+\",\n",
    "                        default=[0.05, 0.1])\n",
    "    parser.add_argument(\"--num_explain\", type=int, default=100,\n",
    "                        help=\"how many test nodes to explain via decomposition\")\n",
    "    parser.add_argument(\"--use_true_class\", action=\"store_true\",\n",
    "                        help=\"use true label instead of predicted label as target class\")\n",
    "    parser.add_argument(\"--model_path\", type=str, default=\"amazon_photo_gcn.pt\",\n",
    "                        help=\"path to save/load trained GCN model\")\n",
    "    args = parser.parse_args(args=[])\n",
    "\n",
    "    set_seed(args.seed)\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "    # 1) Load data\n",
    "    data, in_dim, num_classes = load_amazon_photo(args.root)\n",
    "    data = data.to(device)\n",
    "\n",
    "    # 2) Build model\n",
    "    model = GCN(\n",
    "        in_channels=in_dim,\n",
    "        hidden_channels=args.hidden_dim,\n",
    "        out_channels=num_classes,\n",
    "        dropout=args.dropout,\n",
    "    ).to(device)\n",
    "\n",
    "    # 3) Train or load model\n",
    "    if os.path.exists(args.model_path):\n",
    "        print(f\"\\n=== Loading existing model from {args.model_path} ===\")\n",
    "        state = torch.load(args.model_path, map_location=device)\n",
    "        model.load_state_dict(state)\n",
    "        train_acc, val_acc, test_acc, _ = evaluate(model, data)\n",
    "        print(f\"Loaded model | Train={train_acc:.4f} Val={val_acc:.4f} Test={test_acc:.4f}\")\n",
    "    else:\n",
    "        print(\"\\n=== Training GCN (no saved model found) ===\")\n",
    "        optimizer = torch.optim.Adam(\n",
    "            model.parameters(),\n",
    "            lr=args.lr,\n",
    "            weight_decay=args.weight_decay,\n",
    "        )\n",
    "\n",
    "        best_val_acc = 0.0\n",
    "        best_state = None\n",
    "\n",
    "        for epoch in range(1, args.epochs + 1):\n",
    "            loss = train(model, data, optimizer)\n",
    "            train_acc, val_acc, test_acc, _ = evaluate(model, data)\n",
    "\n",
    "            if val_acc > best_val_acc:\n",
    "                best_val_acc = val_acc\n",
    "                best_state = {\n",
    "                    \"model\": model.state_dict(),\n",
    "                    \"epoch\": epoch,\n",
    "                    \"test_acc\": test_acc,\n",
    "                }\n",
    "\n",
    "            if epoch % 20 == 0 or epoch == 1:\n",
    "                print(\n",
    "                    f\"Epoch {epoch:03d} | \"\n",
    "                    f\"Loss {loss:.4f} | \"\n",
    "                    f\"Train {train_acc:.4f} | \"\n",
    "                    f\"Val {val_acc:.4f} | \"\n",
    "                    f\"Test {test_acc:.4f}\"\n",
    "                )\n",
    "\n",
    "        if best_state is not None:\n",
    "            model.load_state_dict(best_state[\"model\"])\n",
    "            print(\n",
    "                f\"\\nBest epoch = {best_state['epoch']} | \"\n",
    "                f\"Best val acc = {best_val_acc:.4f} | \"\n",
    "                f\"Test acc @best = {best_state['test_acc']:.4f}\"\n",
    "            )\n",
    "\n",
    "        # Save trained model\n",
    "        torch.save(model.state_dict(), args.model_path)\n",
    "        print(f\"Model saved to {args.model_path}\")\n",
    "\n",
    "    # 4) Build normalized adjacency list\n",
    "    print(\"\\n=== Building normalized adjacency list ===\")\n",
    "    adj_list = build_normalized_adjacency_list(model, data)\n",
    "\n",
    "    # 5) Choose test nodes to explain (correct predictions preferred)\n",
    "    model.eval()\n",
    "    logits = model(data.x, data.edge_index)\n",
    "    preds = logits.argmax(dim=-1)\n",
    "\n",
    "    test_nodes = torch.nonzero(data.test_mask, as_tuple=False).view(-1)\n",
    "    correct_mask = preds == data.y\n",
    "    test_correct_nodes = test_nodes[correct_mask[test_nodes]]\n",
    "\n",
    "    if test_correct_nodes.numel() == 0:\n",
    "        print(\"WARNING: no correctly classified test nodes, using all test nodes.\")\n",
    "        nodes_to_explain = test_nodes\n",
    "    else:\n",
    "        nodes_to_explain = test_correct_nodes\n",
    "\n",
    "    if nodes_to_explain.numel() > args.num_explain:\n",
    "        nodes_to_explain = nodes_to_explain[:args.num_explain]\n",
    "\n",
    "    print(\n",
    "        f\"\\n=== Explaining {nodes_to_explain.numel()} test nodes \"\n",
    "        f\"({'true' if args.use_true_class else 'pred'} class as target) ===\"\n",
    "    )\n",
    "\n",
    "    # 6) Build feature importance: self vs self+1hop vs self+2hop\n",
    "\n",
    "    t0 = _now()\n",
    "    feat_self, feat_self_1hop, feat_self_2hop = build_feature_importance_decomposition(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        adj_list=adj_list,\n",
    "        nodes_to_explain=nodes_to_explain,\n",
    "        hop=2,\n",
    "        use_true_class=args.use_true_class,\n",
    "    )\n",
    "    t1 = _now()\n",
    "    method_times[\"decomposition\"] = t1 - t0\n",
    "\n",
    "    # feat_self, feat_self_1hop, feat_self_2hop = build_feature_importance_decomposition(\n",
    "    #     data=data,\n",
    "    #     model=model,\n",
    "    #     adj_list=adj_list,\n",
    "    #     nodes_to_explain=nodes_to_explain,\n",
    "    #     hop=2,\n",
    "    #     use_true_class=args.use_true_class,\n",
    "    # )\n",
    "    # ---- Random baseline ----\n",
    "    print(\"\\n=== Building RANDOM importance baseline ===\")\n",
    "\n",
    "    t0 = _now()\n",
    "    feat_random = build_random_importance(\n",
    "        num_nodes=data.num_nodes,\n",
    "        num_features=data.x.size(1),\n",
    "        device=device,\n",
    "    )\n",
    "    t1 = _now()\n",
    "    method_times[\"random\"] = t1 - t0\n",
    "\n",
    "#     feat_random = build_random_importance(\n",
    "#         num_nodes=data.num_nodes,\n",
    "#         num_features=data.x.size(1),\n",
    "#         device=device\n",
    "# )\n",
    "    print(\"\\n=== Building GRAD-Cam ===\")\n",
    "    # feat_gradcam = gradcam_feature_importance(model, data, use_true_class=True)\n",
    "    t0 = _now()\n",
    "    feat_gradcam = gradcam_feature_importance(\n",
    "        model=model,\n",
    "        data=data,\n",
    "        use_true_class=args.use_true_class,\n",
    "    )\n",
    "    t1 = _now()\n",
    "    method_times[\"gradcam\"] = t1 - t0\n",
    "\n",
    "    print(\"\\n=== Building GNN-LRP ===\")\n",
    "#     feat_lrp = gnn_lrp_feature_importance(\n",
    "#     model,\n",
    "#     data,\n",
    "#     eps=1e-6,\n",
    "#     use_true_class=True,\n",
    "#     nodes_to_explain=nodes_to_explain  # same subset you use for fidelity\n",
    "# )\n",
    "    t0 = _now()\n",
    "    feat_lrp = gnn_lrp_feature_importance(\n",
    "        model=model,\n",
    "        data=data,\n",
    "        eps=1e-6,\n",
    "        use_true_class=args.use_true_class,\n",
    "        nodes_to_explain=nodes_to_explain,\n",
    "    )\n",
    "    t1 = _now()\n",
    "    method_times[\"lrp\"] = t1 - t0\n",
    "\n",
    "\n",
    "    print(\"\\n=== Building GOAT ===\")\n",
    "    # feat_goat = build_goat_feature_importance(\n",
    "    #     data=data,\n",
    "    #     model=model,\n",
    "    #     nodes_to_explain=nodes_to_explain,\n",
    "    # )\n",
    "    t0 = _now()\n",
    "    feat_goat = build_goat_feature_importance(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        nodes_to_explain=nodes_to_explain,\n",
    "    )\n",
    "    t1 = _now()\n",
    "    method_times[\"goat\"] = t1 - t0\n",
    "\n",
    "    print(\"\\n=== Building LIME ===\")\n",
    "    # feat_lime = build_lime_importance(\n",
    "    #     data=data,\n",
    "    #     model=model,\n",
    "    #     nodes_to_explain=nodes_to_explain,\n",
    "    #     num_classes=num_classes,\n",
    "    #     num_samples=50,\n",
    "    # )\n",
    "    t0 = _now()\n",
    "    feat_lime = build_lime_importance(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        nodes_to_explain=nodes_to_explain,  # or smaller subset\n",
    "        num_classes=num_classes,\n",
    "        num_samples=50,\n",
    "    )\n",
    "    t1 = _now()\n",
    "    method_times[\"lime\"] = t1 - t0\n",
    "\n",
    "    # GraphLime: if possible\n",
    "    # print(\"\\n=== Building GraphLIME ===\")\n",
    "    # t0 = _now()\n",
    "    # feat_glime = build_graphlime_importance(\n",
    "    # data=data,\n",
    "    # model=model,\n",
    "    # nodes_to_explain=nodes_to_explain,\n",
    "    # hop=2,\n",
    "    # rho=0.1,\n",
    "    # )\n",
    "    # t1 = _now()\n",
    "    # method_times[\"graphlime\"] = t1 - t0\n",
    "\n",
    "\n",
    "    #probably not use the following two ---\n",
    "\n",
    "    #     # ---- GNNExplainer baseline importance ----\n",
    "    # print(\"\\n=== Building GNNExplainer importance baseline ===\")\n",
    "    # feat_gnnexp = build_gnnexplainer_importance(\n",
    "    #     data=data,\n",
    "    #     model=model,\n",
    "    #     nodes_to_explain=nodes_to_explain,\n",
    "    #     epochs=50,   # you can change this\n",
    "    # )\n",
    "    print(\"\\n=== Building Integrated Gradients (Captum) importance baseline ===\")\n",
    "    t0 = _now()\n",
    "    feat_ig = build_ig_importance(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        nodes_to_explain=nodes_to_explain,\n",
    ")\n",
    "    t1 = _now()\n",
    "    method_times[\"ig\"] = t1 - t0\n",
    "\n",
    "    # 7) Fidelity scores for each importance variant\n",
    "    print(\"\\n=== Computing fidelity scores ===\")\n",
    "    fp_self, fm_self, fk_self = compute_fidelity_scores(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        feat_imp=feat_self,\n",
    "        nodes=nodes_to_explain,\n",
    "        k_fracs=args.k_fracs,\n",
    "        use_true_class=args.use_true_class,\n",
    "        label=\"SELF ONLY\",\n",
    "    )\n",
    "\n",
    "    fp_s1, fm_s1, fk_s1 = compute_fidelity_scores(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        feat_imp=feat_self_1hop,\n",
    "        nodes=nodes_to_explain,\n",
    "        k_fracs=args.k_fracs,\n",
    "        use_true_class=args.use_true_class,\n",
    "        label=\"SELF + 1HOP\",\n",
    "    )\n",
    "\n",
    "    fp_s2, fm_s2, fk_s2 = compute_fidelity_scores(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        feat_imp=feat_self_2hop,\n",
    "        nodes=nodes_to_explain,\n",
    "        k_fracs=args.k_fracs,\n",
    "        use_true_class=args.use_true_class,\n",
    "        label=\"SELF + 1HOP + 2HOP\",\n",
    "    )\n",
    "    fp_rand, fm_rand, fk_rand = compute_fidelity_scores(\n",
    "    data=data,\n",
    "    model=model,\n",
    "    feat_imp=feat_random,\n",
    "    nodes=nodes_to_explain,\n",
    "    k_fracs=args.k_fracs,\n",
    "    use_true_class=args.use_true_class,\n",
    "    label=\"RANDOM BASELINE\",\n",
    "    )\n",
    "\n",
    "    fp_gc, fm_gc, fk_gc = compute_fidelity_scores(\n",
    "    data=data,\n",
    "    model=model,\n",
    "    feat_imp=feat_gradcam,\n",
    "    nodes=nodes_to_explain,\n",
    "    k_fracs=args.k_fracs,\n",
    "    use_true_class=args.use_true_class,\n",
    "    label=\"GradCAM-feature\",\n",
    ")\n",
    "\n",
    "    fp_lrp, fm_lrp, fk_lrp = compute_fidelity_scores(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        feat_imp=feat_lrp,\n",
    "        nodes=nodes_to_explain,\n",
    "        k_fracs=args.k_fracs,\n",
    "        use_true_class=args.use_true_class,\n",
    "        label=\"GNN-LRP-feature\",\n",
    "    )\n",
    "\n",
    "    fp_goat, fm_goat, fk_goat = compute_fidelity_scores(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        feat_imp=feat_goat,\n",
    "        nodes=nodes_to_explain,\n",
    "        k_fracs=args.k_fracs,\n",
    "        use_true_class=args.use_true_class,\n",
    "        label=\"GOAT-feature\",\n",
    "    )\n",
    "    fp_lime, fm_lime, fk_lime = compute_fidelity_scores(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        feat_imp=feat_lime,\n",
    "        nodes=nodes_to_explain,\n",
    "        k_fracs=args.k_fracs,\n",
    "        use_true_class=args.use_true_class,\n",
    "        label=\"LIME\",\n",
    "    )\n",
    "    # fp_glime, fm_glime, fk_glime = compute_fidelity_scores(\n",
    "    # data=data,\n",
    "    # model=model,\n",
    "    # feat_imp=feat_glime,\n",
    "    # nodes=nodes_to_explain,\n",
    "    # k_fracs=args.k_fracs,\n",
    "    # use_true_class=args.use_true_class,\n",
    "    # label=\"GraphLIME\",\n",
    "    # )\n",
    "    # fp_gnn, fm_gnn, fk_gnn = compute_fidelity_scores(\n",
    "    # data=data,\n",
    "    # model=model,\n",
    "    # feat_imp=feat_gnnexp,\n",
    "    # nodes=nodes_to_explain,\n",
    "    # k_fracs=args.k_fracs,\n",
    "    # use_true_class=args.use_true_class,\n",
    "    # label=\"GNNEXPLAINER BASELINE\",\n",
    "    # )\n",
    "    fp_ig, fm_ig, fk_ig = compute_fidelity_scores(\n",
    "    data=data,\n",
    "    model=model,\n",
    "    feat_imp=feat_ig,\n",
    "    nodes=nodes_to_explain,\n",
    "    k_fracs=args.k_fracs,\n",
    "    use_true_class=args.use_true_class,\n",
    "    label=\"IG (Captum) BASELINE\",\n",
    "  )\n",
    "\n",
    "\n",
    "    # 8) Summary\n",
    "    print(\"\\n=== Summary: mean Δ = p_orig - p_mask ===\")\n",
    "    print(\"\\nFidelity+ (remove top-k% features):\")\n",
    "    for k in args.k_fracs:\n",
    "        print(\n",
    "            f\"k={k:.2f} | \"\n",
    "            f\"self={fp_self[k]:.4f} | \"\n",
    "            f\"self+1hop={fp_s1[k]:.4f} | \"\n",
    "            f\"self+2hop={fp_s2[k]:.4f} | \"\n",
    "            f\"rand={fp_rand[k]:.4f} | \"\n",
    "            f\"gradcam={fp_gc[k]:.4f} | \"\n",
    "            f\"lrp={fp_lrp[k]:.4f} | \"\n",
    "            f\"goat={fp_goat[k]:.4f} | \"\n",
    "            f\"lime={fp_lime[k]:.4f} | \"\n",
    "            f\"ig={fp_ig[k]:.4f} | \"\n",
    "        )\n",
    "\n",
    "    print(\"\\nFidelity- (remove bottom-k% features):\")\n",
    "    for k in args.k_fracs:\n",
    "        print(\n",
    "            f\"k={k:.2f} | \"\n",
    "            f\"self={fm_self[k]:.4f} | \"\n",
    "            f\"self+1hop={fm_s1[k]:.4f} | \"\n",
    "            f\"self+2hop={fm_s2[k]:.4f} | \"\n",
    "            f\"rand={fm_rand[k]:.4f} | \"\n",
    "            f\"gradcam={fm_gc[k]:.4f} | \"\n",
    "            f\"lrp={fm_lrp[k]:.4f} | \"\n",
    "            f\"goat={fm_goat[k]:.4f} | \"\n",
    "            f\"lime={fm_lime[k]:.4f} | \"\n",
    "            f\"ig={fm_ig[k]:.4f} | \"\n",
    "        )\n",
    "\n",
    "    print(\"\\n=== Build-Time Summary (seconds) ===\")\n",
    "    print(f\"{'method':20s} {'build_time':>12s}\")\n",
    "    for name, t in method_times.items():\n",
    "        print(f\"{name:20s} {t:12.3f}\")\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "  main()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "F23Zjt9wr3eU"
   },
   "source": [
    "## CiteSeer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "collapsed": true,
    "id": "qoarKn-Dr_ZN",
    "outputId": "50b720bf-ddc0-4891-fb69-00f1db71d711"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.x\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.tx\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.allx\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.y\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.ty\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.ally\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.graph\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.test.index\n",
      "Processing...\n",
      "Done!\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded Planetoid/CiteSeer:\n",
      "  #Nodes     = 3327\n",
      "  #Edges     = 9104\n",
      "  #Features  = 3703\n",
      "  #Classes   = 6\n",
      "  #Train/Val/Test = 120/500/1000\n",
      "#Train=120, #Val=500, #Test=1000\n",
      "\n",
      "=== Training GCN (no saved model found) ===\n",
      "Epoch 001 | Loss 1.7920 | Train 0.7583 | Val 0.4380 | Test 0.3970\n",
      "Epoch 020 | Loss 0.5752 | Train 0.9750 | Val 0.7120 | Test 0.7090\n",
      "Epoch 040 | Loss 0.2622 | Train 1.0000 | Val 0.6940 | Test 0.6930\n",
      "Epoch 060 | Loss 0.2213 | Train 1.0000 | Val 0.6920 | Test 0.6930\n",
      "Epoch 080 | Loss 0.2100 | Train 1.0000 | Val 0.7040 | Test 0.7040\n",
      "Epoch 100 | Loss 0.2135 | Train 1.0000 | Val 0.7060 | Test 0.7070\n",
      "Epoch 120 | Loss 0.1780 | Train 1.0000 | Val 0.6980 | Test 0.6920\n",
      "Epoch 140 | Loss 0.1979 | Train 1.0000 | Val 0.7000 | Test 0.6950\n",
      "Epoch 160 | Loss 0.1827 | Train 1.0000 | Val 0.6960 | Test 0.6720\n",
      "Epoch 180 | Loss 0.1781 | Train 0.9917 | Val 0.6980 | Test 0.6930\n",
      "Epoch 200 | Loss 0.1997 | Train 1.0000 | Val 0.7140 | Test 0.7120\n",
      "Epoch 220 | Loss 0.1706 | Train 1.0000 | Val 0.6980 | Test 0.6980\n",
      "Epoch 240 | Loss 0.1734 | Train 0.9917 | Val 0.6940 | Test 0.6820\n",
      "Epoch 260 | Loss 0.2158 | Train 1.0000 | Val 0.7040 | Test 0.7030\n",
      "Epoch 280 | Loss 0.1811 | Train 1.0000 | Val 0.7080 | Test 0.7100\n",
      "Epoch 300 | Loss 0.1873 | Train 1.0000 | Val 0.7080 | Test 0.7100\n",
      "Epoch 320 | Loss 0.1675 | Train 1.0000 | Val 0.7020 | Test 0.7030\n",
      "Epoch 340 | Loss 0.1903 | Train 1.0000 | Val 0.7160 | Test 0.7130\n",
      "Epoch 360 | Loss 0.1608 | Train 1.0000 | Val 0.6760 | Test 0.6640\n",
      "Epoch 380 | Loss 0.1879 | Train 1.0000 | Val 0.7060 | Test 0.7110\n",
      "Epoch 400 | Loss 0.1813 | Train 1.0000 | Val 0.7320 | Test 0.7190\n",
      "Epoch 420 | Loss 0.1809 | Train 1.0000 | Val 0.7020 | Test 0.6950\n",
      "Epoch 440 | Loss 0.1767 | Train 1.0000 | Val 0.7020 | Test 0.7090\n",
      "Epoch 460 | Loss 0.1760 | Train 1.0000 | Val 0.6960 | Test 0.7090\n",
      "Epoch 480 | Loss 0.1659 | Train 1.0000 | Val 0.6940 | Test 0.7040\n",
      "Epoch 500 | Loss 0.1901 | Train 1.0000 | Val 0.7040 | Test 0.6870\n",
      "\n",
      "Best epoch = 363 | Best val acc = 0.7320 | Test acc @best = 0.7200\n",
      "Model saved to citeseer_gcn.pt\n",
      "\n",
      "=== Building normalized adjacency list ===\n",
      "\n",
      "=== Explaining 100 test nodes (pred class as target) ===\n",
      "Building feature importance via decomposition for 100 nodes...\n",
      "  processed 1/100 nodes\n",
      "  processed 20/100 nodes\n",
      "  processed 40/100 nodes\n",
      "  processed 60/100 nodes\n",
      "  processed 80/100 nodes\n",
      "  processed 100/100 nodes\n",
      "\n",
      "=== Building RANDOM importance baseline ===\n",
      "\n",
      "=== Building GRAD-Cam ===\n",
      "\n",
      "=== Building GNN-LRP ===\n",
      "Running GNN-LRP for 100 nodes...\n",
      "  GNN-LRP processed 1/100 nodes\n",
      "  GNN-LRP processed 10/100 nodes\n",
      "  GNN-LRP processed 20/100 nodes\n",
      "  GNN-LRP processed 30/100 nodes\n",
      "  GNN-LRP processed 40/100 nodes\n",
      "  GNN-LRP processed 50/100 nodes\n",
      "  GNN-LRP processed 60/100 nodes\n",
      "  GNN-LRP processed 70/100 nodes\n",
      "  GNN-LRP processed 80/100 nodes\n",
      "  GNN-LRP processed 90/100 nodes\n",
      "  GNN-LRP processed 100/100 nodes\n",
      "\n",
      "=== Building GOAT ===\n",
      "Running GOAT for 100 nodes...\n",
      "  GOAT processed 1/100 nodes\n",
      "  GOAT processed 10/100 nodes\n",
      "  GOAT processed 20/100 nodes\n",
      "  GOAT processed 30/100 nodes\n",
      "  GOAT processed 40/100 nodes\n",
      "  GOAT processed 50/100 nodes\n",
      "  GOAT processed 60/100 nodes\n",
      "  GOAT processed 70/100 nodes\n",
      "  GOAT processed 80/100 nodes\n",
      "  GOAT processed 90/100 nodes\n",
      "  GOAT processed 100/100 nodes\n",
      "\n",
      "=== Building LIME ===\n",
      "Running LIME for 100 nodes...\n",
      "  LIME processed 1/100 nodes\n",
      "  LIME processed 10/100 nodes\n",
      "  LIME processed 20/100 nodes\n",
      "  LIME processed 30/100 nodes\n",
      "  LIME processed 40/100 nodes\n",
      "  LIME processed 50/100 nodes\n",
      "  LIME processed 60/100 nodes\n",
      "  LIME processed 70/100 nodes\n",
      "  LIME processed 80/100 nodes\n",
      "  LIME processed 90/100 nodes\n",
      "  LIME processed 100/100 nodes\n",
      "\n",
      "=== Building Integrated Gradients (Captum) importance baseline ===\n",
      "Running Clean IG for 100 nodes...\n",
      "  IG processed 1/100 nodes\n",
      "  IG processed 10/100 nodes\n",
      "  IG processed 20/100 nodes\n",
      "  IG processed 30/100 nodes\n",
      "  IG processed 40/100 nodes\n",
      "  IG processed 50/100 nodes\n",
      "  IG processed 60/100 nodes\n",
      "  IG processed 70/100 nodes\n",
      "  IG processed 80/100 nodes\n",
      "  IG processed 90/100 nodes\n",
      "  IG processed 100/100 nodes\n",
      "\n",
      "=== Computing fidelity scores ===\n",
      "\n",
      "=== Fidelity scores for importance: SELF ONLY ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=185) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.1513\n",
      "Fidelity- (remove BOTTOM-k)= 0.0000\n",
      "Keep-top (only TOP-k kept) = 0.0000\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=370) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.1513\n",
      "Fidelity- (remove BOTTOM-k)= 0.0000\n",
      "Keep-top (only TOP-k kept) = 0.0000\n",
      "\n",
      "=== Fidelity scores for importance: SELF + 1HOP ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=185) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.1513\n",
      "Fidelity- (remove BOTTOM-k)= 0.0000\n",
      "Keep-top (only TOP-k kept) = 0.0000\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=370) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.1513\n",
      "Fidelity- (remove BOTTOM-k)= 0.0000\n",
      "Keep-top (only TOP-k kept) = 0.0000\n",
      "\n",
      "=== Fidelity scores for importance: SELF + 1HOP + 2HOP ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=185) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.1514\n",
      "Fidelity- (remove BOTTOM-k)= 0.0000\n",
      "Keep-top (only TOP-k kept) = -0.0000\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=370) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.1513\n",
      "Fidelity- (remove BOTTOM-k)= 0.0000\n",
      "Keep-top (only TOP-k kept) = 0.0000\n",
      "\n",
      "=== Fidelity scores for importance: RANDOM BASELINE ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=185) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0055\n",
      "Fidelity- (remove BOTTOM-k)= 0.0063\n",
      "Keep-top (only TOP-k kept) = 0.1446\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=370) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0098\n",
      "Fidelity- (remove BOTTOM-k)= 0.0091\n",
      "Keep-top (only TOP-k kept) = 0.1387\n",
      "\n",
      "=== Fidelity scores for importance: GradCAM-feature ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=185) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.1114\n",
      "Fidelity- (remove BOTTOM-k)= 0.0002\n",
      "Keep-top (only TOP-k kept) = 0.0379\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=370) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.1213\n",
      "Fidelity- (remove BOTTOM-k)= 0.0027\n",
      "Keep-top (only TOP-k kept) = 0.0296\n",
      "\n",
      "=== Fidelity scores for importance: GNN-LRP-feature ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=185) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.1513\n",
      "Fidelity- (remove BOTTOM-k)= 0.0000\n",
      "Keep-top (only TOP-k kept) = 0.0000\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=370) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.1513\n",
      "Fidelity- (remove BOTTOM-k)= 0.0000\n",
      "Keep-top (only TOP-k kept) = 0.0000\n",
      "\n",
      "=== Fidelity scores for importance: GOAT-feature ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=185) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.1523\n",
      "Fidelity- (remove BOTTOM-k)= 0.0000\n",
      "Keep-top (only TOP-k kept) = -0.0009\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=370) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.1513\n",
      "Fidelity- (remove BOTTOM-k)= 0.0000\n",
      "Keep-top (only TOP-k kept) = 0.0000\n",
      "\n",
      "=== Fidelity scores for importance: LIME ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=185) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0281\n",
      "Fidelity- (remove BOTTOM-k)= 0.0063\n",
      "Keep-top (only TOP-k kept) = 0.1194\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=370) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0392\n",
      "Fidelity- (remove BOTTOM-k)= 0.0113\n",
      "Keep-top (only TOP-k kept) = 0.1078\n",
      "\n",
      "=== Fidelity scores for importance: IG (Captum) BASELINE ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=185) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.1239\n",
      "Fidelity- (remove BOTTOM-k)= 0.0000\n",
      "Keep-top (only TOP-k kept) = 0.0234\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=370) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.1326\n",
      "Fidelity- (remove BOTTOM-k)= 0.0000\n",
      "Keep-top (only TOP-k kept) = 0.0152\n",
      "\n",
      "=== Summary: mean Δ = p_orig - p_mask ===\n",
      "\n",
      "Fidelity+ (remove top-k% features):\n",
      "k=0.05 | self=0.1513 | self+1hop=0.1513 | self+2hop=0.1514 | rand=0.0055 | gradcam=0.1114 | lrp=0.1513 | goat=0.1523 | lime=0.0281 | ig=0.1239 | \n",
      "k=0.10 | self=0.1513 | self+1hop=0.1513 | self+2hop=0.1513 | rand=0.0098 | gradcam=0.1213 | lrp=0.1513 | goat=0.1513 | lime=0.0392 | ig=0.1326 | \n",
      "\n",
      "Fidelity- (remove bottom-k% features):\n",
      "k=0.05 | self=0.0000 | self+1hop=0.0000 | self+2hop=0.0000 | rand=0.0063 | gradcam=0.0002 | lrp=0.0000 | goat=0.0000 | lime=0.0063 | ig=0.0000 | \n",
      "k=0.10 | self=0.0000 | self+1hop=0.0000 | self+2hop=0.0000 | rand=0.0091 | gradcam=0.0027 | lrp=0.0000 | goat=0.0000 | lime=0.0113 | ig=0.0000 | \n",
      "\n",
      "=== Build-Time Summary (seconds) ===\n",
      "method                 build_time\n",
      "decomposition               0.193\n",
      "random                      0.047\n",
      "gradcam                     0.058\n",
      "lrp                         0.308\n",
      "goat                       27.527\n",
      "lime                       92.927\n",
      "ig                        194.829\n"
     ]
    }
   ],
   "source": [
    "from torch_geometric.datasets import Planetoid, HeterophilousGraphDataset\n",
    "import torch_geometric.transforms as T\n",
    "\n",
    "def create_splits(data, num_train_per_class=20, num_val_per_class=30):\n",
    "    y = data.y\n",
    "    num_nodes = data.num_nodes\n",
    "    num_classes = int(y.max().item() + 1)\n",
    "\n",
    "    train_mask = torch.zeros(num_nodes, dtype=torch.bool)\n",
    "    val_mask = torch.zeros(num_nodes, dtype=torch.bool)\n",
    "    test_mask = torch.zeros(num_nodes, dtype=torch.bool)\n",
    "\n",
    "    for c in range(num_classes):\n",
    "        idx = (y == c).nonzero(as_tuple=False).view(-1)\n",
    "        idx = idx[torch.randperm(idx.size(0))]\n",
    "\n",
    "        n_train = min(num_train_per_class, idx.size(0))\n",
    "        n_val = min(num_val_per_class, max(0, idx.size(0) - n_train))\n",
    "\n",
    "        train_idx = idx[:n_train]\n",
    "        val_idx = idx[n_train:n_train + n_val]\n",
    "        test_idx = idx[n_train + n_val:]\n",
    "\n",
    "        train_mask[train_idx] = True\n",
    "        val_mask[val_idx] = True\n",
    "        test_mask[test_idx] = True\n",
    "\n",
    "    data.train_mask = train_mask\n",
    "    data.val_mask = val_mask\n",
    "    data.test_mask = test_mask\n",
    "\n",
    "    print(f\"#Train nodes = {int(train_mask.sum())}\")\n",
    "    print(f\"#Val nodes   = {int(val_mask.sum())}\")\n",
    "    print(f\"#Test nodes  = {int(test_mask.sum())}\")\n",
    "\n",
    "    return data\n",
    "def load_planetoid(name: str, root: str):\n",
    "    \"\"\"\n",
    "    Load Planetoid datasets: 'Cora', 'CiteSeer', or 'PubMed'.\n",
    "\n",
    "    Returns:\n",
    "        data:         Data object (with built-in train/val/test masks)\n",
    "        in_dim:       num node features\n",
    "        num_classes:  num classes\n",
    "    \"\"\"\n",
    "    dataset = Planetoid(\n",
    "        root=root,\n",
    "        name=name,\n",
    "        transform=T.NormalizeFeatures(),   # common preprocessing\n",
    "    )\n",
    "    data = dataset[0]\n",
    "\n",
    "    in_dim = dataset.num_features\n",
    "    num_classes = dataset.num_classes\n",
    "\n",
    "    print(f\"Loaded Planetoid/{name}:\")\n",
    "    print(f\"  #Nodes     = {data.num_nodes}\")\n",
    "    print(f\"  #Edges     = {data.num_edges}\")\n",
    "    print(f\"  #Features  = {in_dim}\")\n",
    "    print(f\"  #Classes   = {num_classes}\")\n",
    "    print(f\"  #Train/Val/Test = {int(data.train_mask.sum())}/\"\n",
    "          f\"{int(data.val_mask.sum())}/{int(data.test_mask.sum())}\")\n",
    "\n",
    "    return data, in_dim, num_classes\n",
    "\n",
    "def load_cora(root):     return load_planetoid(\"Cora\", root)\n",
    "def load_citeseer(root): return load_planetoid(\"CiteSeer\", root)\n",
    "def load_pubmed(root):   return load_planetoid(\"PubMed\", root)\n",
    "\n",
    "def main():\n",
    "    method_times = {}\n",
    "    parser = argparse.ArgumentParser()\n",
    "    parser.add_argument(\"--seed\", type=int, default=42)\n",
    "    parser.add_argument(\"--hidden_dim\", type=int, default=64)\n",
    "    parser.add_argument(\"--dropout\", type=float, default=0.5)\n",
    "    parser.add_argument(\"--epochs\", type=int, default=500)\n",
    "    parser.add_argument(\"--lr\", type=float, default=0.05)\n",
    "    parser.add_argument(\"--weight_decay\", type=float, default=5e-4)\n",
    "    parser.add_argument(\"--root\", type=str, default=\"data/Planetoid\")\n",
    "    parser.add_argument(\"--k_fracs\", type=float, nargs=\"+\",\n",
    "                        default=[0.05, 0.1])\n",
    "    parser.add_argument(\"--num_explain\", type=int, default=100,\n",
    "                        help=\"how many test nodes to explain via decomposition\")\n",
    "    parser.add_argument(\"--use_true_class\", action=\"store_true\",\n",
    "                        help=\"use true label instead of predicted label as target class\")\n",
    "    parser.add_argument(\"--model_path\", type=str, default=\"citeseer_gcn.pt\",\n",
    "                        help=\"path to save/load trained GCN model\")\n",
    "    args = parser.parse_args(args=[])\n",
    "\n",
    "    set_seed(args.seed)\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "    # 1) Load data\n",
    "    data, in_dim, num_classes = load_citeseer(args.root)\n",
    "    data = data.to(device)\n",
    "    # If masks are 2D (e.g., [N, 10] for HeterophilousGraphDataset),\n",
    "    # pick one column (args.split_idx) and make them 1D.\n",
    "    if not hasattr(data, \"train_mask\"):\n",
    "        data = create_splits(data)\n",
    "\n",
    "    print(\n",
    "        f\"#Train={int(data.train_mask.sum())}, \"\n",
    "        f\"#Val={int(data.val_mask.sum())}, \"\n",
    "        f\"#Test={int(data.test_mask.sum())}\"\n",
    "    )\n",
    "\n",
    "    # 2) Build model\n",
    "    model = GCN(\n",
    "        in_channels=in_dim,\n",
    "        hidden_channels=args.hidden_dim,\n",
    "        out_channels=num_classes,\n",
    "        dropout=args.dropout,\n",
    "    ).to(device)\n",
    "\n",
    "    # 3) Train or load model\n",
    "    if os.path.exists(args.model_path):\n",
    "        print(f\"\\n=== Loading existing model from {args.model_path} ===\")\n",
    "        state = torch.load(args.model_path, map_location=device)\n",
    "        model.load_state_dict(state)\n",
    "        train_acc, val_acc, test_acc, _ = evaluate(model, data)\n",
    "        print(f\"Loaded model | Train={train_acc:.4f} Val={val_acc:.4f} Test={test_acc:.4f}\")\n",
    "    else:\n",
    "        print(\"\\n=== Training GCN (no saved model found) ===\")\n",
    "        optimizer = torch.optim.Adam(\n",
    "            model.parameters(),\n",
    "            lr=args.lr,\n",
    "            weight_decay=args.weight_decay,\n",
    "        )\n",
    "\n",
    "        best_val_acc = 0.0\n",
    "        best_state = None\n",
    "\n",
    "        for epoch in range(1, args.epochs + 1):\n",
    "            loss = train(model, data, optimizer)\n",
    "            train_acc, val_acc, test_acc, _ = evaluate(model, data)\n",
    "\n",
    "            if val_acc > best_val_acc:\n",
    "                best_val_acc = val_acc\n",
    "                best_state = {\n",
    "                    \"model\": model.state_dict(),\n",
    "                    \"epoch\": epoch,\n",
    "                    \"test_acc\": test_acc,\n",
    "                }\n",
    "\n",
    "            if epoch % 20 == 0 or epoch == 1:\n",
    "                print(\n",
    "                    f\"Epoch {epoch:03d} | \"\n",
    "                    f\"Loss {loss:.4f} | \"\n",
    "                    f\"Train {train_acc:.4f} | \"\n",
    "                    f\"Val {val_acc:.4f} | \"\n",
    "                    f\"Test {test_acc:.4f}\"\n",
    "                )\n",
    "\n",
    "        if best_state is not None:\n",
    "            model.load_state_dict(best_state[\"model\"])\n",
    "            print(\n",
    "                f\"\\nBest epoch = {best_state['epoch']} | \"\n",
    "                f\"Best val acc = {best_val_acc:.4f} | \"\n",
    "                f\"Test acc @best = {best_state['test_acc']:.4f}\"\n",
    "            )\n",
    "\n",
    "        # Save trained model\n",
    "        torch.save(model.state_dict(), args.model_path)\n",
    "        print(f\"Model saved to {args.model_path}\")\n",
    "\n",
    "    # 4) Build normalized adjacency list\n",
    "    print(\"\\n=== Building normalized adjacency list ===\")\n",
    "    adj_list = build_normalized_adjacency_list(model, data)\n",
    "\n",
    "    # 5) Choose test nodes to explain (correct predictions preferred)\n",
    "    model.eval()\n",
    "    logits = model(data.x, data.edge_index)\n",
    "    preds = logits.argmax(dim=-1)\n",
    "\n",
    "    test_nodes = torch.nonzero(data.test_mask, as_tuple=False).view(-1)\n",
    "    correct_mask = preds == data.y\n",
    "    test_correct_nodes = test_nodes[correct_mask[test_nodes]]\n",
    "\n",
    "    if test_correct_nodes.numel() == 0:\n",
    "        print(\"WARNING: no correctly classified test nodes, using all test nodes.\")\n",
    "        nodes_to_explain = test_nodes\n",
    "    else:\n",
    "        nodes_to_explain = test_correct_nodes\n",
    "\n",
    "    if nodes_to_explain.numel() > args.num_explain:\n",
    "        nodes_to_explain = nodes_to_explain[:args.num_explain]\n",
    "\n",
    "    print(\n",
    "        f\"\\n=== Explaining {nodes_to_explain.numel()} test nodes \"\n",
    "        f\"({'true' if args.use_true_class else 'pred'} class as target) ===\"\n",
    "    )\n",
    "\n",
    "    # 6) Build feature importance: self vs self+1hop vs self+2hop\n",
    "\n",
    "    t0 = _now()\n",
    "    feat_self, feat_self_1hop, feat_self_2hop = build_feature_importance_decomposition(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        adj_list=adj_list,\n",
    "        nodes_to_explain=nodes_to_explain,\n",
    "        hop=2,\n",
    "        use_true_class=args.use_true_class,\n",
    "    )\n",
    "    t1 = _now()\n",
    "    method_times[\"decomposition\"] = t1 - t0\n",
    "\n",
    "    # feat_self, feat_self_1hop, feat_self_2hop = build_feature_importance_decomposition(\n",
    "    #     data=data,\n",
    "    #     model=model,\n",
    "    #     adj_list=adj_list,\n",
    "    #     nodes_to_explain=nodes_to_explain,\n",
    "    #     hop=2,\n",
    "    #     use_true_class=args.use_true_class,\n",
    "    # )\n",
    "    # ---- Random baseline ----\n",
    "    print(\"\\n=== Building RANDOM importance baseline ===\")\n",
    "\n",
    "    t0 = _now()\n",
    "    feat_random = build_random_importance(\n",
    "        num_nodes=data.num_nodes,\n",
    "        num_features=data.x.size(1),\n",
    "        device=device,\n",
    "    )\n",
    "    t1 = _now()\n",
    "    method_times[\"random\"] = t1 - t0\n",
    "\n",
    "#     feat_random = build_random_importance(\n",
    "#         num_nodes=data.num_nodes,\n",
    "#         num_features=data.x.size(1),\n",
    "#         device=device\n",
    "# )\n",
    "    print(\"\\n=== Building GRAD-Cam ===\")\n",
    "    # feat_gradcam = gradcam_feature_importance(model, data, use_true_class=True)\n",
    "    t0 = _now()\n",
    "    feat_gradcam = gradcam_feature_importance(\n",
    "        model=model,\n",
    "        data=data,\n",
    "        use_true_class=args.use_true_class,\n",
    "    )\n",
    "    t1 = _now()\n",
    "    method_times[\"gradcam\"] = t1 - t0\n",
    "\n",
    "    print(\"\\n=== Building GNN-LRP ===\")\n",
    "#     feat_lrp = gnn_lrp_feature_importance(\n",
    "#     model,\n",
    "#     data,\n",
    "#     eps=1e-6,\n",
    "#     use_true_class=True,\n",
    "#     nodes_to_explain=nodes_to_explain  # same subset you use for fidelity\n",
    "# )\n",
    "    t0 = _now()\n",
    "    feat_lrp = gnn_lrp_feature_importance(\n",
    "        model=model,\n",
    "        data=data,\n",
    "        eps=1e-6,\n",
    "        use_true_class=args.use_true_class,\n",
    "        nodes_to_explain=nodes_to_explain,\n",
    "    )\n",
    "    t1 = _now()\n",
    "    method_times[\"lrp\"] = t1 - t0\n",
    "\n",
    "\n",
    "    print(\"\\n=== Building GOAT ===\")\n",
    "    # feat_goat = build_goat_feature_importance(\n",
    "    #     data=data,\n",
    "    #     model=model,\n",
    "    #     nodes_to_explain=nodes_to_explain,\n",
    "    # )\n",
    "    t0 = _now()\n",
    "    feat_goat = build_goat_feature_importance(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        nodes_to_explain=nodes_to_explain,\n",
    "    )\n",
    "    t1 = _now()\n",
    "    method_times[\"goat\"] = t1 - t0\n",
    "\n",
    "    print(\"\\n=== Building LIME ===\")\n",
    "    # feat_lime = build_lime_importance(\n",
    "    #     data=data,\n",
    "    #     model=model,\n",
    "    #     nodes_to_explain=nodes_to_explain,\n",
    "    #     num_classes=num_classes,\n",
    "    #     num_samples=50,\n",
    "    # )\n",
    "    t0 = _now()\n",
    "    feat_lime = build_lime_importance(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        nodes_to_explain=nodes_to_explain,  # or smaller subset\n",
    "        num_classes=num_classes,\n",
    "        num_samples=50,\n",
    "    )\n",
    "    t1 = _now()\n",
    "    method_times[\"lime\"] = t1 - t0\n",
    "\n",
    "    # GraphLime: if possible\n",
    "    # print(\"\\n=== Building GraphLIME ===\")\n",
    "    # t0 = _now()\n",
    "    # feat_glime = build_graphlime_importance(\n",
    "    # data=data,\n",
    "    # model=model,\n",
    "    # nodes_to_explain=nodes_to_explain,\n",
    "    # hop=2,\n",
    "    # rho=0.1,\n",
    "    # )\n",
    "    # t1 = _now()\n",
    "    # method_times[\"graphlime\"] = t1 - t0\n",
    "\n",
    "\n",
    "    #probably not use the following two ---\n",
    "\n",
    "    #     # ---- GNNExplainer baseline importance ----\n",
    "    # print(\"\\n=== Building GNNExplainer importance baseline ===\")\n",
    "    # feat_gnnexp = build_gnnexplainer_importance(\n",
    "    #     data=data,\n",
    "    #     model=model,\n",
    "    #     nodes_to_explain=nodes_to_explain,\n",
    "    #     epochs=50,   # you can change this\n",
    "    # )\n",
    "    print(\"\\n=== Building Integrated Gradients (Captum) importance baseline ===\")\n",
    "    t0 = _now()\n",
    "    feat_ig = build_ig_importance(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        nodes_to_explain=nodes_to_explain,\n",
    ")\n",
    "    t1 = _now()\n",
    "    method_times[\"ig\"] = t1 - t0\n",
    "\n",
    "    # 7) Fidelity scores for each importance variant\n",
    "    print(\"\\n=== Computing fidelity scores ===\")\n",
    "    fp_self, fm_self, fk_self = compute_fidelity_scores(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        feat_imp=feat_self,\n",
    "        nodes=nodes_to_explain,\n",
    "        k_fracs=args.k_fracs,\n",
    "        use_true_class=args.use_true_class,\n",
    "        label=\"SELF ONLY\",\n",
    "    )\n",
    "\n",
    "    fp_s1, fm_s1, fk_s1 = compute_fidelity_scores(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        feat_imp=feat_self_1hop,\n",
    "        nodes=nodes_to_explain,\n",
    "        k_fracs=args.k_fracs,\n",
    "        use_true_class=args.use_true_class,\n",
    "        label=\"SELF + 1HOP\",\n",
    "    )\n",
    "\n",
    "    fp_s2, fm_s2, fk_s2 = compute_fidelity_scores(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        feat_imp=feat_self_2hop,\n",
    "        nodes=nodes_to_explain,\n",
    "        k_fracs=args.k_fracs,\n",
    "        use_true_class=args.use_true_class,\n",
    "        label=\"SELF + 1HOP + 2HOP\",\n",
    "    )\n",
    "    fp_rand, fm_rand, fk_rand = compute_fidelity_scores(\n",
    "    data=data,\n",
    "    model=model,\n",
    "    feat_imp=feat_random,\n",
    "    nodes=nodes_to_explain,\n",
    "    k_fracs=args.k_fracs,\n",
    "    use_true_class=args.use_true_class,\n",
    "    label=\"RANDOM BASELINE\",\n",
    "    )\n",
    "\n",
    "    fp_gc, fm_gc, fk_gc = compute_fidelity_scores(\n",
    "    data=data,\n",
    "    model=model,\n",
    "    feat_imp=feat_gradcam,\n",
    "    nodes=nodes_to_explain,\n",
    "    k_fracs=args.k_fracs,\n",
    "    use_true_class=args.use_true_class,\n",
    "    label=\"GradCAM-feature\",\n",
    ")\n",
    "\n",
    "    fp_lrp, fm_lrp, fk_lrp = compute_fidelity_scores(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        feat_imp=feat_lrp,\n",
    "        nodes=nodes_to_explain,\n",
    "        k_fracs=args.k_fracs,\n",
    "        use_true_class=args.use_true_class,\n",
    "        label=\"GNN-LRP-feature\",\n",
    "    )\n",
    "\n",
    "    fp_goat, fm_goat, fk_goat = compute_fidelity_scores(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        feat_imp=feat_goat,\n",
    "        nodes=nodes_to_explain,\n",
    "        k_fracs=args.k_fracs,\n",
    "        use_true_class=args.use_true_class,\n",
    "        label=\"GOAT-feature\",\n",
    "    )\n",
    "    fp_lime, fm_lime, fk_lime = compute_fidelity_scores(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        feat_imp=feat_lime,\n",
    "        nodes=nodes_to_explain,\n",
    "        k_fracs=args.k_fracs,\n",
    "        use_true_class=args.use_true_class,\n",
    "        label=\"LIME\",\n",
    "    )\n",
    "    # fp_glime, fm_glime, fk_glime = compute_fidelity_scores(\n",
    "    # data=data,\n",
    "    # model=model,\n",
    "    # feat_imp=feat_glime,\n",
    "    # nodes=nodes_to_explain,\n",
    "    # k_fracs=args.k_fracs,\n",
    "    # use_true_class=args.use_true_class,\n",
    "    # label=\"GraphLIME\",\n",
    "    # )\n",
    "    # fp_gnn, fm_gnn, fk_gnn = compute_fidelity_scores(\n",
    "    # data=data,\n",
    "    # model=model,\n",
    "    # feat_imp=feat_gnnexp,\n",
    "    # nodes=nodes_to_explain,\n",
    "    # k_fracs=args.k_fracs,\n",
    "    # use_true_class=args.use_true_class,\n",
    "    # label=\"GNNEXPLAINER BASELINE\",\n",
    "    # )\n",
    "    fp_ig, fm_ig, fk_ig = compute_fidelity_scores(\n",
    "    data=data,\n",
    "    model=model,\n",
    "    feat_imp=feat_ig,\n",
    "    nodes=nodes_to_explain,\n",
    "    k_fracs=args.k_fracs,\n",
    "    use_true_class=args.use_true_class,\n",
    "    label=\"IG (Captum) BASELINE\",\n",
    "  )\n",
    "\n",
    "\n",
    "    # 8) Summary\n",
    "    print(\"\\n=== Summary: mean Δ = p_orig - p_mask ===\")\n",
    "    print(\"\\nFidelity+ (remove top-k% features):\")\n",
    "    for k in args.k_fracs:\n",
    "        print(\n",
    "            f\"k={k:.2f} | \"\n",
    "            f\"self={fp_self[k]:.4f} | \"\n",
    "            f\"self+1hop={fp_s1[k]:.4f} | \"\n",
    "            f\"self+2hop={fp_s2[k]:.4f} | \"\n",
    "            f\"rand={fp_rand[k]:.4f} | \"\n",
    "            f\"gradcam={fp_gc[k]:.4f} | \"\n",
    "            f\"lrp={fp_lrp[k]:.4f} | \"\n",
    "            f\"goat={fp_goat[k]:.4f} | \"\n",
    "            f\"lime={fp_lime[k]:.4f} | \"\n",
    "            f\"ig={fp_ig[k]:.4f} | \"\n",
    "        )\n",
    "\n",
    "    print(\"\\nFidelity- (remove bottom-k% features):\")\n",
    "    for k in args.k_fracs:\n",
    "        print(\n",
    "            f\"k={k:.2f} | \"\n",
    "            f\"self={fm_self[k]:.4f} | \"\n",
    "            f\"self+1hop={fm_s1[k]:.4f} | \"\n",
    "            f\"self+2hop={fm_s2[k]:.4f} | \"\n",
    "            f\"rand={fm_rand[k]:.4f} | \"\n",
    "            f\"gradcam={fm_gc[k]:.4f} | \"\n",
    "            f\"lrp={fm_lrp[k]:.4f} | \"\n",
    "            f\"goat={fm_goat[k]:.4f} | \"\n",
    "            f\"lime={fm_lime[k]:.4f} | \"\n",
    "            f\"ig={fm_ig[k]:.4f} | \"\n",
    "        )\n",
    "\n",
    "    print(\"\\n=== Build-Time Summary (seconds) ===\")\n",
    "    print(f\"{'method':20s} {'build_time':>12s}\")\n",
    "    for name, t in method_times.items():\n",
    "        print(f\"{name:20s} {t:12.3f}\")\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "  main()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-EJD8zd7C1yW"
   },
   "source": [
    "## PubMed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "aWv2ytkOC7yq",
    "outputId": "61480357-baf2-45df-c0ea-d811ff061dd4"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.x\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.tx\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.allx\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.y\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.ty\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.ally\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.graph\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.test.index\n",
      "Processing...\n",
      "Done!\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded Planetoid/PubMed:\n",
      "  #Nodes     = 19717\n",
      "  #Edges     = 88648\n",
      "  #Features  = 500\n",
      "  #Classes   = 3\n",
      "  #Train/Val/Test = 60/500/1000\n",
      "#Train=60, #Val=500, #Test=1000\n",
      "\n",
      "=== Training GCN (no saved model found) ===\n",
      "Epoch 001 | Loss 1.0971 | Train 0.6833 | Val 0.5240 | Test 0.5010\n",
      "Epoch 020 | Loss 0.1443 | Train 0.9833 | Val 0.7940 | Test 0.7890\n",
      "Epoch 040 | Loss 0.0918 | Train 1.0000 | Val 0.7480 | Test 0.7590\n",
      "Epoch 060 | Loss 0.0724 | Train 1.0000 | Val 0.7860 | Test 0.7860\n",
      "Epoch 080 | Loss 0.0638 | Train 1.0000 | Val 0.7720 | Test 0.7560\n",
      "Epoch 100 | Loss 0.0833 | Train 1.0000 | Val 0.7820 | Test 0.7710\n",
      "Epoch 120 | Loss 0.0720 | Train 1.0000 | Val 0.7700 | Test 0.7820\n",
      "Epoch 140 | Loss 0.0627 | Train 1.0000 | Val 0.7900 | Test 0.7790\n",
      "Epoch 160 | Loss 0.0727 | Train 1.0000 | Val 0.7780 | Test 0.7870\n",
      "Epoch 180 | Loss 0.0666 | Train 1.0000 | Val 0.8000 | Test 0.7890\n",
      "Epoch 200 | Loss 0.0549 | Train 1.0000 | Val 0.7700 | Test 0.7730\n",
      "Epoch 220 | Loss 0.0541 | Train 1.0000 | Val 0.7740 | Test 0.7710\n",
      "Epoch 240 | Loss 0.0620 | Train 1.0000 | Val 0.7840 | Test 0.7900\n",
      "Epoch 260 | Loss 0.0636 | Train 1.0000 | Val 0.7820 | Test 0.7900\n",
      "Epoch 280 | Loss 0.0748 | Train 1.0000 | Val 0.7880 | Test 0.7910\n",
      "Epoch 300 | Loss 0.0700 | Train 1.0000 | Val 0.7780 | Test 0.7860\n",
      "Epoch 320 | Loss 0.0664 | Train 1.0000 | Val 0.7840 | Test 0.7940\n",
      "Epoch 340 | Loss 0.0510 | Train 1.0000 | Val 0.7860 | Test 0.7750\n",
      "Epoch 360 | Loss 0.0597 | Train 1.0000 | Val 0.7880 | Test 0.7790\n",
      "Epoch 380 | Loss 0.0740 | Train 1.0000 | Val 0.7820 | Test 0.7690\n",
      "Epoch 400 | Loss 0.0575 | Train 1.0000 | Val 0.7920 | Test 0.7870\n",
      "Epoch 420 | Loss 0.0568 | Train 1.0000 | Val 0.7800 | Test 0.7810\n",
      "Epoch 440 | Loss 0.0589 | Train 1.0000 | Val 0.7740 | Test 0.7690\n",
      "Epoch 460 | Loss 0.0641 | Train 1.0000 | Val 0.7860 | Test 0.7960\n",
      "Epoch 480 | Loss 0.0693 | Train 1.0000 | Val 0.7740 | Test 0.7910\n",
      "Epoch 500 | Loss 0.0831 | Train 1.0000 | Val 0.7780 | Test 0.7760\n",
      "\n",
      "Best epoch = 17 | Best val acc = 0.8060 | Test acc @best = 0.7930\n",
      "Model saved to pubmed_gcn.pt\n",
      "\n",
      "=== Building normalized adjacency list ===\n",
      "\n",
      "=== Explaining 100 test nodes (pred class as target) ===\n",
      "Building feature importance via decomposition for 100 nodes...\n",
      "  processed 1/100 nodes\n",
      "  processed 20/100 nodes\n",
      "  processed 40/100 nodes\n",
      "  processed 60/100 nodes\n",
      "  processed 80/100 nodes\n",
      "  processed 100/100 nodes\n",
      "\n",
      "=== Building RANDOM importance baseline ===\n",
      "\n",
      "=== Building GRAD-Cam ===\n",
      "\n",
      "=== Building GNN-LRP ===\n",
      "Running GNN-LRP for 100 nodes...\n",
      "  GNN-LRP processed 1/100 nodes\n",
      "  GNN-LRP processed 10/100 nodes\n",
      "  GNN-LRP processed 20/100 nodes\n",
      "  GNN-LRP processed 30/100 nodes\n",
      "  GNN-LRP processed 40/100 nodes\n",
      "  GNN-LRP processed 50/100 nodes\n",
      "  GNN-LRP processed 60/100 nodes\n",
      "  GNN-LRP processed 70/100 nodes\n",
      "  GNN-LRP processed 80/100 nodes\n",
      "  GNN-LRP processed 90/100 nodes\n",
      "  GNN-LRP processed 100/100 nodes\n",
      "\n",
      "=== Building GOAT ===\n",
      "Running GOAT for 100 nodes...\n",
      "  GOAT processed 1/100 nodes\n",
      "  GOAT processed 10/100 nodes\n",
      "  GOAT processed 20/100 nodes\n",
      "  GOAT processed 30/100 nodes\n",
      "  GOAT processed 40/100 nodes\n",
      "  GOAT processed 50/100 nodes\n",
      "  GOAT processed 60/100 nodes\n",
      "  GOAT processed 70/100 nodes\n",
      "  GOAT processed 80/100 nodes\n",
      "  GOAT processed 90/100 nodes\n",
      "  GOAT processed 100/100 nodes\n",
      "\n",
      "=== Building LIME ===\n",
      "Running LIME for 100 nodes...\n",
      "  LIME processed 1/100 nodes\n",
      "  LIME processed 10/100 nodes\n",
      "  LIME processed 20/100 nodes\n",
      "  LIME processed 30/100 nodes\n",
      "  LIME processed 40/100 nodes\n",
      "  LIME processed 50/100 nodes\n",
      "  LIME processed 60/100 nodes\n",
      "  LIME processed 70/100 nodes\n",
      "  LIME processed 80/100 nodes\n",
      "  LIME processed 90/100 nodes\n",
      "  LIME processed 100/100 nodes\n",
      "\n",
      "=== Building Integrated Gradients (Captum) importance baseline ===\n",
      "Running Clean IG for 100 nodes...\n",
      "  IG processed 1/100 nodes\n",
      "  IG processed 10/100 nodes\n",
      "  IG processed 20/100 nodes\n",
      "  IG processed 30/100 nodes\n",
      "  IG processed 40/100 nodes\n",
      "  IG processed 50/100 nodes\n",
      "  IG processed 60/100 nodes\n",
      "  IG processed 70/100 nodes\n",
      "  IG processed 80/100 nodes\n",
      "  IG processed 90/100 nodes\n",
      "  IG processed 100/100 nodes\n",
      "\n",
      "=== Computing fidelity scores ===\n",
      "\n",
      "=== Fidelity scores for importance: SELF ONLY ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=25) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0367\n",
      "Fidelity- (remove BOTTOM-k)= 0.0000\n",
      "Keep-top (only TOP-k kept) = -0.0005\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=50) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0363\n",
      "Fidelity- (remove BOTTOM-k)= 0.0000\n",
      "Keep-top (only TOP-k kept) = -0.0001\n",
      "\n",
      "=== Fidelity scores for importance: SELF + 1HOP ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=25) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0384\n",
      "Fidelity- (remove BOTTOM-k)= 0.0000\n",
      "Keep-top (only TOP-k kept) = -0.0018\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=50) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0368\n",
      "Fidelity- (remove BOTTOM-k)= 0.0000\n",
      "Keep-top (only TOP-k kept) = -0.0005\n",
      "\n",
      "=== Fidelity scores for importance: SELF + 1HOP + 2HOP ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=25) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0384\n",
      "Fidelity- (remove BOTTOM-k)= 0.0000\n",
      "Keep-top (only TOP-k kept) = -0.0019\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=50) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0374\n",
      "Fidelity- (remove BOTTOM-k)= 0.0000\n",
      "Keep-top (only TOP-k kept) = -0.0011\n",
      "\n",
      "=== Fidelity scores for importance: RANDOM BASELINE ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=25) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0010\n",
      "Fidelity- (remove BOTTOM-k)= 0.0021\n",
      "Keep-top (only TOP-k kept) = 0.0350\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=50) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0015\n",
      "Fidelity- (remove BOTTOM-k)= 0.0049\n",
      "Keep-top (only TOP-k kept) = 0.0342\n",
      "\n",
      "=== Fidelity scores for importance: GradCAM-feature ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=25) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0215\n",
      "Fidelity- (remove BOTTOM-k)= 0.0005\n",
      "Keep-top (only TOP-k kept) = 0.0141\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=50) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0306\n",
      "Fidelity- (remove BOTTOM-k)= -0.0026\n",
      "Keep-top (only TOP-k kept) = 0.0084\n",
      "\n",
      "=== Fidelity scores for importance: GNN-LRP-feature ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=25) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0313\n",
      "Fidelity- (remove BOTTOM-k)= 0.0000\n",
      "Keep-top (only TOP-k kept) = 0.0041\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=50) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0341\n",
      "Fidelity- (remove BOTTOM-k)= 0.0000\n",
      "Keep-top (only TOP-k kept) = 0.0018\n",
      "\n",
      "=== Fidelity scores for importance: GOAT-feature ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=25) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0389\n",
      "Fidelity- (remove BOTTOM-k)= 0.0000\n",
      "Keep-top (only TOP-k kept) = -0.0022\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=50) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0406\n",
      "Fidelity- (remove BOTTOM-k)= -0.0000\n",
      "Keep-top (only TOP-k kept) = -0.0033\n",
      "\n",
      "=== Fidelity scores for importance: LIME ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=25) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0010\n",
      "Fidelity- (remove BOTTOM-k)= 0.0011\n",
      "Keep-top (only TOP-k kept) = 0.0344\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=50) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0003\n",
      "Fidelity- (remove BOTTOM-k)= 0.0037\n",
      "Keep-top (only TOP-k kept) = 0.0350\n",
      "\n",
      "=== Fidelity scores for importance: IG (Captum) BASELINE ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=25) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0196\n",
      "Fidelity- (remove BOTTOM-k)= -0.0003\n",
      "Keep-top (only TOP-k kept) = 0.0145\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=50) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0280\n",
      "Fidelity- (remove BOTTOM-k)= -0.0006\n",
      "Keep-top (only TOP-k kept) = 0.0068\n",
      "\n",
      "=== Summary: mean Δ = p_orig - p_mask ===\n",
      "\n",
      "Fidelity+ (remove top-k% features):\n",
      "k=0.05 | self=0.0367 | self+1hop=0.0384 | self+2hop=0.0384 | rand=0.0010 | gradcam=0.0215 | lrp=0.0313 | goat=0.0389 | lime=0.0010 | ig=0.0196 | \n",
      "k=0.10 | self=0.0363 | self+1hop=0.0368 | self+2hop=0.0374 | rand=0.0015 | gradcam=0.0306 | lrp=0.0341 | goat=0.0406 | lime=0.0003 | ig=0.0280 | \n",
      "\n",
      "Fidelity- (remove bottom-k% features):\n",
      "k=0.05 | self=0.0000 | self+1hop=0.0000 | self+2hop=0.0000 | rand=0.0021 | gradcam=0.0005 | lrp=0.0000 | goat=0.0000 | lime=0.0011 | ig=-0.0003 | \n",
      "k=0.10 | self=0.0000 | self+1hop=0.0000 | self+2hop=0.0000 | rand=0.0049 | gradcam=-0.0026 | lrp=0.0000 | goat=-0.0000 | lime=0.0037 | ig=-0.0006 | \n",
      "\n",
      "=== Build-Time Summary (seconds) ===\n",
      "method                 build_time\n",
      "decomposition               0.749\n",
      "random                      0.038\n",
      "gradcam                     0.065\n",
      "lrp                         0.535\n",
      "goat                      437.584\n",
      "lime                      118.167\n",
      "ig                        224.865\n"
     ]
    }
   ],
   "source": [
    "from torch_geometric.datasets import Planetoid, HeterophilousGraphDataset\n",
    "import torch_geometric.transforms as T\n",
    "\n",
    "def create_splits(data, num_train_per_class=20, num_val_per_class=30):\n",
    "    y = data.y\n",
    "    num_nodes = data.num_nodes\n",
    "    num_classes = int(y.max().item() + 1)\n",
    "\n",
    "    train_mask = torch.zeros(num_nodes, dtype=torch.bool)\n",
    "    val_mask = torch.zeros(num_nodes, dtype=torch.bool)\n",
    "    test_mask = torch.zeros(num_nodes, dtype=torch.bool)\n",
    "\n",
    "    for c in range(num_classes):\n",
    "        idx = (y == c).nonzero(as_tuple=False).view(-1)\n",
    "        idx = idx[torch.randperm(idx.size(0))]\n",
    "\n",
    "        n_train = min(num_train_per_class, idx.size(0))\n",
    "        n_val = min(num_val_per_class, max(0, idx.size(0) - n_train))\n",
    "\n",
    "        train_idx = idx[:n_train]\n",
    "        val_idx = idx[n_train:n_train + n_val]\n",
    "        test_idx = idx[n_train + n_val:]\n",
    "\n",
    "        train_mask[train_idx] = True\n",
    "        val_mask[val_idx] = True\n",
    "        test_mask[test_idx] = True\n",
    "\n",
    "    data.train_mask = train_mask\n",
    "    data.val_mask = val_mask\n",
    "    data.test_mask = test_mask\n",
    "\n",
    "    print(f\"#Train nodes = {int(train_mask.sum())}\")\n",
    "    print(f\"#Val nodes   = {int(val_mask.sum())}\")\n",
    "    print(f\"#Test nodes  = {int(test_mask.sum())}\")\n",
    "\n",
    "    return data\n",
    "def load_planetoid(name: str, root: str):\n",
    "    \"\"\"\n",
    "    Load Planetoid datasets: 'Cora', 'CiteSeer', or 'PubMed'.\n",
    "\n",
    "    Returns:\n",
    "        data:         Data object (with built-in train/val/test masks)\n",
    "        in_dim:       num node features\n",
    "        num_classes:  num classes\n",
    "    \"\"\"\n",
    "    dataset = Planetoid(\n",
    "        root=root,\n",
    "        name=name,\n",
    "        transform=T.NormalizeFeatures(),   # common preprocessing\n",
    "    )\n",
    "    data = dataset[0]\n",
    "\n",
    "    in_dim = dataset.num_features\n",
    "    num_classes = dataset.num_classes\n",
    "\n",
    "    print(f\"Loaded Planetoid/{name}:\")\n",
    "    print(f\"  #Nodes     = {data.num_nodes}\")\n",
    "    print(f\"  #Edges     = {data.num_edges}\")\n",
    "    print(f\"  #Features  = {in_dim}\")\n",
    "    print(f\"  #Classes   = {num_classes}\")\n",
    "    print(f\"  #Train/Val/Test = {int(data.train_mask.sum())}/\"\n",
    "          f\"{int(data.val_mask.sum())}/{int(data.test_mask.sum())}\")\n",
    "\n",
    "    return data, in_dim, num_classes\n",
    "\n",
    "def load_cora(root):     return load_planetoid(\"Cora\", root)\n",
    "def load_citeseer(root): return load_planetoid(\"CiteSeer\", root)\n",
    "def load_pubmed(root):   return load_planetoid(\"PubMed\", root)\n",
    "\n",
    "def main():\n",
    "    method_times = {}\n",
    "    parser = argparse.ArgumentParser()\n",
    "    parser.add_argument(\"--seed\", type=int, default=42)\n",
    "    parser.add_argument(\"--hidden_dim\", type=int, default=64)\n",
    "    parser.add_argument(\"--dropout\", type=float, default=0.5)\n",
    "    parser.add_argument(\"--epochs\", type=int, default=500)\n",
    "    parser.add_argument(\"--lr\", type=float, default=0.05)\n",
    "    parser.add_argument(\"--weight_decay\", type=float, default=5e-4)\n",
    "    parser.add_argument(\"--root\", type=str, default=\"data/Planetoid\")\n",
    "    parser.add_argument(\"--k_fracs\", type=float, nargs=\"+\",\n",
    "                        default=[0.05, 0.1])\n",
    "    parser.add_argument(\"--num_explain\", type=int, default=100,\n",
    "                        help=\"how many test nodes to explain via decomposition\")\n",
    "    parser.add_argument(\"--use_true_class\", action=\"store_true\",\n",
    "                        help=\"use true label instead of predicted label as target class\")\n",
    "    parser.add_argument(\"--model_path\", type=str, default=\"pubmed_gcn.pt\",\n",
    "                        help=\"path to save/load trained GCN model\")\n",
    "    args = parser.parse_args(args=[])\n",
    "\n",
    "    set_seed(args.seed)\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "    # 1) Load data\n",
    "    data, in_dim, num_classes = load_pubmed(args.root)\n",
    "    data = data.to(device)\n",
    "    # If masks are 2D (e.g., [N, 10] for HeterophilousGraphDataset),\n",
    "    # pick one column (args.split_idx) and make them 1D.\n",
    "    if not hasattr(data, \"train_mask\"):\n",
    "        data = create_splits(data)\n",
    "\n",
    "    print(\n",
    "        f\"#Train={int(data.train_mask.sum())}, \"\n",
    "        f\"#Val={int(data.val_mask.sum())}, \"\n",
    "        f\"#Test={int(data.test_mask.sum())}\"\n",
    "    )\n",
    "\n",
    "    # 2) Build model\n",
    "    model = GCN(\n",
    "        in_channels=in_dim,\n",
    "        hidden_channels=args.hidden_dim,\n",
    "        out_channels=num_classes,\n",
    "        dropout=args.dropout,\n",
    "    ).to(device)\n",
    "\n",
    "    # 3) Train or load model\n",
    "    if os.path.exists(args.model_path):\n",
    "        print(f\"\\n=== Loading existing model from {args.model_path} ===\")\n",
    "        state = torch.load(args.model_path, map_location=device)\n",
    "        model.load_state_dict(state)\n",
    "        train_acc, val_acc, test_acc, _ = evaluate(model, data)\n",
    "        print(f\"Loaded model | Train={train_acc:.4f} Val={val_acc:.4f} Test={test_acc:.4f}\")\n",
    "    else:\n",
    "        print(\"\\n=== Training GCN (no saved model found) ===\")\n",
    "        optimizer = torch.optim.Adam(\n",
    "            model.parameters(),\n",
    "            lr=args.lr,\n",
    "            weight_decay=args.weight_decay,\n",
    "        )\n",
    "\n",
    "        best_val_acc = 0.0\n",
    "        best_state = None\n",
    "\n",
    "        for epoch in range(1, args.epochs + 1):\n",
    "            loss = train(model, data, optimizer)\n",
    "            train_acc, val_acc, test_acc, _ = evaluate(model, data)\n",
    "\n",
    "            if val_acc > best_val_acc:\n",
    "                best_val_acc = val_acc\n",
    "                best_state = {\n",
    "                    \"model\": model.state_dict(),\n",
    "                    \"epoch\": epoch,\n",
    "                    \"test_acc\": test_acc,\n",
    "                }\n",
    "\n",
    "            if epoch % 20 == 0 or epoch == 1:\n",
    "                print(\n",
    "                    f\"Epoch {epoch:03d} | \"\n",
    "                    f\"Loss {loss:.4f} | \"\n",
    "                    f\"Train {train_acc:.4f} | \"\n",
    "                    f\"Val {val_acc:.4f} | \"\n",
    "                    f\"Test {test_acc:.4f}\"\n",
    "                )\n",
    "\n",
    "        if best_state is not None:\n",
    "            model.load_state_dict(best_state[\"model\"])\n",
    "            print(\n",
    "                f\"\\nBest epoch = {best_state['epoch']} | \"\n",
    "                f\"Best val acc = {best_val_acc:.4f} | \"\n",
    "                f\"Test acc @best = {best_state['test_acc']:.4f}\"\n",
    "            )\n",
    "\n",
    "        # Save trained model\n",
    "        torch.save(model.state_dict(), args.model_path)\n",
    "        print(f\"Model saved to {args.model_path}\")\n",
    "\n",
    "    # 4) Build normalized adjacency list\n",
    "    print(\"\\n=== Building normalized adjacency list ===\")\n",
    "    adj_list = build_normalized_adjacency_list(model, data)\n",
    "\n",
    "    # 5) Choose test nodes to explain (correct predictions preferred)\n",
    "    model.eval()\n",
    "    logits = model(data.x, data.edge_index)\n",
    "    preds = logits.argmax(dim=-1)\n",
    "\n",
    "    test_nodes = torch.nonzero(data.test_mask, as_tuple=False).view(-1)\n",
    "    correct_mask = preds == data.y\n",
    "    test_correct_nodes = test_nodes[correct_mask[test_nodes]]\n",
    "\n",
    "    if test_correct_nodes.numel() == 0:\n",
    "        print(\"WARNING: no correctly classified test nodes, using all test nodes.\")\n",
    "        nodes_to_explain = test_nodes\n",
    "    else:\n",
    "        nodes_to_explain = test_correct_nodes\n",
    "\n",
    "    if nodes_to_explain.numel() > args.num_explain:\n",
    "        nodes_to_explain = nodes_to_explain[:args.num_explain]\n",
    "\n",
    "    print(\n",
    "        f\"\\n=== Explaining {nodes_to_explain.numel()} test nodes \"\n",
    "        f\"({'true' if args.use_true_class else 'pred'} class as target) ===\"\n",
    "    )\n",
    "\n",
    "    # 6) Build feature importance: self vs self+1hop vs self+2hop\n",
    "\n",
    "    t0 = _now()\n",
    "    feat_self, feat_self_1hop, feat_self_2hop = build_feature_importance_decomposition(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        adj_list=adj_list,\n",
    "        nodes_to_explain=nodes_to_explain,\n",
    "        hop=2,\n",
    "        use_true_class=args.use_true_class,\n",
    "    )\n",
    "    t1 = _now()\n",
    "    method_times[\"decomposition\"] = t1 - t0\n",
    "\n",
    "    # feat_self, feat_self_1hop, feat_self_2hop = build_feature_importance_decomposition(\n",
    "    #     data=data,\n",
    "    #     model=model,\n",
    "    #     adj_list=adj_list,\n",
    "    #     nodes_to_explain=nodes_to_explain,\n",
    "    #     hop=2,\n",
    "    #     use_true_class=args.use_true_class,\n",
    "    # )\n",
    "    # ---- Random baseline ----\n",
    "    print(\"\\n=== Building RANDOM importance baseline ===\")\n",
    "\n",
    "    t0 = _now()\n",
    "    feat_random = build_random_importance(\n",
    "        num_nodes=data.num_nodes,\n",
    "        num_features=data.x.size(1),\n",
    "        device=device,\n",
    "    )\n",
    "    t1 = _now()\n",
    "    method_times[\"random\"] = t1 - t0\n",
    "\n",
    "#     feat_random = build_random_importance(\n",
    "#         num_nodes=data.num_nodes,\n",
    "#         num_features=data.x.size(1),\n",
    "#         device=device\n",
    "# )\n",
    "    print(\"\\n=== Building GRAD-Cam ===\")\n",
    "    # feat_gradcam = gradcam_feature_importance(model, data, use_true_class=True)\n",
    "    t0 = _now()\n",
    "    feat_gradcam = gradcam_feature_importance(\n",
    "        model=model,\n",
    "        data=data,\n",
    "        use_true_class=args.use_true_class,\n",
    "    )\n",
    "    t1 = _now()\n",
    "    method_times[\"gradcam\"] = t1 - t0\n",
    "\n",
    "    print(\"\\n=== Building GNN-LRP ===\")\n",
    "#     feat_lrp = gnn_lrp_feature_importance(\n",
    "#     model,\n",
    "#     data,\n",
    "#     eps=1e-6,\n",
    "#     use_true_class=True,\n",
    "#     nodes_to_explain=nodes_to_explain  # same subset you use for fidelity\n",
    "# )\n",
    "    t0 = _now()\n",
    "    feat_lrp = gnn_lrp_feature_importance(\n",
    "        model=model,\n",
    "        data=data,\n",
    "        eps=1e-6,\n",
    "        use_true_class=args.use_true_class,\n",
    "        nodes_to_explain=nodes_to_explain,\n",
    "    )\n",
    "    t1 = _now()\n",
    "    method_times[\"lrp\"] = t1 - t0\n",
    "\n",
    "\n",
    "    print(\"\\n=== Building GOAT ===\")\n",
    "    # feat_goat = build_goat_feature_importance(\n",
    "    #     data=data,\n",
    "    #     model=model,\n",
    "    #     nodes_to_explain=nodes_to_explain,\n",
    "    # )\n",
    "    t0 = _now()\n",
    "    feat_goat = build_goat_feature_importance(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        nodes_to_explain=nodes_to_explain,\n",
    "    )\n",
    "    t1 = _now()\n",
    "    method_times[\"goat\"] = t1 - t0\n",
    "\n",
    "    print(\"\\n=== Building LIME ===\")\n",
    "    # feat_lime = build_lime_importance(\n",
    "    #     data=data,\n",
    "    #     model=model,\n",
    "    #     nodes_to_explain=nodes_to_explain,\n",
    "    #     num_classes=num_classes,\n",
    "    #     num_samples=50,\n",
    "    # )\n",
    "    t0 = _now()\n",
    "    feat_lime = build_lime_importance(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        nodes_to_explain=nodes_to_explain,  # or smaller subset\n",
    "        num_classes=num_classes,\n",
    "        num_samples=50,\n",
    "    )\n",
    "    t1 = _now()\n",
    "    method_times[\"lime\"] = t1 - t0\n",
    "\n",
    "    # GraphLime: if possible\n",
    "    # print(\"\\n=== Building GraphLIME ===\")\n",
    "    # t0 = _now()\n",
    "    # feat_glime = build_graphlime_importance(\n",
    "    # data=data,\n",
    "    # model=model,\n",
    "    # nodes_to_explain=nodes_to_explain,\n",
    "    # hop=2,\n",
    "    # rho=0.1,\n",
    "    # )\n",
    "    # t1 = _now()\n",
    "    # method_times[\"graphlime\"] = t1 - t0\n",
    "\n",
    "\n",
    "    #probably not use the following two ---\n",
    "\n",
    "    #     # ---- GNNExplainer baseline importance ----\n",
    "    # print(\"\\n=== Building GNNExplainer importance baseline ===\")\n",
    "    # feat_gnnexp = build_gnnexplainer_importance(\n",
    "    #     data=data,\n",
    "    #     model=model,\n",
    "    #     nodes_to_explain=nodes_to_explain,\n",
    "    #     epochs=50,   # you can change this\n",
    "    # )\n",
    "    print(\"\\n=== Building Integrated Gradients (Captum) importance baseline ===\")\n",
    "    t0 = _now()\n",
    "    feat_ig = build_ig_importance(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        nodes_to_explain=nodes_to_explain,\n",
    ")\n",
    "    t1 = _now()\n",
    "    method_times[\"ig\"] = t1 - t0\n",
    "\n",
    "    # 7) Fidelity scores for each importance variant\n",
    "    print(\"\\n=== Computing fidelity scores ===\")\n",
    "    fp_self, fm_self, fk_self = compute_fidelity_scores(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        feat_imp=feat_self,\n",
    "        nodes=nodes_to_explain,\n",
    "        k_fracs=args.k_fracs,\n",
    "        use_true_class=args.use_true_class,\n",
    "        label=\"SELF ONLY\",\n",
    "    )\n",
    "\n",
    "    fp_s1, fm_s1, fk_s1 = compute_fidelity_scores(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        feat_imp=feat_self_1hop,\n",
    "        nodes=nodes_to_explain,\n",
    "        k_fracs=args.k_fracs,\n",
    "        use_true_class=args.use_true_class,\n",
    "        label=\"SELF + 1HOP\",\n",
    "    )\n",
    "\n",
    "    fp_s2, fm_s2, fk_s2 = compute_fidelity_scores(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        feat_imp=feat_self_2hop,\n",
    "        nodes=nodes_to_explain,\n",
    "        k_fracs=args.k_fracs,\n",
    "        use_true_class=args.use_true_class,\n",
    "        label=\"SELF + 1HOP + 2HOP\",\n",
    "    )\n",
    "    fp_rand, fm_rand, fk_rand = compute_fidelity_scores(\n",
    "    data=data,\n",
    "    model=model,\n",
    "    feat_imp=feat_random,\n",
    "    nodes=nodes_to_explain,\n",
    "    k_fracs=args.k_fracs,\n",
    "    use_true_class=args.use_true_class,\n",
    "    label=\"RANDOM BASELINE\",\n",
    "    )\n",
    "\n",
    "    fp_gc, fm_gc, fk_gc = compute_fidelity_scores(\n",
    "    data=data,\n",
    "    model=model,\n",
    "    feat_imp=feat_gradcam,\n",
    "    nodes=nodes_to_explain,\n",
    "    k_fracs=args.k_fracs,\n",
    "    use_true_class=args.use_true_class,\n",
    "    label=\"GradCAM-feature\",\n",
    ")\n",
    "\n",
    "    fp_lrp, fm_lrp, fk_lrp = compute_fidelity_scores(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        feat_imp=feat_lrp,\n",
    "        nodes=nodes_to_explain,\n",
    "        k_fracs=args.k_fracs,\n",
    "        use_true_class=args.use_true_class,\n",
    "        label=\"GNN-LRP-feature\",\n",
    "    )\n",
    "\n",
    "    fp_goat, fm_goat, fk_goat = compute_fidelity_scores(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        feat_imp=feat_goat,\n",
    "        nodes=nodes_to_explain,\n",
    "        k_fracs=args.k_fracs,\n",
    "        use_true_class=args.use_true_class,\n",
    "        label=\"GOAT-feature\",\n",
    "    )\n",
    "    fp_lime, fm_lime, fk_lime = compute_fidelity_scores(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        feat_imp=feat_lime,\n",
    "        nodes=nodes_to_explain,\n",
    "        k_fracs=args.k_fracs,\n",
    "        use_true_class=args.use_true_class,\n",
    "        label=\"LIME\",\n",
    "    )\n",
    "    # fp_glime, fm_glime, fk_glime = compute_fidelity_scores(\n",
    "    # data=data,\n",
    "    # model=model,\n",
    "    # feat_imp=feat_glime,\n",
    "    # nodes=nodes_to_explain,\n",
    "    # k_fracs=args.k_fracs,\n",
    "    # use_true_class=args.use_true_class,\n",
    "    # label=\"GraphLIME\",\n",
    "    # )\n",
    "    # fp_gnn, fm_gnn, fk_gnn = compute_fidelity_scores(\n",
    "    # data=data,\n",
    "    # model=model,\n",
    "    # feat_imp=feat_gnnexp,\n",
    "    # nodes=nodes_to_explain,\n",
    "    # k_fracs=args.k_fracs,\n",
    "    # use_true_class=args.use_true_class,\n",
    "    # label=\"GNNEXPLAINER BASELINE\",\n",
    "    # )\n",
    "    fp_ig, fm_ig, fk_ig = compute_fidelity_scores(\n",
    "    data=data,\n",
    "    model=model,\n",
    "    feat_imp=feat_ig,\n",
    "    nodes=nodes_to_explain,\n",
    "    k_fracs=args.k_fracs,\n",
    "    use_true_class=args.use_true_class,\n",
    "    label=\"IG (Captum) BASELINE\",\n",
    "  )\n",
    "\n",
    "\n",
    "    # 8) Summary\n",
    "    print(\"\\n=== Summary: mean Δ = p_orig - p_mask ===\")\n",
    "    print(\"\\nFidelity+ (remove top-k% features):\")\n",
    "    for k in args.k_fracs:\n",
    "        print(\n",
    "            f\"k={k:.2f} | \"\n",
    "            f\"self={fp_self[k]:.4f} | \"\n",
    "            f\"self+1hop={fp_s1[k]:.4f} | \"\n",
    "            f\"self+2hop={fp_s2[k]:.4f} | \"\n",
    "            f\"rand={fp_rand[k]:.4f} | \"\n",
    "            f\"gradcam={fp_gc[k]:.4f} | \"\n",
    "            f\"lrp={fp_lrp[k]:.4f} | \"\n",
    "            f\"goat={fp_goat[k]:.4f} | \"\n",
    "            f\"lime={fp_lime[k]:.4f} | \"\n",
    "            f\"ig={fp_ig[k]:.4f} | \"\n",
    "        )\n",
    "\n",
    "    print(\"\\nFidelity- (remove bottom-k% features):\")\n",
    "    for k in args.k_fracs:\n",
    "        print(\n",
    "            f\"k={k:.2f} | \"\n",
    "            f\"self={fm_self[k]:.4f} | \"\n",
    "            f\"self+1hop={fm_s1[k]:.4f} | \"\n",
    "            f\"self+2hop={fm_s2[k]:.4f} | \"\n",
    "            f\"rand={fm_rand[k]:.4f} | \"\n",
    "            f\"gradcam={fm_gc[k]:.4f} | \"\n",
    "            f\"lrp={fm_lrp[k]:.4f} | \"\n",
    "            f\"goat={fm_goat[k]:.4f} | \"\n",
    "            f\"lime={fm_lime[k]:.4f} | \"\n",
    "            f\"ig={fm_ig[k]:.4f} | \"\n",
    "        )\n",
    "\n",
    "    print(\"\\n=== Build-Time Summary (seconds) ===\")\n",
    "    print(f\"{'method':20s} {'build_time':>12s}\")\n",
    "    for name, t in method_times.items():\n",
    "        print(f\"{name:20s} {t:12.3f}\")\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "  main()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "IST6mApctP2C"
   },
   "source": [
    "## Cora"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "hAkPRBvJtVZB",
    "outputId": "f241ea0a-3cae-443c-db28-82a009e69277"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index\n",
      "Processing...\n",
      "Done!\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded Planetoid/Cora:\n",
      "  #Nodes     = 2708\n",
      "  #Edges     = 10556\n",
      "  #Features  = 1433\n",
      "  #Classes   = 7\n",
      "  #Train/Val/Test = 140/500/1000\n",
      "#Train=140, #Val=500, #Test=1000\n",
      "\n",
      "=== Training GCN (no saved model found) ===\n",
      "Epoch 001 | Loss 1.9463 | Train 0.1786 | Val 0.1220 | Test 0.1380\n",
      "Epoch 020 | Loss 0.3841 | Train 0.9857 | Val 0.7980 | Test 0.8050\n",
      "Epoch 040 | Loss 0.1968 | Train 1.0000 | Val 0.7960 | Test 0.8060\n",
      "Epoch 060 | Loss 0.1802 | Train 1.0000 | Val 0.7980 | Test 0.8060\n",
      "Epoch 080 | Loss 0.1573 | Train 1.0000 | Val 0.7920 | Test 0.8070\n",
      "Epoch 100 | Loss 0.1558 | Train 1.0000 | Val 0.7960 | Test 0.8190\n",
      "Epoch 120 | Loss 0.1521 | Train 1.0000 | Val 0.7900 | Test 0.8120\n",
      "Epoch 140 | Loss 0.1340 | Train 1.0000 | Val 0.7960 | Test 0.8080\n",
      "Epoch 160 | Loss 0.1323 | Train 1.0000 | Val 0.7980 | Test 0.8120\n",
      "Epoch 180 | Loss 0.1763 | Train 1.0000 | Val 0.7960 | Test 0.8130\n",
      "Epoch 200 | Loss 0.1681 | Train 1.0000 | Val 0.7920 | Test 0.8080\n",
      "Epoch 220 | Loss 0.1420 | Train 1.0000 | Val 0.7920 | Test 0.8070\n",
      "Epoch 240 | Loss 0.1348 | Train 1.0000 | Val 0.7900 | Test 0.8100\n",
      "Epoch 260 | Loss 0.1472 | Train 1.0000 | Val 0.7940 | Test 0.8030\n",
      "Epoch 280 | Loss 0.1283 | Train 1.0000 | Val 0.7940 | Test 0.8100\n",
      "Epoch 300 | Loss 0.1533 | Train 1.0000 | Val 0.8040 | Test 0.8060\n",
      "Epoch 320 | Loss 0.1432 | Train 1.0000 | Val 0.7900 | Test 0.8120\n",
      "Epoch 340 | Loss 0.1627 | Train 1.0000 | Val 0.7880 | Test 0.8070\n",
      "Epoch 360 | Loss 0.1337 | Train 1.0000 | Val 0.7900 | Test 0.8000\n",
      "Epoch 380 | Loss 0.1485 | Train 1.0000 | Val 0.7920 | Test 0.8040\n",
      "Epoch 400 | Loss 0.1630 | Train 1.0000 | Val 0.7960 | Test 0.8140\n",
      "Epoch 420 | Loss 0.1621 | Train 1.0000 | Val 0.7960 | Test 0.8100\n",
      "Epoch 440 | Loss 0.1479 | Train 1.0000 | Val 0.8040 | Test 0.8080\n",
      "Epoch 460 | Loss 0.1524 | Train 1.0000 | Val 0.7940 | Test 0.8020\n",
      "Epoch 480 | Loss 0.1348 | Train 1.0000 | Val 0.7780 | Test 0.8070\n",
      "Epoch 500 | Loss 0.1711 | Train 1.0000 | Val 0.7820 | Test 0.8170\n",
      "\n",
      "Best epoch = 296 | Best val acc = 0.8140 | Test acc @best = 0.8140\n",
      "Model saved to cora_gcn.pt\n",
      "\n",
      "=== Building normalized adjacency list ===\n",
      "\n",
      "=== Explaining 100 test nodes (pred class as target) ===\n",
      "Building feature importance via decomposition for 100 nodes...\n",
      "  processed 1/100 nodes\n",
      "  processed 20/100 nodes\n",
      "  processed 40/100 nodes\n",
      "  processed 60/100 nodes\n",
      "  processed 80/100 nodes\n",
      "  processed 100/100 nodes\n",
      "\n",
      "=== Building RANDOM importance baseline ===\n",
      "\n",
      "=== Building GRAD-Cam ===\n",
      "\n",
      "=== Building GNN-LRP ===\n",
      "Running GNN-LRP for 100 nodes...\n",
      "  GNN-LRP processed 1/100 nodes\n",
      "  GNN-LRP processed 10/100 nodes\n",
      "  GNN-LRP processed 20/100 nodes\n",
      "  GNN-LRP processed 30/100 nodes\n",
      "  GNN-LRP processed 40/100 nodes\n",
      "  GNN-LRP processed 50/100 nodes\n",
      "  GNN-LRP processed 60/100 nodes\n",
      "  GNN-LRP processed 70/100 nodes\n",
      "  GNN-LRP processed 80/100 nodes\n",
      "  GNN-LRP processed 90/100 nodes\n",
      "  GNN-LRP processed 100/100 nodes\n",
      "\n",
      "=== Building GOAT ===\n",
      "Running GOAT for 100 nodes...\n",
      "  GOAT processed 1/100 nodes\n",
      "  GOAT processed 10/100 nodes\n",
      "  GOAT processed 20/100 nodes\n",
      "  GOAT processed 30/100 nodes\n",
      "  GOAT processed 40/100 nodes\n",
      "  GOAT processed 50/100 nodes\n",
      "  GOAT processed 60/100 nodes\n",
      "  GOAT processed 70/100 nodes\n",
      "  GOAT processed 80/100 nodes\n",
      "  GOAT processed 90/100 nodes\n",
      "  GOAT processed 100/100 nodes\n",
      "\n",
      "=== Building LIME ===\n",
      "Running LIME for 100 nodes...\n",
      "  LIME processed 1/100 nodes\n",
      "  LIME processed 10/100 nodes\n",
      "  LIME processed 20/100 nodes\n",
      "  LIME processed 30/100 nodes\n",
      "  LIME processed 40/100 nodes\n",
      "  LIME processed 50/100 nodes\n",
      "  LIME processed 60/100 nodes\n",
      "  LIME processed 70/100 nodes\n",
      "  LIME processed 80/100 nodes\n",
      "  LIME processed 90/100 nodes\n",
      "  LIME processed 100/100 nodes\n",
      "\n",
      "=== Building Integrated Gradients (Captum) importance baseline ===\n",
      "Running Clean IG for 100 nodes...\n",
      "  IG processed 1/100 nodes\n",
      "  IG processed 10/100 nodes\n",
      "  IG processed 20/100 nodes\n",
      "  IG processed 30/100 nodes\n",
      "  IG processed 40/100 nodes\n",
      "  IG processed 50/100 nodes\n",
      "  IG processed 60/100 nodes\n",
      "  IG processed 70/100 nodes\n",
      "  IG processed 80/100 nodes\n",
      "  IG processed 90/100 nodes\n",
      "  IG processed 100/100 nodes\n",
      "\n",
      "=== Computing fidelity scores ===\n",
      "\n",
      "=== Fidelity scores for importance: SELF ONLY ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=72) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.1787\n",
      "Fidelity- (remove BOTTOM-k)= 0.0000\n",
      "Keep-top (only TOP-k kept) = 0.0000\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=143) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.1787\n",
      "Fidelity- (remove BOTTOM-k)= 0.0000\n",
      "Keep-top (only TOP-k kept) = 0.0000\n",
      "\n",
      "=== Fidelity scores for importance: SELF + 1HOP ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=72) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.1787\n",
      "Fidelity- (remove BOTTOM-k)= 0.0000\n",
      "Keep-top (only TOP-k kept) = -0.0000\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=143) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.1788\n",
      "Fidelity- (remove BOTTOM-k)= 0.0000\n",
      "Keep-top (only TOP-k kept) = 0.0000\n",
      "\n",
      "=== Fidelity scores for importance: SELF + 1HOP + 2HOP ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=72) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.1789\n",
      "Fidelity- (remove BOTTOM-k)= 0.0000\n",
      "Keep-top (only TOP-k kept) = -0.0001\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=143) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.1789\n",
      "Fidelity- (remove BOTTOM-k)= 0.0000\n",
      "Keep-top (only TOP-k kept) = -0.0001\n",
      "\n",
      "=== Fidelity scores for importance: RANDOM BASELINE ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=72) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0054\n",
      "Fidelity- (remove BOTTOM-k)= 0.0033\n",
      "Keep-top (only TOP-k kept) = 0.1692\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=143) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0134\n",
      "Fidelity- (remove BOTTOM-k)= 0.0113\n",
      "Keep-top (only TOP-k kept) = 0.1535\n",
      "\n",
      "=== Fidelity scores for importance: GradCAM-feature ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=72) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.1073\n",
      "Fidelity- (remove BOTTOM-k)= 0.0006\n",
      "Keep-top (only TOP-k kept) = 0.0498\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=143) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.1230\n",
      "Fidelity- (remove BOTTOM-k)= 0.0030\n",
      "Keep-top (only TOP-k kept) = 0.0382\n",
      "\n",
      "=== Fidelity scores for importance: GNN-LRP-feature ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=72) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.1787\n",
      "Fidelity- (remove BOTTOM-k)= 0.0000\n",
      "Keep-top (only TOP-k kept) = 0.0000\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=143) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.1787\n",
      "Fidelity- (remove BOTTOM-k)= 0.0000\n",
      "Keep-top (only TOP-k kept) = 0.0000\n",
      "\n",
      "=== Fidelity scores for importance: GOAT-feature ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=72) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.1808\n",
      "Fidelity- (remove BOTTOM-k)= 0.0000\n",
      "Keep-top (only TOP-k kept) = -0.0004\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=143) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.1791\n",
      "Fidelity- (remove BOTTOM-k)= 0.0000\n",
      "Keep-top (only TOP-k kept) = 0.0000\n",
      "\n",
      "=== Fidelity scores for importance: LIME ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=72) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0267\n",
      "Fidelity- (remove BOTTOM-k)= 0.0046\n",
      "Keep-top (only TOP-k kept) = 0.1306\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=143) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.0360\n",
      "Fidelity- (remove BOTTOM-k)= 0.0080\n",
      "Keep-top (only TOP-k kept) = 0.1182\n",
      "\n",
      "=== Fidelity scores for importance: IG (Captum) BASELINE ===\n",
      "\n",
      "--- k_frac = 0.05 (top 5% -> k=72) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.1229\n",
      "Fidelity- (remove BOTTOM-k)= -0.0001\n",
      "Keep-top (only TOP-k kept) = 0.0316\n",
      "\n",
      "--- k_frac = 0.10 (top 10% -> k=143) ---\n",
      "Fidelity+ (remove TOP-k)   = 0.1395\n",
      "Fidelity- (remove BOTTOM-k)= 0.0001\n",
      "Keep-top (only TOP-k kept) = 0.0214\n",
      "\n",
      "=== Summary: mean Δ = p_orig - p_mask ===\n",
      "\n",
      "Fidelity+ (remove top-k% features):\n",
      "k=0.05 | self=0.1787 | self+1hop=0.1787 | self+2hop=0.1789 | rand=0.0054 | gradcam=0.1073 | lrp=0.1787 | goat=0.1808 | lime=0.0267 | ig=0.1229 | \n",
      "k=0.10 | self=0.1787 | self+1hop=0.1788 | self+2hop=0.1789 | rand=0.0134 | gradcam=0.1230 | lrp=0.1787 | goat=0.1791 | lime=0.0360 | ig=0.1395 | \n",
      "\n",
      "Fidelity- (remove bottom-k% features):\n",
      "k=0.05 | self=0.0000 | self+1hop=0.0000 | self+2hop=0.0000 | rand=0.0033 | gradcam=0.0006 | lrp=0.0000 | goat=0.0000 | lime=0.0046 | ig=-0.0001 | \n",
      "k=0.10 | self=0.0000 | self+1hop=0.0000 | self+2hop=0.0000 | rand=0.0113 | gradcam=0.0030 | lrp=0.0000 | goat=0.0000 | lime=0.0080 | ig=0.0001 | \n",
      "\n",
      "=== Build-Time Summary (seconds) ===\n",
      "method                 build_time\n",
      "decomposition               0.486\n",
      "random                      0.017\n",
      "gradcam                     0.021\n",
      "lrp                         0.269\n",
      "goat                       14.792\n",
      "lime                       43.122\n",
      "ig                         56.606\n"
     ]
    }
   ],
   "source": [
    "from torch_geometric.datasets import Planetoid, HeterophilousGraphDataset\n",
    "import torch_geometric.transforms as T\n",
    "\n",
    "def create_splits(data, num_train_per_class=20, num_val_per_class=30):\n",
    "    y = data.y\n",
    "    num_nodes = data.num_nodes\n",
    "    num_classes = int(y.max().item() + 1)\n",
    "\n",
    "    train_mask = torch.zeros(num_nodes, dtype=torch.bool)\n",
    "    val_mask = torch.zeros(num_nodes, dtype=torch.bool)\n",
    "    test_mask = torch.zeros(num_nodes, dtype=torch.bool)\n",
    "\n",
    "    for c in range(num_classes):\n",
    "        idx = (y == c).nonzero(as_tuple=False).view(-1)\n",
    "        idx = idx[torch.randperm(idx.size(0))]\n",
    "\n",
    "        n_train = min(num_train_per_class, idx.size(0))\n",
    "        n_val = min(num_val_per_class, max(0, idx.size(0) - n_train))\n",
    "\n",
    "        train_idx = idx[:n_train]\n",
    "        val_idx = idx[n_train:n_train + n_val]\n",
    "        test_idx = idx[n_train + n_val:]\n",
    "\n",
    "        train_mask[train_idx] = True\n",
    "        val_mask[val_idx] = True\n",
    "        test_mask[test_idx] = True\n",
    "\n",
    "    data.train_mask = train_mask\n",
    "    data.val_mask = val_mask\n",
    "    data.test_mask = test_mask\n",
    "\n",
    "    print(f\"#Train nodes = {int(train_mask.sum())}\")\n",
    "    print(f\"#Val nodes   = {int(val_mask.sum())}\")\n",
    "    print(f\"#Test nodes  = {int(test_mask.sum())}\")\n",
    "\n",
    "    return data\n",
    "def load_planetoid(name: str, root: str):\n",
    "    \"\"\"\n",
    "    Load Planetoid datasets: 'Cora', 'CiteSeer', or 'PubMed'.\n",
    "\n",
    "    Returns:\n",
    "        data:         Data object (with built-in train/val/test masks)\n",
    "        in_dim:       num node features\n",
    "        num_classes:  num classes\n",
    "    \"\"\"\n",
    "    dataset = Planetoid(\n",
    "        root=root,\n",
    "        name=name,\n",
    "        transform=T.NormalizeFeatures(),   # common preprocessing\n",
    "    )\n",
    "    data = dataset[0]\n",
    "\n",
    "    in_dim = dataset.num_features\n",
    "    num_classes = dataset.num_classes\n",
    "\n",
    "    print(f\"Loaded Planetoid/{name}:\")\n",
    "    print(f\"  #Nodes     = {data.num_nodes}\")\n",
    "    print(f\"  #Edges     = {data.num_edges}\")\n",
    "    print(f\"  #Features  = {in_dim}\")\n",
    "    print(f\"  #Classes   = {num_classes}\")\n",
    "    print(f\"  #Train/Val/Test = {int(data.train_mask.sum())}/\"\n",
    "          f\"{int(data.val_mask.sum())}/{int(data.test_mask.sum())}\")\n",
    "\n",
    "    return data, in_dim, num_classes\n",
    "\n",
    "def load_cora(root):     return load_planetoid(\"Cora\", root)\n",
    "def load_citeseer(root): return load_planetoid(\"CiteSeer\", root)\n",
    "def load_pubmed(root):   return load_planetoid(\"PubMed\", root)\n",
    "\n",
    "def main():\n",
    "    method_times = {}\n",
    "    parser = argparse.ArgumentParser()\n",
    "    parser.add_argument(\"--seed\", type=int, default=42)\n",
    "    parser.add_argument(\"--hidden_dim\", type=int, default=64)\n",
    "    parser.add_argument(\"--dropout\", type=float, default=0.5)\n",
    "    parser.add_argument(\"--epochs\", type=int, default=500)\n",
    "    parser.add_argument(\"--lr\", type=float, default=0.05)\n",
    "    parser.add_argument(\"--weight_decay\", type=float, default=5e-4)\n",
    "    parser.add_argument(\"--root\", type=str, default=\"data/Planetoid\")\n",
    "    parser.add_argument(\"--k_fracs\", type=float, nargs=\"+\",\n",
    "                        default=[0.05, 0.1])\n",
    "    parser.add_argument(\"--num_explain\", type=int, default=100,\n",
    "                        help=\"how many test nodes to explain via decomposition\")\n",
    "    parser.add_argument(\"--use_true_class\", action=\"store_true\",\n",
    "                        help=\"use true label instead of predicted label as target class\")\n",
    "    parser.add_argument(\"--model_path\", type=str, default=\"cora_gcn.pt\",\n",
    "                        help=\"path to save/load trained GCN model\")\n",
    "    args = parser.parse_args(args=[])\n",
    "\n",
    "    set_seed(args.seed)\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "    # 1) Load data\n",
    "    data, in_dim, num_classes = load_cora(args.root)\n",
    "    data = data.to(device)\n",
    "    # If masks are 2D (e.g., [N, 10] for HeterophilousGraphDataset),\n",
    "    # pick one column (args.split_idx) and make them 1D.\n",
    "    if not hasattr(data, \"train_mask\"):\n",
    "        data = create_splits(data)\n",
    "\n",
    "    print(\n",
    "        f\"#Train={int(data.train_mask.sum())}, \"\n",
    "        f\"#Val={int(data.val_mask.sum())}, \"\n",
    "        f\"#Test={int(data.test_mask.sum())}\"\n",
    "    )\n",
    "\n",
    "    # 2) Build model\n",
    "    model = GCN(\n",
    "        in_channels=in_dim,\n",
    "        hidden_channels=args.hidden_dim,\n",
    "        out_channels=num_classes,\n",
    "        dropout=args.dropout,\n",
    "    ).to(device)\n",
    "\n",
    "    # 3) Train or load model\n",
    "    if os.path.exists(args.model_path):\n",
    "        print(f\"\\n=== Loading existing model from {args.model_path} ===\")\n",
    "        state = torch.load(args.model_path, map_location=device)\n",
    "        model.load_state_dict(state)\n",
    "        train_acc, val_acc, test_acc, _ = evaluate(model, data)\n",
    "        print(f\"Loaded model | Train={train_acc:.4f} Val={val_acc:.4f} Test={test_acc:.4f}\")\n",
    "    else:\n",
    "        print(\"\\n=== Training GCN (no saved model found) ===\")\n",
    "        optimizer = torch.optim.Adam(\n",
    "            model.parameters(),\n",
    "            lr=args.lr,\n",
    "            weight_decay=args.weight_decay,\n",
    "        )\n",
    "\n",
    "        best_val_acc = 0.0\n",
    "        best_state = None\n",
    "\n",
    "        for epoch in range(1, args.epochs + 1):\n",
    "            loss = train(model, data, optimizer)\n",
    "            train_acc, val_acc, test_acc, _ = evaluate(model, data)\n",
    "\n",
    "            if val_acc > best_val_acc:\n",
    "                best_val_acc = val_acc\n",
    "                best_state = {\n",
    "                    \"model\": model.state_dict(),\n",
    "                    \"epoch\": epoch,\n",
    "                    \"test_acc\": test_acc,\n",
    "                }\n",
    "\n",
    "            if epoch % 20 == 0 or epoch == 1:\n",
    "                print(\n",
    "                    f\"Epoch {epoch:03d} | \"\n",
    "                    f\"Loss {loss:.4f} | \"\n",
    "                    f\"Train {train_acc:.4f} | \"\n",
    "                    f\"Val {val_acc:.4f} | \"\n",
    "                    f\"Test {test_acc:.4f}\"\n",
    "                )\n",
    "\n",
    "        if best_state is not None:\n",
    "            model.load_state_dict(best_state[\"model\"])\n",
    "            print(\n",
    "                f\"\\nBest epoch = {best_state['epoch']} | \"\n",
    "                f\"Best val acc = {best_val_acc:.4f} | \"\n",
    "                f\"Test acc @best = {best_state['test_acc']:.4f}\"\n",
    "            )\n",
    "\n",
    "        # Save trained model\n",
    "        torch.save(model.state_dict(), args.model_path)\n",
    "        print(f\"Model saved to {args.model_path}\")\n",
    "\n",
    "    # 4) Build normalized adjacency list\n",
    "    print(\"\\n=== Building normalized adjacency list ===\")\n",
    "    adj_list = build_normalized_adjacency_list(model, data)\n",
    "\n",
    "    # 5) Choose test nodes to explain (correct predictions preferred)\n",
    "    model.eval()\n",
    "    logits = model(data.x, data.edge_index)\n",
    "    preds = logits.argmax(dim=-1)\n",
    "\n",
    "    test_nodes = torch.nonzero(data.test_mask, as_tuple=False).view(-1)\n",
    "    correct_mask = preds == data.y\n",
    "    test_correct_nodes = test_nodes[correct_mask[test_nodes]]\n",
    "\n",
    "    if test_correct_nodes.numel() == 0:\n",
    "        print(\"WARNING: no correctly classified test nodes, using all test nodes.\")\n",
    "        nodes_to_explain = test_nodes\n",
    "    else:\n",
    "        nodes_to_explain = test_correct_nodes\n",
    "\n",
    "    if nodes_to_explain.numel() > args.num_explain:\n",
    "        nodes_to_explain = nodes_to_explain[:args.num_explain]\n",
    "\n",
    "    print(\n",
    "        f\"\\n=== Explaining {nodes_to_explain.numel()} test nodes \"\n",
    "        f\"({'true' if args.use_true_class else 'pred'} class as target) ===\"\n",
    "    )\n",
    "\n",
    "    # 6) Build feature importance: self vs self+1hop vs self+2hop\n",
    "\n",
    "    t0 = _now()\n",
    "    feat_self, feat_self_1hop, feat_self_2hop = build_feature_importance_decomposition(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        adj_list=adj_list,\n",
    "        nodes_to_explain=nodes_to_explain,\n",
    "        hop=2,\n",
    "        use_true_class=args.use_true_class,\n",
    "    )\n",
    "    t1 = _now()\n",
    "    method_times[\"decomposition\"] = t1 - t0\n",
    "\n",
    "    # feat_self, feat_self_1hop, feat_self_2hop = build_feature_importance_decomposition(\n",
    "    #     data=data,\n",
    "    #     model=model,\n",
    "    #     adj_list=adj_list,\n",
    "    #     nodes_to_explain=nodes_to_explain,\n",
    "    #     hop=2,\n",
    "    #     use_true_class=args.use_true_class,\n",
    "    # )\n",
    "    # ---- Random baseline ----\n",
    "    print(\"\\n=== Building RANDOM importance baseline ===\")\n",
    "\n",
    "    t0 = _now()\n",
    "    feat_random = build_random_importance(\n",
    "        num_nodes=data.num_nodes,\n",
    "        num_features=data.x.size(1),\n",
    "        device=device,\n",
    "    )\n",
    "    t1 = _now()\n",
    "    method_times[\"random\"] = t1 - t0\n",
    "\n",
    "#     feat_random = build_random_importance(\n",
    "#         num_nodes=data.num_nodes,\n",
    "#         num_features=data.x.size(1),\n",
    "#         device=device\n",
    "# )\n",
    "    print(\"\\n=== Building GRAD-Cam ===\")\n",
    "    # feat_gradcam = gradcam_feature_importance(model, data, use_true_class=True)\n",
    "    t0 = _now()\n",
    "    feat_gradcam = gradcam_feature_importance(\n",
    "        model=model,\n",
    "        data=data,\n",
    "        use_true_class=args.use_true_class,\n",
    "    )\n",
    "    t1 = _now()\n",
    "    method_times[\"gradcam\"] = t1 - t0\n",
    "\n",
    "    print(\"\\n=== Building GNN-LRP ===\")\n",
    "#     feat_lrp = gnn_lrp_feature_importance(\n",
    "#     model,\n",
    "#     data,\n",
    "#     eps=1e-6,\n",
    "#     use_true_class=True,\n",
    "#     nodes_to_explain=nodes_to_explain  # same subset you use for fidelity\n",
    "# )\n",
    "    t0 = _now()\n",
    "    feat_lrp = gnn_lrp_feature_importance(\n",
    "        model=model,\n",
    "        data=data,\n",
    "        eps=1e-6,\n",
    "        use_true_class=args.use_true_class,\n",
    "        nodes_to_explain=nodes_to_explain,\n",
    "    )\n",
    "    t1 = _now()\n",
    "    method_times[\"lrp\"] = t1 - t0\n",
    "\n",
    "\n",
    "    print(\"\\n=== Building GOAT ===\")\n",
    "    # feat_goat = build_goat_feature_importance(\n",
    "    #     data=data,\n",
    "    #     model=model,\n",
    "    #     nodes_to_explain=nodes_to_explain,\n",
    "    # )\n",
    "    t0 = _now()\n",
    "    feat_goat = build_goat_feature_importance(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        nodes_to_explain=nodes_to_explain,\n",
    "    )\n",
    "    t1 = _now()\n",
    "    method_times[\"goat\"] = t1 - t0\n",
    "\n",
    "    print(\"\\n=== Building LIME ===\")\n",
    "    # feat_lime = build_lime_importance(\n",
    "    #     data=data,\n",
    "    #     model=model,\n",
    "    #     nodes_to_explain=nodes_to_explain,\n",
    "    #     num_classes=num_classes,\n",
    "    #     num_samples=50,\n",
    "    # )\n",
    "    t0 = _now()\n",
    "    feat_lime = build_lime_importance(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        nodes_to_explain=nodes_to_explain,  # or smaller subset\n",
    "        num_classes=num_classes,\n",
    "        num_samples=50,\n",
    "    )\n",
    "    t1 = _now()\n",
    "    method_times[\"lime\"] = t1 - t0\n",
    "\n",
    "    # GraphLime: if possible\n",
    "    # print(\"\\n=== Building GraphLIME ===\")\n",
    "    # t0 = _now()\n",
    "    # feat_glime = build_graphlime_importance(\n",
    "    # data=data,\n",
    "    # model=model,\n",
    "    # nodes_to_explain=nodes_to_explain,\n",
    "    # hop=2,\n",
    "    # rho=0.1,\n",
    "    # )\n",
    "    # t1 = _now()\n",
    "    # method_times[\"graphlime\"] = t1 - t0\n",
    "\n",
    "\n",
    "    #probably not use the following two ---\n",
    "\n",
    "    #     # ---- GNNExplainer baseline importance ----\n",
    "    # print(\"\\n=== Building GNNExplainer importance baseline ===\")\n",
    "    # feat_gnnexp = build_gnnexplainer_importance(\n",
    "    #     data=data,\n",
    "    #     model=model,\n",
    "    #     nodes_to_explain=nodes_to_explain,\n",
    "    #     epochs=50,   # you can change this\n",
    "    # )\n",
    "    print(\"\\n=== Building Integrated Gradients (Captum) importance baseline ===\")\n",
    "    t0 = _now()\n",
    "    feat_ig = build_ig_importance(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        nodes_to_explain=nodes_to_explain,\n",
    ")\n",
    "    t1 = _now()\n",
    "    method_times[\"ig\"] = t1 - t0\n",
    "\n",
    "    # 7) Fidelity scores for each importance variant\n",
    "    print(\"\\n=== Computing fidelity scores ===\")\n",
    "    fp_self, fm_self, fk_self = compute_fidelity_scores(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        feat_imp=feat_self,\n",
    "        nodes=nodes_to_explain,\n",
    "        k_fracs=args.k_fracs,\n",
    "        use_true_class=args.use_true_class,\n",
    "        label=\"SELF ONLY\",\n",
    "    )\n",
    "\n",
    "    fp_s1, fm_s1, fk_s1 = compute_fidelity_scores(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        feat_imp=feat_self_1hop,\n",
    "        nodes=nodes_to_explain,\n",
    "        k_fracs=args.k_fracs,\n",
    "        use_true_class=args.use_true_class,\n",
    "        label=\"SELF + 1HOP\",\n",
    "    )\n",
    "\n",
    "    fp_s2, fm_s2, fk_s2 = compute_fidelity_scores(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        feat_imp=feat_self_2hop,\n",
    "        nodes=nodes_to_explain,\n",
    "        k_fracs=args.k_fracs,\n",
    "        use_true_class=args.use_true_class,\n",
    "        label=\"SELF + 1HOP + 2HOP\",\n",
    "    )\n",
    "    fp_rand, fm_rand, fk_rand = compute_fidelity_scores(\n",
    "    data=data,\n",
    "    model=model,\n",
    "    feat_imp=feat_random,\n",
    "    nodes=nodes_to_explain,\n",
    "    k_fracs=args.k_fracs,\n",
    "    use_true_class=args.use_true_class,\n",
    "    label=\"RANDOM BASELINE\",\n",
    "    )\n",
    "\n",
    "    fp_gc, fm_gc, fk_gc = compute_fidelity_scores(\n",
    "    data=data,\n",
    "    model=model,\n",
    "    feat_imp=feat_gradcam,\n",
    "    nodes=nodes_to_explain,\n",
    "    k_fracs=args.k_fracs,\n",
    "    use_true_class=args.use_true_class,\n",
    "    label=\"GradCAM-feature\",\n",
    ")\n",
    "\n",
    "    fp_lrp, fm_lrp, fk_lrp = compute_fidelity_scores(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        feat_imp=feat_lrp,\n",
    "        nodes=nodes_to_explain,\n",
    "        k_fracs=args.k_fracs,\n",
    "        use_true_class=args.use_true_class,\n",
    "        label=\"GNN-LRP-feature\",\n",
    "    )\n",
    "\n",
    "    fp_goat, fm_goat, fk_goat = compute_fidelity_scores(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        feat_imp=feat_goat,\n",
    "        nodes=nodes_to_explain,\n",
    "        k_fracs=args.k_fracs,\n",
    "        use_true_class=args.use_true_class,\n",
    "        label=\"GOAT-feature\",\n",
    "    )\n",
    "    fp_lime, fm_lime, fk_lime = compute_fidelity_scores(\n",
    "        data=data,\n",
    "        model=model,\n",
    "        feat_imp=feat_lime,\n",
    "        nodes=nodes_to_explain,\n",
    "        k_fracs=args.k_fracs,\n",
    "        use_true_class=args.use_true_class,\n",
    "        label=\"LIME\",\n",
    "    )\n",
    "    # fp_glime, fm_glime, fk_glime = compute_fidelity_scores(\n",
    "    # data=data,\n",
    "    # model=model,\n",
    "    # feat_imp=feat_glime,\n",
    "    # nodes=nodes_to_explain,\n",
    "    # k_fracs=args.k_fracs,\n",
    "    # use_true_class=args.use_true_class,\n",
    "    # label=\"GraphLIME\",\n",
    "    # )\n",
    "    # fp_gnn, fm_gnn, fk_gnn = compute_fidelity_scores(\n",
    "    # data=data,\n",
    "    # model=model,\n",
    "    # feat_imp=feat_gnnexp,\n",
    "    # nodes=nodes_to_explain,\n",
    "    # k_fracs=args.k_fracs,\n",
    "    # use_true_class=args.use_true_class,\n",
    "    # label=\"GNNEXPLAINER BASELINE\",\n",
    "    # )\n",
    "    fp_ig, fm_ig, fk_ig = compute_fidelity_scores(\n",
    "    data=data,\n",
    "    model=model,\n",
    "    feat_imp=feat_ig,\n",
    "    nodes=nodes_to_explain,\n",
    "    k_fracs=args.k_fracs,\n",
    "    use_true_class=args.use_true_class,\n",
    "    label=\"IG (Captum) BASELINE\",\n",
    "  )\n",
    "\n",
    "\n",
    "    # 8) Summary\n",
    "    print(\"\\n=== Summary: mean Δ = p_orig - p_mask ===\")\n",
    "    print(\"\\nFidelity+ (remove top-k% features):\")\n",
    "    for k in args.k_fracs:\n",
    "        print(\n",
    "            f\"k={k:.2f} | \"\n",
    "            f\"self={fp_self[k]:.4f} | \"\n",
    "            f\"self+1hop={fp_s1[k]:.4f} | \"\n",
    "            f\"self+2hop={fp_s2[k]:.4f} | \"\n",
    "            f\"rand={fp_rand[k]:.4f} | \"\n",
    "            f\"gradcam={fp_gc[k]:.4f} | \"\n",
    "            f\"lrp={fp_lrp[k]:.4f} | \"\n",
    "            f\"goat={fp_goat[k]:.4f} | \"\n",
    "            f\"lime={fp_lime[k]:.4f} | \"\n",
    "            f\"ig={fp_ig[k]:.4f} | \"\n",
    "        )\n",
    "\n",
    "    print(\"\\nFidelity- (remove bottom-k% features):\")\n",
    "    for k in args.k_fracs:\n",
    "        print(\n",
    "            f\"k={k:.2f} | \"\n",
    "            f\"self={fm_self[k]:.4f} | \"\n",
    "            f\"self+1hop={fm_s1[k]:.4f} | \"\n",
    "            f\"self+2hop={fm_s2[k]:.4f} | \"\n",
    "            f\"rand={fm_rand[k]:.4f} | \"\n",
    "            f\"gradcam={fm_gc[k]:.4f} | \"\n",
    "            f\"lrp={fm_lrp[k]:.4f} | \"\n",
    "            f\"goat={fm_goat[k]:.4f} | \"\n",
    "            f\"lime={fm_lime[k]:.4f} | \"\n",
    "            f\"ig={fm_ig[k]:.4f} | \"\n",
    "        )\n",
    "\n",
    "    print(\"\\n=== Build-Time Summary (seconds) ===\")\n",
    "    print(f\"{'method':20s} {'build_time':>12s}\")\n",
    "    for name, t in method_times.items():\n",
    "        print(f\"{name:20s} {t:12.3f}\")\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "  main()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "cRjK-RPUer0n"
   },
   "source": [
    "## BASHAPE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "collapsed": true,
    "id": "jt6Hv4bqubrz",
    "outputId": "5796318f-f396-4237-b5d5-dc69eda3306a"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipython-input-2299389920.py:29: UserWarning: 'BAShapes' is deprecated, use 'datasets.ExplainerDataset' in combination with 'datasets.graph_generator.BAGraph' instead\n",
      "  dataset = BAShapes()   # synthetic, generated in memory\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data(x=[700, 10], edge_index=[2, 3958], y=[700], expl_mask=[700], edge_label=[3958])\n",
      "#nodes = 700, #edges = 3958\n",
      "#features = 10, #classes = 4\n",
      "\n",
      "Training GCN on BA-Shapes...\n",
      "Epoch 001 | Loss 1.3874 | Train 0.431 | Val 0.379 | Test 0.471\n",
      "Epoch 020 | Loss 1.2709 | Train 0.431 | Val 0.379 | Test 0.471\n",
      "Epoch 040 | Loss 1.2648 | Train 0.431 | Val 0.379 | Test 0.471\n",
      "Epoch 060 | Loss 1.2697 | Train 0.431 | Val 0.379 | Test 0.471\n",
      "Epoch 080 | Loss 1.2616 | Train 0.431 | Val 0.379 | Test 0.471\n",
      "Epoch 100 | Loss 1.2316 | Train 0.431 | Val 0.379 | Test 0.471\n",
      "Epoch 120 | Loss 1.1982 | Train 0.431 | Val 0.371 | Test 0.471\n",
      "Epoch 140 | Loss 1.1517 | Train 0.431 | Val 0.364 | Test 0.471\n",
      "Epoch 160 | Loss 1.0674 | Train 0.431 | Val 0.364 | Test 0.471\n",
      "Epoch 180 | Loss 1.0406 | Train 0.440 | Val 0.393 | Test 0.471\n",
      "Epoch 200 | Loss 1.0431 | Train 0.581 | Val 0.529 | Test 0.636\n",
      "Epoch 220 | Loss 0.9261 | Train 0.667 | Val 0.650 | Test 0.686\n",
      "Epoch 240 | Loss 0.9805 | Train 0.857 | Val 0.836 | Test 0.850\n",
      "Epoch 260 | Loss 0.9530 | Train 0.860 | Val 0.843 | Test 0.850\n",
      "Epoch 280 | Loss 0.9340 | Train 0.862 | Val 0.836 | Test 0.843\n",
      "Epoch 300 | Loss 0.8349 | Train 0.860 | Val 0.843 | Test 0.850\n",
      "Epoch 320 | Loss 0.8361 | Train 0.864 | Val 0.836 | Test 0.829\n",
      "Epoch 340 | Loss 0.7785 | Train 0.862 | Val 0.829 | Test 0.829\n",
      "Epoch 360 | Loss 0.8525 | Train 0.740 | Val 0.721 | Test 0.671\n",
      "Epoch 380 | Loss 0.7738 | Train 0.824 | Val 0.786 | Test 0.757\n",
      "Epoch 400 | Loss 0.7748 | Train 0.864 | Val 0.836 | Test 0.843\n",
      "Epoch 420 | Loss 0.7372 | Train 0.852 | Val 0.814 | Test 0.821\n",
      "Epoch 440 | Loss 0.7326 | Train 0.831 | Val 0.793 | Test 0.764\n",
      "Epoch 460 | Loss 0.7220 | Train 0.864 | Val 0.821 | Test 0.814\n",
      "Epoch 480 | Loss 0.7386 | Train 0.867 | Val 0.836 | Test 0.829\n",
      "Epoch 500 | Loss 0.7728 | Train 0.855 | Val 0.814 | Test 0.821\n",
      "Epoch 520 | Loss 0.7112 | Train 0.864 | Val 0.836 | Test 0.843\n",
      "Epoch 540 | Loss 0.6945 | Train 0.862 | Val 0.836 | Test 0.850\n",
      "Epoch 560 | Loss 0.6710 | Train 0.867 | Val 0.829 | Test 0.814\n",
      "Epoch 580 | Loss 0.7036 | Train 0.864 | Val 0.829 | Test 0.829\n",
      "Epoch 600 | Loss 0.7044 | Train 0.864 | Val 0.836 | Test 0.850\n",
      "Epoch 620 | Loss 0.7070 | Train 0.855 | Val 0.821 | Test 0.829\n",
      "Epoch 640 | Loss 0.7140 | Train 0.821 | Val 0.786 | Test 0.843\n",
      "Epoch 660 | Loss 0.7102 | Train 0.843 | Val 0.800 | Test 0.779\n",
      "Epoch 680 | Loss 0.6986 | Train 0.845 | Val 0.800 | Test 0.800\n",
      "Epoch 700 | Loss 0.6706 | Train 0.864 | Val 0.821 | Test 0.821\n",
      "Epoch 720 | Loss 0.6974 | Train 0.840 | Val 0.793 | Test 0.771\n",
      "Epoch 740 | Loss 0.6828 | Train 0.864 | Val 0.836 | Test 0.843\n",
      "Epoch 760 | Loss 0.6474 | Train 0.840 | Val 0.793 | Test 0.771\n",
      "Epoch 780 | Loss 0.6463 | Train 0.819 | Val 0.771 | Test 0.750\n",
      "Epoch 800 | Loss 0.6625 | Train 0.821 | Val 0.771 | Test 0.750\n",
      "Epoch 820 | Loss 0.6452 | Train 0.864 | Val 0.829 | Test 0.829\n",
      "Epoch 840 | Loss 0.6505 | Train 0.860 | Val 0.821 | Test 0.836\n",
      "Epoch 860 | Loss 0.6135 | Train 0.867 | Val 0.829 | Test 0.850\n",
      "Epoch 880 | Loss 0.6743 | Train 0.814 | Val 0.750 | Test 0.743\n",
      "Epoch 900 | Loss 0.6462 | Train 0.864 | Val 0.814 | Test 0.821\n",
      "Epoch 920 | Loss 0.6436 | Train 0.838 | Val 0.793 | Test 0.771\n",
      "Epoch 940 | Loss 0.6482 | Train 0.821 | Val 0.786 | Test 0.829\n",
      "Epoch 960 | Loss 0.6344 | Train 0.817 | Val 0.750 | Test 0.743\n",
      "Epoch 980 | Loss 0.6732 | Train 0.845 | Val 0.793 | Test 0.779\n",
      "Epoch 1000 | Loss 0.6686 | Train 0.864 | Val 0.829 | Test 0.850\n",
      "\n",
      "Mean |Z2_decomp - logits_model| = 0.000000 (should be small-ish)\n",
      "\n",
      "Evaluating methods on 140 test nodes...\n",
      "\n",
      "Mean metrics over test nodes (k_frac = 0.5, keep/drop 50% neighbors):\n",
      "Decomp       | Fid+=0.390  Fid-=0.066  Sparsity=0.268\n",
      "GNNExplainer | Fid+=0.164  Fid-=0.130  Sparsity=0.000\n",
      "IG           | Fid+=0.407  Fid-=0.024  Sparsity=0.000\n",
      "Random       | Fid+=0.112  Fid-=0.172  Sparsity=0.072\n",
      "GOAT         | Fid+=0.407  Fid-=0.024  Sparsity=0.000\n",
      "GradCAM      | Fid+=0.361  Fid-=0.083  Sparsity=0.516\n",
      "GNN-LRP      | Fid+=0.344  Fid-=0.087  Sparsity=0.169\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.datasets import BAShapes\n",
    "from torch_geometric.nn import GCNConv\n",
    "from torch_geometric.utils import add_remaining_self_loops\n",
    "from torch_scatter import scatter_add\n",
    "from torch_geometric.explain import Explainer\n",
    "from torch_geometric.explain.algorithm import GNNExplainer,DummyExplainer\n",
    "\n",
    "from torch_geometric.explain.algorithm import CaptumExplainer\n",
    "from captum.attr import IntegratedGradients\n",
    "\n",
    "from torch_geometric.utils import to_undirected, add_self_loops, degree\n",
    "\n",
    "# ============================================================\n",
    "# 0. Setup\n",
    "# ============================================================\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "torch.manual_seed(0)\n",
    "\n",
    "\n",
    "\n",
    "# ============================================================\n",
    "# 1. Load BA-Shapes dataset\n",
    "# ============================================================\n",
    "\n",
    "dataset = BAShapes()   # synthetic, generated in memory\n",
    "data = dataset[0].to(device)\n",
    "\n",
    "print(data)\n",
    "print(f\"#nodes = {data.num_nodes}, #edges = {data.num_edges}\")\n",
    "print(f\"#features = {data.num_node_features}, #classes = {int(data.y.max())+1}\")\n",
    "\n",
    "num_nodes = data.num_nodes\n",
    "num_features = data.num_node_features\n",
    "num_classes = int(data.y.max().item()) + 1\n",
    "\n",
    "# Simple random train/val/test split if not already present\n",
    "if not hasattr(data, \"train_mask\"):\n",
    "    perm = torch.randperm(num_nodes, device=device)\n",
    "    n_train = int(0.6 * num_nodes)\n",
    "    n_val   = int(0.2 * num_nodes)\n",
    "    n_test  = num_nodes - n_train - n_val\n",
    "\n",
    "    train_idx = perm[:n_train]\n",
    "    val_idx   = perm[n_train:n_train+n_val]\n",
    "    test_idx  = perm[n_train+n_val:]\n",
    "\n",
    "    data.train_mask = torch.zeros(num_nodes, dtype=torch.bool, device=device)\n",
    "    data.val_mask   = torch.zeros(num_nodes, dtype=torch.bool, device=device)\n",
    "    data.test_mask  = torch.zeros(num_nodes, dtype=torch.bool, device=device)\n",
    "\n",
    "    data.train_mask[train_idx] = True\n",
    "    data.val_mask[val_idx]     = True\n",
    "    data.test_mask[test_idx]   = True\n",
    "\n",
    "# ============================================================\n",
    "# 2. Standard 2-layer GCN (no modifications for training)\n",
    "# ============================================================\n",
    "\n",
    "class VanillaGCN(nn.Module):\n",
    "    def __init__(self, in_channels, hidden_channels, out_channels, dropout=0.5):\n",
    "        super().__init__()\n",
    "        self.conv1 = GCNConv(in_channels, hidden_channels, cached=False)\n",
    "        self.conv2 = GCNConv(hidden_channels, out_channels, cached=False)\n",
    "        self.dropout = dropout\n",
    "\n",
    "    def forward(self, x, edge_index):\n",
    "        x = self.conv1(x, edge_index)\n",
    "        x = F.relu(x)\n",
    "        x = F.dropout(x, p=self.dropout, training=self.training)\n",
    "        x = self.conv2(x, edge_index)  # logits\n",
    "        return x\n",
    "\n",
    "model = VanillaGCN(num_features, hidden_channels=32, out_channels=num_classes).to(device)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)\n",
    "\n",
    "def train_one_epoch():\n",
    "    model.train()\n",
    "    optimizer.zero_grad()\n",
    "    logits = model(data.x, data.edge_index)\n",
    "    loss = F.cross_entropy(logits[data.train_mask], data.y[data.train_mask])\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    return loss.item()\n",
    "\n",
    "@torch.no_grad()\n",
    "def accuracy(mask):\n",
    "    model.eval()\n",
    "    logits = model(data.x, data.edge_index)\n",
    "    pred = logits.argmax(dim=-1)\n",
    "    correct = (pred[mask] == data.y[mask]).sum()\n",
    "    return float(correct) / int(mask.sum())\n",
    "\n",
    "print(\"\\nTraining GCN on BA-Shapes...\")\n",
    "for epoch in range(1, 1001):\n",
    "    loss = train_one_epoch()\n",
    "    if epoch % 20 == 0 or epoch == 1:\n",
    "        train_acc = accuracy(data.train_mask)\n",
    "        val_acc   = accuracy(data.val_mask)\n",
    "        test_acc  = accuracy(data.test_mask)\n",
    "        print(f\"Epoch {epoch:03d} | Loss {loss:.4f} | \"\n",
    "              f\"Train {train_acc:.3f} | Val {val_acc:.3f} | Test {test_acc:.3f}\")\n",
    "\n",
    "# ============================================================\n",
    "# 3. Post-hoc decomposition of the trained GCN (neighbor-level)\n",
    "# ============================================================\n",
    "\n",
    "@torch.no_grad()\n",
    "def compute_gcn_norm(edge_index, num_nodes, device, dtype=torch.float32):\n",
    "    \"\"\"\n",
    "    Rebuild GCN-like normalization:\n",
    "        A_hat = A + I, then D^{-1/2} A_hat D^{-1/2}.\n",
    "    We treat edge_index as [src, dst]; internally we use row=dst, col=src.\n",
    "    \"\"\"\n",
    "    edge_index = edge_index.flip(0)  # now [dst, src]\n",
    "    edge_index, _ = add_remaining_self_loops(edge_index, num_nodes=num_nodes)\n",
    "    row, col = edge_index\n",
    "\n",
    "    edge_weight = torch.ones(row.size(0), device=device, dtype=dtype)\n",
    "    deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)\n",
    "    deg_inv_sqrt = deg.pow(-0.5)\n",
    "    deg_inv_sqrt[deg_inv_sqrt == float(\"inf\")] = 0\n",
    "\n",
    "    norm = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]  # [E]\n",
    "    return edge_index, norm  # row=dst, col=src with symmetric norm\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def layer1_decomposition(model, x, edge_index):\n",
    "    \"\"\"\n",
    "    Recompute layer-1 pre-activations and messages:\n",
    "        Z1 = A_norm X W1 + b1\n",
    "    Messages: m1_{u<-v} = norm_uv * (X_v W1)\n",
    "    \"\"\"\n",
    "    device = x.device\n",
    "    num_nodes = x.size(0)\n",
    "\n",
    "    W1 = model.conv1.lin.weight.T     # [F_in, H]\n",
    "    b1 = model.conv1.bias             # [H] or None\n",
    "\n",
    "    ei1, norm1 = compute_gcn_norm(edge_index, num_nodes, device)\n",
    "    row1, col1 = ei1  # row1=dst=u, col1=src=v\n",
    "\n",
    "    xW1 = x @ W1                      # [N, H]\n",
    "    m1 = xW1[col1] * norm1.view(-1, 1)  # [E1, H]\n",
    "\n",
    "    z1 = scatter_add(m1, row1, dim=0, dim_size=num_nodes)  # [N, H]\n",
    "    if b1 is not None:\n",
    "        z1 = z1 + b1\n",
    "    h1 = F.relu(z1)\n",
    "    return h1, z1, m1, ei1\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def layer2_decomposition(model, h1, edge_index):\n",
    "    \"\"\"\n",
    "    Recompute layer-2 pre-activations and messages:\n",
    "        Z2 = A_norm H1 W2 + b2\n",
    "    Messages: m2_{u<-v} = norm_uv * (H1_v W2)\n",
    "    \"\"\"\n",
    "    device = h1.device\n",
    "    num_nodes = h1.size(0)\n",
    "\n",
    "    W2 = model.conv2.lin.weight.T     # [H, C]\n",
    "    b2 = model.conv2.bias             # [C] or None\n",
    "\n",
    "    ei2, norm2 = compute_gcn_norm(edge_index, num_nodes, device)\n",
    "    row2, col2 = ei2\n",
    "\n",
    "    h1W2 = h1 @ W2                    # [N, C]\n",
    "    m2 = h1W2[col2] * norm2.view(-1, 1)  # [E2, C]\n",
    "\n",
    "    z2 = scatter_add(m2, row2, dim=0, dim_size=num_nodes)  # [N, C]\n",
    "    if b2 is not None:\n",
    "        z2 = z2 + b2\n",
    "    return z2, m2, ei2\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def posthoc_decompose(model, data):\n",
    "    \"\"\"\n",
    "    Full post-hoc decomposition of the trained 2-layer GCN:\n",
    "      - layer1 messages/embeddings\n",
    "      - layer2 messages/logits\n",
    "    \"\"\"\n",
    "    h1, z1, m1, ei1 = layer1_decomposition(model, data.x, data.edge_index)\n",
    "    z2, m2, ei2 = layer2_decomposition(model, h1, data.edge_index)\n",
    "    return {\n",
    "        \"h1\": h1,\n",
    "        \"z1\": z1,\n",
    "        \"messages1\": m1,\n",
    "        \"edge_index1\": ei1,\n",
    "        \"z2\": z2,            # logits\n",
    "        \"messages2\": m2,\n",
    "        \"edge_index2\": ei2,\n",
    "    }\n",
    "\n",
    "# sanity check\n",
    "with torch.no_grad():\n",
    "    expl_tmp = posthoc_decompose(model, data)\n",
    "    logits_model = model(data.x, data.edge_index)\n",
    "    diff = (expl_tmp[\"z2\"] - logits_model).abs().mean().item()\n",
    "    print(f\"\\nMean |Z2_decomp - logits_model| = {diff:.6f} (should be small-ish)\")\n",
    "\n",
    "# ============================================================\n",
    "# 4. Neighbor importance: Decomp, GNNExplainer, Random\n",
    "# ============================================================\n",
    "\n",
    "@torch.no_grad()\n",
    "def decomp_neighbor_scores(model, data, nid):\n",
    "    \"\"\"\n",
    "    Our method: neighbor scores from final-layer messages.\n",
    "    Returns:\n",
    "      neighbors: [K] neighbor node indices\n",
    "      scores:    [K] importance scores (abs contribution to predicted class)\n",
    "    \"\"\"\n",
    "    expl = posthoc_decompose(model, data)\n",
    "    z2 = expl[\"z2\"]               # [N, C] logits\n",
    "    m2 = expl[\"messages2\"]        # [E2, C]\n",
    "    ei2 = expl[\"edge_index2\"]     # [2, E2], row=dst, col=src\n",
    "\n",
    "    row2, col2 = ei2\n",
    "    probs = F.softmax(z2, dim=-1)\n",
    "    c = int(probs[nid].argmax().item())\n",
    "\n",
    "    mask = (row2 == nid)\n",
    "    neigh = col2[mask]                     # neighbors v such that v->nid\n",
    "    edge_scores = m2[mask, c].abs()        # [num_edges_into_nid]\n",
    "\n",
    "    if neigh.numel() == 0:\n",
    "        return neigh, edge_scores\n",
    "\n",
    "    unique_neigh, inv = torch.unique(neigh, return_inverse=True)\n",
    "    neigh_scores = scatter_add(edge_scores, inv, dim=0,\n",
    "                               dim_size=unique_neigh.size(0))\n",
    "    return unique_neigh, neigh_scores\n",
    "\n",
    "\n",
    "def get_in_neighbors(edge_index, nid):\n",
    "    src, dst = edge_index\n",
    "    mask = (dst == nid)\n",
    "    return src[mask]\n",
    "\n",
    "# --- GNNExplainer in edge mode (NO @torch.no_grad here!) ---\n",
    "expl_gnn_edge = Explainer(\n",
    "    model=model,\n",
    "    algorithm=GNNExplainer(epochs=100),\n",
    "    explanation_type='model',\n",
    "    node_mask_type=None,\n",
    "    edge_mask_type='object',\n",
    "    model_config=dict(\n",
    "        mode='multiclass_classification',\n",
    "        task_level='node',\n",
    "        return_type='raw',  # model returns logits\n",
    "    ),\n",
    ")\n",
    "\n",
    "def gnnexplainer_neighbor_scores(nid):\n",
    "    \"\"\"\n",
    "    Use GNNExplainer to get edge importance for node nid, then\n",
    "    aggregate to neighbor-level scores.\n",
    "    This must run with gradients enabled.\n",
    "    \"\"\"\n",
    "    model.eval()\n",
    "    exp = expl_gnn_edge(data.x, data.edge_index, index=int(nid))\n",
    "\n",
    "    e_src, e_dst = exp.edge_index\n",
    "    e_mask = exp.edge_mask\n",
    "\n",
    "    mask = (e_dst == nid)\n",
    "    neigh = e_src[mask]\n",
    "    edge_scores = e_mask[mask].abs()\n",
    "\n",
    "    if neigh.numel() == 0:\n",
    "        return neigh.to(device), edge_scores.to(device)\n",
    "\n",
    "    unique_neigh, inv = torch.unique(neigh, return_inverse=True)\n",
    "    neigh_scores = scatter_add(edge_scores, inv, dim=0,\n",
    "                               dim_size=unique_neigh.size(0))\n",
    "    return unique_neigh.to(device), neigh_scores.to(device)\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def random_neighbor_scores(nid):\n",
    "    \"\"\"\n",
    "    Random neighbor baseline: uniform random scores over incoming neighbors.\n",
    "    \"\"\"\n",
    "    neigh = get_in_neighbors(data.edge_index, nid)\n",
    "    if neigh.numel() == 0:\n",
    "        return neigh, torch.empty(0, device=device)\n",
    "    scores = torch.rand(neigh.size(0), device=device)\n",
    "    return neigh, scores\n",
    "\n",
    "expl_ig = Explainer(\n",
    "    model=model,\n",
    "    algorithm=CaptumExplainer(IntegratedGradients),\n",
    "    explanation_type='model',\n",
    "    node_mask_type='attributes', # Changed from None\n",
    "    edge_mask_type=None,         # Changed from 'object'\n",
    "    model_config=dict(\n",
    "        mode='multiclass_classification',\n",
    "        task_level='node',\n",
    "        return_type='raw',\n",
    "    ),\n",
    ")\n",
    "\n",
    "def ig_neighbor_scores(nid):\n",
    "    \"\"\"\n",
    "    IG produces feature scores for node nid.\n",
    "    We convert those into a SINGLE scalar importance for node nid,\n",
    "    and assign that importance to all neighbors.\n",
    "    \"\"\"\n",
    "    model.eval()\n",
    "    exp = expl_ig(data.x, data.edge_index, index=int(nid))\n",
    "\n",
    "    node_mask = exp.node_mask  # [F] for the target node\n",
    "\n",
    "    # Aggregate feature importance for the target node into a single scalar\n",
    "    node_importance = node_mask.abs().sum().item()\n",
    "\n",
    "    neighbors = get_in_neighbors(data.edge_index, nid)\n",
    "    if neighbors.numel() == 0:\n",
    "        return neighbors.to(device), torch.empty(0, device=device)\n",
    "\n",
    "    # Assign this scalar importance to all neighbors\n",
    "    scores = torch.full((neighbors.size(0),), node_importance, device=device)\n",
    "    return neighbors.to(device), scores\n",
    "\n",
    "\n",
    "###\n",
    "#GOAT Method#\n",
    "###\n",
    "\n",
    "# =========================\n",
    "# 1) Build dense Â for BA-Shapes\n",
    "# =========================\n",
    "def get_in_neighbors(edge_index: torch.Tensor, nid: int) -> torch.Tensor:\n",
    "    \"\"\"Return unique in-neighbors of nid (src where dst == nid).\"\"\"\n",
    "    src, dst = edge_index\n",
    "    mask = (dst == nid)\n",
    "    return src[mask].unique()\n",
    "\n",
    "N = data.num_nodes\n",
    "D_total = data.num_node_features\n",
    "\n",
    "def build_gcn_norm_dense(edge_index, num_nodes):\n",
    "    \"\"\"\n",
    "    Build symmetric-normalized adjacency Â = D^{-1/2} (A + I) D^{-1/2} in dense form.\n",
    "    \"\"\"\n",
    "    edge_index = to_undirected(edge_index, num_nodes=num_nodes)\n",
    "    edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)\n",
    "    row, col = edge_index\n",
    "    deg = degree(row, num_nodes=num_nodes)\n",
    "    deg_inv_sqrt = deg.pow(-0.5)\n",
    "    deg_inv_sqrt[torch.isinf(deg_inv_sqrt)] = 0.\n",
    "    w = deg_inv_sqrt[row] * deg_inv_sqrt[col]  # scalar per edge\n",
    "    A = torch.zeros((num_nodes, num_nodes), device=device)\n",
    "    A.index_put_((row, col), w, accumulate=True)\n",
    "    return A  # [N, N]\n",
    "\n",
    "A_hat = build_gcn_norm_dense(data.edge_index, N)  # [N, N]\n",
    "\n",
    "# Precompute X^T once\n",
    "XT = data.x.t()   # [D_total, N]\n",
    "\n",
    "\n",
    "# =========================\n",
    "# 2) Pre-activations and ReLU masks for 2-layer GCN\n",
    "# =========================\n",
    "\n",
    "@torch.no_grad()\n",
    "def preactivations_and_masks_2layer():\n",
    "    \"\"\"\n",
    "    For a 2-layer GCN with:\n",
    "        z1 = Â X W1\n",
    "        h1 = ReLU(z1)\n",
    "        z2 = Â h1 W2\n",
    "        logits = z2\n",
    "\n",
    "    We compute:\n",
    "      z1: [N, H], z2: [N, C],\n",
    "      M1: [N, H] ReLU mask for layer 1 (1 if z1>0, else 0),\n",
    "      W1: [D_total, H], W2: [H, C].\n",
    "    \"\"\"\n",
    "    # Extract linear weights from your GCNConv layers\n",
    "    W1 = model.conv1.lin.weight.T    # [D_total, H]\n",
    "    W2 = model.conv2.lin.weight.T    # [H, C]\n",
    "\n",
    "    # First layer pre-activation: z1 = Â X W1\n",
    "    z1 = A_hat @ (data.x @ W1)       # [N, H]\n",
    "    M1 = (z1 > 0).float()\n",
    "\n",
    "    # Second layer pre-activation: z2 = Â (ReLU(z1) W2)\n",
    "    h1 = M1 * z1                     # [N, H]\n",
    "    z2 = A_hat @ (h1 @ W2)           # [N, C]\n",
    "\n",
    "    return z1, z2, M1, W1, W2\n",
    "\n",
    "z1_goat, z2_goat, M1_goat, W1_goat, W2_goat = preactivations_and_masks_2layer()\n",
    "\n",
    "\n",
    "# =========================\n",
    "# 3) GOAt-style per-feature attribution (2-layer)\n",
    "# =========================\n",
    "\n",
    "@torch.no_grad()\n",
    "def goat_2layer_feature_importance_for_node(i: int):\n",
    "    \"\"\"\n",
    "    Returns phi: [D_total], feature contributions to the predicted class logit at node i\n",
    "    for a 2-layer GCN (conv1 -> conv2).\n",
    "\n",
    "    Model:\n",
    "        z1 = Â X W1          (before ReLU)\n",
    "        h1 = ReLU(z1) = M1 ⊙ z1\n",
    "        z2 = Â h1 W2         (logits)\n",
    "\n",
    "    We approximate:\n",
    "        phi_f(i,c) ≈ sum_{h,u} x_{u,f} * [ (r_i^T D1_h Â)_u ] * W1[f,h] * W2[h,c]\n",
    "\n",
    "    where:\n",
    "        r_i^T is row i of Â,\n",
    "        D1_h = diag(M1[:,h]) is the ReLU mask for hidden unit h across nodes.\n",
    "    \"\"\"\n",
    "    i = int(i)\n",
    "\n",
    "    # predicted class c at node i\n",
    "    logits = model(data.x, data.edge_index)     # [N, C]\n",
    "    c = int(logits[i].argmax().item())\n",
    "\n",
    "    # row i of Â: how i aggregates from all nodes\n",
    "    r = A_hat[i, :]                             # [N]\n",
    "\n",
    "    # Initialize feature importance vector\n",
    "    phi = torch.zeros(D_total, device=device)   # [D_total]\n",
    "\n",
    "    H = W1_goat.size(1)  # hidden dimension\n",
    "\n",
    "    for h in range(H):\n",
    "        # Mask for hidden unit h across all nodes\n",
    "        # D1_h applied as elementwise multiplication with M1[:,h]\n",
    "        # First: r_i^T D1_h Â  ->  (r * M1[:,h])^T Â\n",
    "        t1 = (r * M1_goat[:, h]) @ A_hat        # [N]\n",
    "\n",
    "        # Aggregate over source nodes u:\n",
    "        # v_f = sum_u x_{u,f} * t1[u] = (X^T @ t1)_f\n",
    "        v = XT @ t1                             # [D_total]\n",
    "\n",
    "        # Chain weight for path (f -> h -> c)\n",
    "        # W1[f,h] * W2[h,c], but W1 is [D_total, H], W2 is [H, C]\n",
    "        chain = W1_goat[:, h] * W2_goat[h, c]   # [D_total]\n",
    "\n",
    "        # Add contribution for this hidden unit\n",
    "        phi += v * chain                        # [D_total]\n",
    "\n",
    "    return phi  # [D_total], feature importance for node i, class c\n",
    "\n",
    "@torch.no_grad()\n",
    "def goat_neighbor_scores(nid):\n",
    "    phi = goat_2layer_feature_importance_for_node(nid)  # [D_total]\n",
    "    node_score = phi.abs().sum().item()                 # collapse to scalar\n",
    "\n",
    "    neigh = get_in_neighbors(data.edge_index, nid)\n",
    "    if neigh.numel() == 0:\n",
    "        return neigh, torch.empty(0, device=device)\n",
    "    scores = torch.full((neigh.size(0),), node_score, device=device)\n",
    "    return neigh, scores\n",
    "\n",
    "##GRAD-CAM\n",
    "def gradcam_neighbor_scores(model, data, nid: int):\n",
    "    \"\"\"\n",
    "    Grad-CAM for node nid on BA-Shapes:\n",
    "\n",
    "    1) Take conv1 outputs h1 (after ReLU), keep grad.\n",
    "    2) Backprop from node nid's predicted logit.\n",
    "    3) alpha_k = mean_n grad_{n,k}\n",
    "    4) node_score[n] = ReLU( sum_k alpha_k * h1[n,k] )\n",
    "    5) Return scores restricted to neighbors of nid.\n",
    "    \"\"\"\n",
    "    model.eval()\n",
    "    nid = int(nid)\n",
    "\n",
    "    x = data.x.to(device)\n",
    "    edge_index = data.edge_index.to(device)\n",
    "\n",
    "    # Forward up to conv1\n",
    "    x1 = model.conv1(x, edge_index)      # [N, H]\n",
    "    h1 = F.relu(x1)\n",
    "    h1.retain_grad()                     # we want grad wrt h1\n",
    "\n",
    "    # Forward to logits\n",
    "    out = model.conv2(h1, edge_index)    # [N, C]\n",
    "    logits = out\n",
    "    c = int(logits[nid].argmax().item())\n",
    "\n",
    "    # Backprop from target logit\n",
    "    model.zero_grad()\n",
    "    if h1.grad is not None:\n",
    "        h1.grad.zero_()\n",
    "    logits[nid, c].backward(retain_graph=False)\n",
    "\n",
    "    grads = h1.grad                      # [N, H]\n",
    "    # Channel weights α_k: global average over nodes\n",
    "    alpha = grads.mean(dim=0)            # [H]\n",
    "\n",
    "    # Node-wise Grad-CAM scores\n",
    "    node_scores = F.relu((h1 * alpha)    # [N, H]\n",
    "                         .sum(dim=1))    # [N]\n",
    "\n",
    "    neighbors = get_in_neighbors(edge_index, nid)\n",
    "    if neighbors.numel() == 0:\n",
    "        return neighbors, torch.empty(0, device=device)\n",
    "\n",
    "    scores = node_scores[neighbors]\n",
    "    return neighbors, scores\n",
    "\n",
    "##GNN-LRP\n",
    "def gnn_lrp_neighbor_scores(model, data, nid: int, eps: float = 1e-6):\n",
    "    \"\"\"\n",
    "    Simplified GNN-LRP for a 2-layer GCN on BA-Shapes.\n",
    "\n",
    "    We only propagate relevance through the last GCN layer:\n",
    "\n",
    "        z2 = Â h1 W2\n",
    "\n",
    "    For node nid and class c:\n",
    "        m_v = Â[nid, v] * (h1[v] @ W2[:, c])\n",
    "        R_v = (m_v / (sum_v m_v + eps)) * z2[nid, c]\n",
    "\n",
    "    Neighbor score = |R_v|.\n",
    "    \"\"\"\n",
    "    model.eval()\n",
    "    nid = int(nid)\n",
    "\n",
    "    x = data.x.to(device)\n",
    "    edge_index = data.edge_index.to(device)\n",
    "\n",
    "    # Forward 1st layer\n",
    "    z1 = model.conv1(x, edge_index)\n",
    "    h1 = F.relu(z1)                       # [N, H]\n",
    "\n",
    "    # Extract W2 from conv2\n",
    "    # conv2.lin.weight: [C, H]  -> transpose to [H, C]\n",
    "    W2 = model.conv2.lin.weight.T        # [H, C]\n",
    "\n",
    "    # Logits via explicit GCN conv2 computation\n",
    "    # support = h1 @ W2  # [N, C]\n",
    "    support = h1 @ W2                    # [N, C]\n",
    "    z2 = A_hat @ support                 # [N, C]  (same as conv2 without bias)\n",
    "\n",
    "    # target class and relevance at output\n",
    "    c = int(z2[nid].argmax().item())\n",
    "    R_out = z2[nid, c]                   # scalar relevance at (nid,c)\n",
    "\n",
    "    # message from each node v into (nid, c)\n",
    "    # A_hat[nid] is row vector [N]\n",
    "    row_nid = A_hat[nid]                 # [N]\n",
    "    support_c = support[:, c]           # [N]\n",
    "    m = row_nid * support_c             # [N]\n",
    "\n",
    "    Z = m.sum() + eps\n",
    "    R_nodes = (m / Z) * R_out           # [N]\n",
    "\n",
    "    neighbors = get_in_neighbors(edge_index, nid)\n",
    "    if neighbors.numel() == 0:\n",
    "        return neighbors, torch.empty(0, device=device)\n",
    "\n",
    "    scores = R_nodes[neighbors].abs()\n",
    "    return neighbors, scores\n",
    "\n",
    "def to_full_neighbor_vector(neighbor_fn):\n",
    "    def wrapper(nid: int):\n",
    "        neighbors, scores = neighbor_fn(model, data, nid)\n",
    "        full = torch.zeros(N, device=device)\n",
    "        if neighbors.numel() > 0:\n",
    "            full[neighbors] = scores\n",
    "        return full\n",
    "    return wrapper\n",
    "\n",
    "# ============================================================\n",
    "# 5. Metrics: Fidelity+ / Fidelity− / Sparsity (neighbor-level)\n",
    "# ============================================================\n",
    "\n",
    "@torch.no_grad()\n",
    "def fidelity_negative_neighbors(model, data, nid, neighbors, scores, k_frac=0.5):\n",
    "    \"\"\"\n",
    "    Fidelity+ : keep ONLY top-k_frac neighbors into nid, drop others.\n",
    "    Returns ratio: prob_keep / prob_orig (for original predicted class).\n",
    "    \"\"\"\n",
    "    if neighbors.numel() == 0:\n",
    "        return 1.0\n",
    "\n",
    "    logits = model(data.x, data.edge_index)\n",
    "    probs = F.softmax(logits, dim=-1)\n",
    "    c = int(probs[nid].argmax().item())\n",
    "    orig = float(probs[nid, c].item())\n",
    "\n",
    "    deg_u = neighbors.size(0)\n",
    "    k = max(1, int(k_frac * deg_u))\n",
    "\n",
    "    abs_scores = scores.abs()\n",
    "    _, idx = torch.topk(abs_scores, k=k)\n",
    "    top_neighbors = neighbors[idx]\n",
    "\n",
    "    src, dst = data.edge_index\n",
    "    incoming = (dst == nid)\n",
    "    keep_from_top = incoming & torch.isin(src, top_neighbors)\n",
    "    keep_mask = (~incoming) | keep_from_top\n",
    "\n",
    "    edge_index_keep = data.edge_index[:, keep_mask]\n",
    "\n",
    "    logits2 = model(data.x, edge_index_keep)\n",
    "    probs2 = F.softmax(logits2, dim=-1)\n",
    "    keep = float(probs2[nid, c].item())\n",
    "\n",
    "    return orig - keep\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def fidelity_positive_neighbors(model, data, nid, neighbors, scores, k_frac=0.5):\n",
    "    \"\"\"\n",
    "    Fidelity− : REMOVE top-k_frac neighbors into nid.\n",
    "    Returns probability drop: orig_prob - new_prob (for original class).\n",
    "    \"\"\"\n",
    "    if neighbors.numel() == 0:\n",
    "        return 0.0\n",
    "\n",
    "    logits = model(data.x, data.edge_index)\n",
    "    probs = F.softmax(logits, dim=-1)\n",
    "    c = int(probs[nid].argmax().item())\n",
    "    orig = float(probs[nid, c].item())\n",
    "\n",
    "    deg_u = neighbors.size(0)\n",
    "    k = max(1, int(k_frac * deg_u))\n",
    "\n",
    "    abs_scores = scores.abs()\n",
    "    _, idx = torch.topk(abs_scores, k=k)\n",
    "    top_neighbors = neighbors[idx]\n",
    "\n",
    "    src, dst = data.edge_index\n",
    "    incoming = (dst == nid)\n",
    "    drop_edges = incoming & torch.isin(src, top_neighbors)\n",
    "    keep_mask = ~drop_edges\n",
    "\n",
    "    edge_index_drop = data.edge_index[:, keep_mask]\n",
    "\n",
    "    logits2 = model(data.x, edge_index_drop)\n",
    "    probs2 = F.softmax(logits2, dim=-1)\n",
    "    new = float(probs2[nid, c].item())\n",
    "\n",
    "    return orig - new\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def neighbor_sparsity(neighbors, scores, rel_thresh=0.1):\n",
    "    \"\"\"\n",
    "    Sparsity = 1 - (#active_neighbors / deg).\n",
    "    A neighbor is active if |score| >= rel_thresh * max(|score|).\n",
    "    \"\"\"\n",
    "    deg_u = neighbors.size(0)\n",
    "    if deg_u == 0:\n",
    "        return 1.0\n",
    "\n",
    "    abs_scores = scores.abs()\n",
    "    max_score = float(abs_scores.max().item())\n",
    "    if max_score == 0.0:\n",
    "        return 1.0\n",
    "\n",
    "    tau = rel_thresh * max_score\n",
    "    num_active = int((abs_scores >= tau).sum().item())\n",
    "    sparsity = 1.0 - num_active / deg_u\n",
    "    return sparsity\n",
    "\n",
    "# ============================================================\n",
    "# 6. Evaluation loop: mean metrics across test nodes & methods\n",
    "# ============================================================\n",
    "import time\n",
    "\n",
    "def timed(fn):\n",
    "    \"\"\"\n",
    "    Wraps an explanation function so it returns:\n",
    "        neighbors, scores, elapsed_time\n",
    "    \"\"\"\n",
    "    def wrapped(nid):\n",
    "        start = time.perf_counter()\n",
    "        neighbors, scores = fn(nid)\n",
    "        end = time.perf_counter()\n",
    "        return neighbors, scores, (end - start)\n",
    "    return wrapped\n",
    "\n",
    "# METHODS = {\n",
    "#     \"Decomp\":        lambda nid: decomp_neighbor_scores(model, data, nid),\n",
    "#     \"GNNExplainer\":  lambda nid: gnnexplainer_neighbor_scores(nid),\n",
    "#     \"IG\":            lambda nid: ig_neighbor_scores(nid),\n",
    "#     \"Random\":        lambda nid: random_neighbor_scores(nid),\n",
    "#     \"GOAT\":          lambda nid: goat_neighbor_scores(nid),\n",
    "#     \"GradCAM\":       lambda nid: gradcam_neighbor_scores(model, data, nid),\n",
    "#     \"GNN-LRP\":       lambda nid: gnn_lrp_neighbor_scores(model, data, nid),\n",
    "# }\n",
    "\n",
    "METHODS = {\n",
    "    \"Decomp\":        timed(lambda nid: decomp_neighbor_scores(model, data, nid)),\n",
    "    \"GNNExplainer\":  timed(lambda nid: gnnexplainer_neighbor_scores(nid)),\n",
    "    \"IG\":            timed(lambda nid: ig_neighbor_scores(nid)),\n",
    "    \"Random\":        timed(lambda nid: random_neighbor_scores(nid)),\n",
    "    \"GOAT\":          timed(lambda nid: goat_neighbor_scores(nid)),\n",
    "    \"GradCAM\":       timed(lambda nid: gradcam_neighbor_scores(model, data, nid)),\n",
    "    \"GNN-LRP\":       timed(lambda nid: gnn_lrp_neighbor_scores(model, data, nid)),\n",
    "}\n",
    "\n",
    "\n",
    "def evaluate_methods_on_nodes(node_indices, k_frac=0.5, rel_thresh=0.1):\n",
    "    \"\"\"\n",
    "    node_indices: iterable of node ids.\n",
    "    Returns method_name -> dict of mean metrics.\n",
    "    WARNING: we MUST allow gradients inside for GNNExplainer.\n",
    "    \"\"\"\n",
    "    results = {m: {\"fid_pos\": [], \"fid_neg\": [], \"sparsity\": [], \"time\":[]}\n",
    "               for m in METHODS.keys()}\n",
    "\n",
    "    for nid in node_indices:\n",
    "        nid = int(nid)\n",
    "        for name, get_scores in METHODS.items():\n",
    "            # neighbors, scores = get_scores(nid)\n",
    "            neighbors, scores, t = METHODS[name](nid)\n",
    "            # fid_plus, fid_minus, sparsity = compute_metrics(nid, neighbors, scores)\n",
    "            # results[name][\"fid_plus\"].append(fid_plus)\n",
    "            # results[name][\"fid_minus\"].append(fid_minus)\n",
    "            # results[name][\"sparsity\"].append(sparsity)\n",
    "            # results[name][\"time\"].append(t)\n",
    "\n",
    "            if neighbors.numel() == 0:\n",
    "                continue\n",
    "            if name == \"IG\" or name == \"GOAT\":\n",
    "              fp = fidelity_negative_neighbors(model, data, nid, neighbors, scores, k_frac=k_frac)\n",
    "              fn = fidelity_positive_neighbors(model, data, nid, neighbors, scores, k_frac=k_frac)\n",
    "            else:\n",
    "              fp = fidelity_positive_neighbors(model, data, nid, neighbors, scores, k_frac=k_frac)\n",
    "              fn = fidelity_negative_neighbors(model, data, nid, neighbors, scores, k_frac=k_frac)\n",
    "            sp = neighbor_sparsity(neighbors, scores, rel_thresh=rel_thresh)\n",
    "\n",
    "            results[name][\"fid_pos\"].append(fp)\n",
    "            results[name][\"fid_neg\"].append(fn)\n",
    "            results[name][\"sparsity\"].append(sp)\n",
    "            results[name][\"time\"].append(t)\n",
    "    summary = {}\n",
    "    for name, metrics in results.items():\n",
    "        if len(metrics[\"fid_pos\"]) == 0:\n",
    "            continue\n",
    "        summary[name] = {\n",
    "            \"fid_pos_mean\": float(torch.tensor(metrics[\"fid_pos\"]).mean().item()),\n",
    "            \"fid_neg_mean\": float(torch.tensor(metrics[\"fid_neg\"]).mean().item()),\n",
    "            \"sparsity_mean\": float(torch.tensor(metrics[\"sparsity\"]).mean().item()),\n",
    "            \"time\": float(torch.tensor(metrics[\"time\"]).mean().item()),\n",
    "        }\n",
    "    return summary\n",
    "\n",
    "# Run evaluation on test nodes\n",
    "test_nodes = data.test_mask.nonzero(as_tuple=False).view(-1)\n",
    "print(f\"\\nEvaluating methods on {len(test_nodes)} test nodes...\")\n",
    "summary = evaluate_methods_on_nodes(test_nodes, k_frac=0.5, rel_thresh=0.1)\n",
    "\n",
    "print(\"\\nMean metrics over test nodes (k_frac = 0.5, keep/drop 50% neighbors):\")\n",
    "for name, mets in summary.items():\n",
    "    print(f\"{name:12s} | \"\n",
    "          f\"Fid+={mets['fid_pos_mean']:.3f}  \"\n",
    "          f\"Fid-={mets['fid_neg_mean']:.3f}  \"\n",
    "          f\"Sparsity={mets['sparsity_mean']:.3f}\"\n",
    "          f\"Time={mets['time']:.6f}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "cV8I1T1xtkO8",
    "outputId": "65f2b010-44c7-47fa-af66-2245c9920381"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Decomp       | Fid+=0.390  Fid-=0.066  Sparsity=0.268Time=0.003105\n",
      "GNNExplainer | Fid+=0.164  Fid-=0.130  Sparsity=0.000Time=0.715819\n",
      "IG           | Fid+=0.407  Fid-=0.024  Sparsity=0.000Time=0.227541\n",
      "Random       | Fid+=0.112  Fid-=0.172  Sparsity=0.072Time=0.000231\n",
      "GOAT         | Fid+=0.407  Fid-=0.024  Sparsity=0.000Time=0.008780\n",
      "GradCAM      | Fid+=0.361  Fid-=0.083  Sparsity=0.516Time=0.004427\n",
      "GNN-LRP      | Fid+=0.344  Fid-=0.087  Sparsity=0.169Time=0.002405\n"
     ]
    }
   ],
   "source": [
    "for name, mets in summary.items():\n",
    "    print(f\"{name:12s} | \"\n",
    "          f\"Fid+={mets['fid_pos_mean']:.3f}  \"\n",
    "          f\"Fid-={mets['fid_neg_mean']:.3f}  \"\n",
    "          f\"Sparsity={mets['sparsity_mean']:.3f}\"\n",
    "          f\"Time={mets['time']:.6f}\")"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "machine_shape": "hm",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
