{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 114,
   "id": "3ff9343e-98a7-41ed-89a0-7e0f8ded95ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "from itertools import product\n",
    "\n",
    "import argparse\n",
    "from datasets import get_dataset\n",
    "from ours_train_eval import *\n",
    "\n",
    "from gib_gin import GIBGIN, Discriminator\n",
    "from gib_gat import GIBGAT\n",
    "from gib_sage import GIBSAGE\n",
    "from gib_gcn  import GIBGCN\n",
    "import numpy as np\n",
    "\n",
    "import time\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch import tensor\n",
    "from torch.optim import Adam\n",
    "from sklearn.model_selection import StratifiedKFold, train_test_split\n",
    "from torch_geometric.data import DataLoader, DenseDataLoader as DenseLoader\n",
    "from torch_geometric.utils import subgraph\n",
    "import numpy as np\n",
    "import pickle\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "id": "89168de3-7c84-463f-81da-0f23ed8043c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.utils import subgraph, add_self_loops\n",
    "from torch_geometric.data import Data,Batch\n",
    "import tqdm\n",
    "from collections import defaultdict\n",
    "def calculate_fidelity(data, node_mask, model, remove_nodes=True, top_k=None):\n",
    "    data = Batch.from_data_list([data])\n",
    "    device = next(model.parameters()).device\n",
    "    data = data.to(device)\n",
    "\n",
    "    # Compute sparsity\n",
    "    total_nodes = data.x.shape[0]\n",
    "    sparsity = 1 - node_mask.sum().item() / total_nodes if total_nodes > 0 else 0\n",
    "\n",
    "    # Get original predictions\n",
    "    original_pred = model(data)[0]\n",
    "    original_pred = F.softmax(original_pred, dim=1)\n",
    "    label = original_pred.argmax(-1).item()\n",
    "    \n",
    "    # Apply node mask (InvFidelity)\n",
    "    if remove_nodes:\n",
    "        masked_edge_index, _ = subgraph(node_mask == 0, edge_index=data.edge_index, num_nodes=data.x.size(0), relabel_nodes=True)\n",
    "        n_nodes = (node_mask == 0).sum()\n",
    "        new_data = Batch.from_data_list([Data(x=data.x[node_mask == 0], edge_index=masked_edge_index)])\n",
    "        masked_pred = model(new_data)[0]\n",
    "    else:\n",
    "        masked_x = data.x.clone()\n",
    "        masked_x[node_mask==1] = 0\n",
    "        new_data = Batch.from_data_list([Data(x=masked_x, edge_index=data.edge_index)])\n",
    "        masked_pred = model(new_data)[0]\n",
    "    masked_pred = F.softmax(masked_pred, dim=1)\n",
    "\n",
    "    # # Keep only important nodes (Fidelity)\n",
    "    # if remove_nodes:\n",
    "    #     masked_edge_index, _ = subgraph(node_mask == 1, edge_index=data.edge_index, num_nodes=data.x.size(0), relabel_nodes=True)\n",
    "    #     n_nodes = (node_mask == 1).sum()\n",
    "    #     new_data = Batch.from_data_list([Data(x=data.x[node_mask == 1], edge_index=masked_edge_index)])\n",
    "    #     retained_pred = model(new_data)[0]\n",
    "    # else:\n",
    "    #     masked_x = data.x.clone()\n",
    "    #     masked_x[node_mask==0] = 0\n",
    "    #     new_data = Batch.from_data_list([Data(x=masked_x, edge_index=data.edge_index)])\n",
    "    #     retained_pred = model(new_data)[0]\n",
    "    # retained_pred = F.softmax(retained_pred, dim=1)\n",
    "\n",
    "    inv_fidelity = 0\n",
    "    # Compute Fidelity+ and Fidelity-\n",
    "    # inv_fidelity = (original_pred[:, label] - \n",
    "    #                              retained_pred[:, label]).mean().item()\n",
    "\n",
    "    # fidelity = (original_pred[:, label] - \n",
    "    #                              masked_pred[:, label]).mean().item()\n",
    "\n",
    "    # inv_fidelity = (original_pred.argmax(-1) != \n",
    "    #                              retained_pred.argmax(-1)).float().item()\n",
    "\n",
    "    fidelity = (original_pred.argmax(-1)  != \n",
    "                                 masked_pred.argmax(-1) ).float().item()\n",
    "\n",
    "    n_fidelity = inv_fidelity*sparsity\n",
    "    n_inv_fidelity = inv_fidelity*(1-sparsity)\n",
    "    \n",
    "    # Compute HFidelity (harmonic mean of Fidelity+ and Fidelity-)\n",
    "    hfidelity = ((1+n_fidelity) * (1-n_inv_fidelity)) / (2 + n_fidelity - n_inv_fidelity) if (1 + n_fidelity - n_inv_fidelity) != 0 else 0\n",
    "\n",
    "    return {\n",
    "        \"Fidelity\": fidelity,\n",
    "        \"InvFidelity\": inv_fidelity,\n",
    "        \"HFidelity\": hfidelity\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 116,
   "id": "40fd79c6-81cf-4cd3-9b1f-a10ae112c46b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import torch.nn as nn\n",
    "from torch.nn import Linear, Sequential, ReLU, BatchNorm1d as BN\n",
    "from torch_geometric.nn import GINConv, global_mean_pool, JumpingKnowledge\n",
    "from torch_geometric.utils import to_dense_adj\n",
    "def aggregate(self, assignment, x, batch, edge_index):\n",
    "\n",
    "    max_id = torch.max(batch)\n",
    "    if torch.cuda.is_available():\n",
    "        EYE = torch.ones(2).cuda()\n",
    "    else:\n",
    "        EYE = torch.ones(2)\n",
    "\n",
    "    all_adj = to_dense_adj(edge_index, max_num_nodes=x.shape[0])[0]\n",
    "\n",
    "    all_pos_penalty = 0\n",
    "    all_graph_embedding = []\n",
    "    all_pos_embedding = []\n",
    "\n",
    "    st = 0\n",
    "    end = 0\n",
    "\n",
    "    for i in range(int(max_id + 1)):\n",
    "\n",
    "        j = 0\n",
    "        while batch[st + j] == i and st + j <= len(batch) - 2:\n",
    "            j += 1\n",
    "\n",
    "        end = st + j\n",
    "\n",
    "        if end == len(batch) - 1:\n",
    "            end += 1\n",
    "\n",
    "        one_batch_x = x[st:end]\n",
    "        one_batch_assignment = assignment[st:end]\n",
    "\n",
    "        group_features = torch.mm(torch.t(one_batch_assignment), one_batch_x)\n",
    "\n",
    "        pos_embedding = group_features[0].unsqueeze(dim=0)\n",
    "\n",
    "        Adj = all_adj[st:end,st:end]\n",
    "        new_adj = torch.mm(torch.t(one_batch_assignment), Adj)\n",
    "        new_adj = torch.mm(new_adj, one_batch_assignment)\n",
    "        normalize_new_adj = F.normalize(new_adj, p=1, dim=1)\n",
    "        norm_diag = torch.diag(normalize_new_adj)\n",
    "        pos_penalty = self.mse_loss(norm_diag, EYE)\n",
    "        graph_embedding = torch.mean(x, dim=0, keepdim=True)\n",
    "\n",
    "        all_pos_embedding.append(pos_embedding)\n",
    "        all_graph_embedding.append(graph_embedding)\n",
    "\n",
    "        all_pos_penalty = all_pos_penalty + pos_penalty\n",
    "\n",
    "        st = end\n",
    "\n",
    "    all_pos_embedding = torch.cat(tuple(all_pos_embedding), dim=0)\n",
    "    all_graph_embedding = torch.cat(tuple(all_graph_embedding), dim=0)\n",
    "    all_pos_penalty = all_pos_penalty / (max_id + 1)\n",
    "\n",
    "    return all_pos_embedding,all_graph_embedding, all_pos_penalty\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 117,
   "id": "f6cccd69-01fa-45ec-bed4-727c7864c697",
   "metadata": {},
   "outputs": [],
   "source": [
    "# def generate_hard_masks(soft_mask):\n",
    "#     sparsity_levels = torch.arange(0.5,1, 0.05)\n",
    "#     hard_masks = []\n",
    "#     for sparsity in sparsity_levels:\n",
    "#         threshold = np.percentile(soft_mask, sparsity * 100)\n",
    "#         hard_mask = (soft_mask > threshold).int()\n",
    "#         if hard_mask.sum() == 0:\n",
    "#             hard_mask = (soft_mask > soft_mask.min()).int()\n",
    "#         hard_masks.append(hard_mask)\n",
    "#     return list(zip(sparsity_levels, hard_masks))\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "\n",
    "def generate_hard_masks(soft_mask):\n",
    "    soft_mask_flat = soft_mask.flatten()\n",
    "    total_elements = soft_mask_flat.numel()\n",
    "    sparsity_levels = torch.arange(0.5, 1.0, 0.05)\n",
    "    hard_masks = []\n",
    "\n",
    "    # Get sorted indices (ascending: lowest values first)\n",
    "    sorted_indices = torch.argsort(soft_mask_flat)\n",
    "\n",
    "    for sparsity in sparsity_levels:\n",
    "        num_to_mask = int(sparsity.item() * total_elements)\n",
    "        mask_flat = torch.ones_like(soft_mask_flat, dtype=torch.int)\n",
    "        \n",
    "        if num_to_mask >= total_elements:\n",
    "            num_to_mask = total_elements - 1  # keep at least one element\n",
    "        \n",
    "        # Zero out the lowest `num_to_mask` elements\n",
    "        mask_flat[sorted_indices[:num_to_mask]] = 0\n",
    "        \n",
    "        # Reshape to original shape\n",
    "        hard_mask = mask_flat.view_as(soft_mask)\n",
    "        hard_masks.append(hard_mask)\n",
    "\n",
    "    return list(zip(sparsity_levels, hard_masks))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 126,
   "id": "bb7fe06f-b81c-4a4c-bf59-4bfa6a08a671",
   "metadata": {},
   "outputs": [
    {
     "ename": "FileNotFoundError",
     "evalue": "[Errno 2] No such file or directory: 'results/bbbp.pkl'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mFileNotFoundError\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn [126], line 2\u001b[0m\n\u001b[1;32m      1\u001b[0m dataset_name \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbbbp\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[0;32m----> 2\u001b[0m seeds_results \u001b[38;5;241m=\u001b[39m pickle\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;124;43mf\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mresults/\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mdataset_name\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m.pkl\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mrb\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m)\n",
      "File \u001b[0;32m~/.conda/envs/imp/lib/python3.8/site-packages/IPython/core/interactiveshell.py:282\u001b[0m, in \u001b[0;36m_modified_open\u001b[0;34m(file, *args, **kwargs)\u001b[0m\n\u001b[1;32m    275\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m file \u001b[38;5;129;01min\u001b[39;00m {\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m2\u001b[39m}:\n\u001b[1;32m    276\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m    277\u001b[0m         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIPython won\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt let you open fd=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfile\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m by default \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    278\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mas it is likely to crash IPython. If you know what you are doing, \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    279\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124myou can use builtins\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m open.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    280\u001b[0m     )\n\u001b[0;32m--> 282\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mio_open\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfile\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'results/bbbp.pkl'"
     ]
    }
   ],
   "source": [
    "dataset_name = 'BBBP'\n",
    "seeds_results = pickle.load(open(f'results/{dataset_name}.pkl', 'rb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 122,
   "id": "19fb7849-114e-4a01-9a41-7c503a9c92a7",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 30%|██▉       | 33/111 [00:01<00:04, 16.39it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 77%|███████▋  | 85/111 [00:04<00:01, 20.46it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 111/111 [00:06<00:00, 16.77it/s]\n",
      "100%|██████████| 111/111 [00:05<00:00, 20.61it/s]\n",
      "  6%|▋         | 7/111 [00:00<00:06, 15.52it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 23%|██▎       | 25/111 [00:01<00:04, 17.67it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 28%|██▊       | 31/111 [00:01<00:05, 13.48it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 49%|████▊     | 54/111 [00:03<00:03, 17.14it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 111/111 [00:06<00:00, 17.24it/s]\n",
      " 23%|██▎       | 26/111 [00:01<00:04, 18.39it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|█████     | 56/111 [00:03<00:03, 13.88it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 57%|█████▋    | 63/111 [00:03<00:02, 23.84it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 111/111 [00:05<00:00, 19.44it/s]\n",
      " 25%|██▌       | 28/111 [00:00<00:02, 29.75it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 111/111 [00:05<00:00, 19.77it/s]\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "l = []\n",
    "for seed in range(len(seeds_results)):\n",
    "    train_idx, val_idx, test_idx, model = seeds_results[seed][0]\n",
    "    model.aggregate = lambda *x: aggregate(model, *x)\n",
    "    model = model.to(device)\n",
    "    dataset = get_dataset(dataset_name, sparse=True)\n",
    "    batch_size = 1\n",
    "    train_dataset = dataset[train_idx]\n",
    "    test_dataset = dataset[test_idx]\n",
    "    val_dataset = dataset[val_idx]\n",
    "    \n",
    "    if 'adj' in train_dataset[0]:\n",
    "        train_loader = DenseLoader(train_dataset, batch_size, shuffle=True)\n",
    "        val_loader = DenseLoader(val_dataset, batch_size, shuffle=False)\n",
    "        test_loader = DenseLoader(test_dataset, batch_size, shuffle=False)\n",
    "    else:\n",
    "        train_loader = DataLoader(train_dataset, batch_size, shuffle=True)\n",
    "        val_loader = DataLoader(val_dataset, batch_size, shuffle=False)\n",
    "        test_loader = DataLoader(test_dataset, batch_size, shuffle=False)\n",
    "    acc = eval_acc(model, test_loader)\n",
    "    seed_expl_res = []\n",
    "    for data in tqdm.tqdm(test_loader):\n",
    "        if data.y == 0: continue\n",
    "        try:\n",
    "            out, _, _, _, assignment = model(data.to(device), with_assignment=True)\n",
    "            c = out.argmax(-1).item()\n",
    "            masks = generate_hard_masks(assignment[:,c].cpu().detach())\n",
    "            d = {}\n",
    "            for m in masks:\n",
    "                d[m[0].item()] = calculate_fidelity(data, m[1], model)\n",
    "            seed_expl_res.append(pd.DataFrame(d))\n",
    "        except Exception as e:\n",
    "            print(e)\n",
    "            continue\n",
    "    l.append(sum(seed_expl_res)/len(seed_expl_res))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 123,
   "id": "6887ade9-cec0-4ca5-85b2-92d8072a37fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "l = [np.array(x) for x in l]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 124,
   "id": "71d3c43a-c5e1-4f81-9c6d-f9a7ff1b5eb8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.39599591, 0.37936604, 0.38892924, 0.37779093, 0.34179498,\n",
       "       0.33154018, 0.30713466, 0.27391674, 0.25240666, 0.23323247])"
      ]
     },
     "execution_count": 124,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(l, axis=0)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "id": "9e510259-b9bb-43c4-8287-60ebc8a24e9b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.09983972, 0.1066425 , 0.11667583, 0.11883211, 0.11872186,\n",
       "       0.12351298, 0.13286887, 0.14500764, 0.13102667, 0.09104237])"
      ]
     },
     "execution_count": 72,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.std(l, axis=0)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "dd9c6901-8073-4562-83a7-aabe878c756f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0.50</th>\n",
       "      <th>0.55</th>\n",
       "      <th>0.60</th>\n",
       "      <th>0.65</th>\n",
       "      <th>0.70</th>\n",
       "      <th>0.75</th>\n",
       "      <th>0.80</th>\n",
       "      <th>0.85</th>\n",
       "      <th>0.90</th>\n",
       "      <th>0.95</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Fidelity</th>\n",
       "      <td>0.365079</td>\n",
       "      <td>0.323413</td>\n",
       "      <td>0.292063</td>\n",
       "      <td>0.324206</td>\n",
       "      <td>0.248413</td>\n",
       "      <td>0.327778</td>\n",
       "      <td>0.274206</td>\n",
       "      <td>0.217063</td>\n",
       "      <td>0.18373</td>\n",
       "      <td>0.101587</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>InvFidelity</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.00000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>HFidelity</th>\n",
       "      <td>0.500000</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>0.50000</td>\n",
       "      <td>0.500000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                 0.50      0.55      0.60      0.65      0.70      0.75  \\\n",
       "Fidelity     0.365079  0.323413  0.292063  0.324206  0.248413  0.327778   \n",
       "InvFidelity  0.000000  0.000000  0.000000  0.000000  0.000000  0.000000   \n",
       "HFidelity    0.500000  0.500000  0.500000  0.500000  0.500000  0.500000   \n",
       "\n",
       "                 0.80      0.85     0.90      0.95  \n",
       "Fidelity     0.274206  0.217063  0.18373  0.101587  \n",
       "InvFidelity  0.000000  0.000000  0.00000  0.000000  \n",
       "HFidelity    0.500000  0.500000  0.50000  0.500000  "
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sum(l)/len(l)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f10ee04e-17b7-47a8-9008-ed4006da7f05",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch_geometric.data import Data, Batch\n",
    "import networkx as nx\n",
    "import matplotlib.pyplot as plt\n",
    "import math\n",
    "def plot_activations(batch_ids, batch, attr):\n",
    "    if type(batch_ids) != list:\n",
    "        batch_ids = [batch_ids]\n",
    "    num_ids = len(batch_ids)\n",
    "    cols = 5\n",
    "    rows = math.ceil(num_ids / cols)\n",
    "    \n",
    "    fig, axs = plt.subplots(rows, cols, figsize=(16*5, 8 * rows))\n",
    "    \n",
    "    if type(axs) != np.ndarray: axs = np.array([axs])\n",
    "    # Flatten axs if it's 2D to simplify indexing\n",
    "    axs = axs.flatten()\n",
    "\n",
    "    for i, batch_id in enumerate(batch_ids):\n",
    "        node_mask = batch.batch == batch_id  # Get nodes where batch == 0\n",
    "        if node_mask.float().sum() == 0: continue\n",
    "        node_indices = torch.nonzero(node_mask, as_tuple=True)[0]\n",
    "        \n",
    "        subgraph_edge_mask = (batch.batch[batch.edge_index[0]] == batch_id) & \\\n",
    "                             (batch.batch[batch.edge_index[1]] == batch_id)\n",
    "        subgraph_edges = batch.edge_index[:, subgraph_edge_mask]\n",
    "        \n",
    "        node_mapping = {old_idx.item(): new_idx for new_idx, old_idx in enumerate(node_indices)}\n",
    "        remapped_edges = torch.tensor([[node_mapping[e.item()] for e in edge] for edge in subgraph_edges.T])\n",
    "        \n",
    "        G = nx.Graph()\n",
    "        G.add_edges_from(remapped_edges.numpy())\n",
    "        \n",
    "        nx.set_node_attributes(G, {v: k for k, v in node_mapping.items()}, \"original_id\")\n",
    "        \n",
    "        node_colors = []\n",
    "        node_borders = []\n",
    "        \n",
    "        for node in G.nodes:\n",
    "            if attr[batch.batch==batch_id][node] == 1:\n",
    "                node_colors.append(\"lightblue\")  # Fill color\n",
    "                node_borders.append(\"red\")  # Border color for attr == 1\n",
    "            else:\n",
    "                node_colors.append(\"lightblue\")  # Fill color\n",
    "                node_borders.append(\"black\")  # Default border color\n",
    "        \n",
    "        \n",
    "        pos = nx.kamada_kawai_layout(G) \n",
    "        \n",
    "        nx.draw(\n",
    "            G, pos,\n",
    "            node_color=node_colors,\n",
    "            edgecolors=node_borders,  # Border colors\n",
    "            node_size=700,\n",
    "            with_labels=True,\n",
    "            ax = axs[i]\n",
    "        )\n",
    "        \n",
    "        axs[i].set_title(f\"Class = {batch.y[batch_id]}\")\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb716290-4c10-4b48-8447-0fe62ba08f04",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61e1b510-0cfc-4a5f-8250-9573b1d977ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = dataset[499]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60520a0d-8712-4dcb-bd2b-b07ed785bca3",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred, _, _, _, assignment = model(Batch.from_data_list([data]).to(device), with_assignment=True)\n",
    "print(pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "026f6619-01f8-453c-bcdb-2b2b83142a4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "c = pred.argmax(-1).item()\n",
    "print(c)\n",
    "node_mask = generate_hard_masks(assignment[:,c])[3][1]\n",
    "plot_activations([0], Batch.from_data_list([data]), node_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23283770-e4c4-40a2-93b4-87f8996e1d75",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f82a9bf9-6329-4afe-aeb0-63be48dceb17",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06e343e9-3771-424b-9cac-f9c03e6efe5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# masked_edge_index, _ = subgraph(node_mask == 0, edge_index=data.edge_index, num_nodes=data.x.size(0), relabel_nodes=True)\n",
    "# n_nodes = (node_mask == 0).sum()\n",
    "# new_data = Batch.from_data_list([Data(x=data.x[node_mask == 0], edge_index=masked_edge_index, num_nodes=n_nodes)[0])])\n",
    "\n",
    "\n",
    "\n",
    "edge_mask = (~node_mask.bool()[data.edge_index[0]]) & (~node_mask.bool()[data.edge_index[1]])\n",
    "new_data = Batch.from_data_list([Data(x=data.x, edge_index=data.edge_index[:,edge_mask], y=0)])\n",
    "masked_pred = model(new_data.to(device))[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0928903-d07d-4632-a6cf-c006d3fad544",
   "metadata": {},
   "outputs": [],
   "source": [
    "masked_pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b753ba4b-c058-47c8-a980-67cbb7bebf0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_activations([0], Batch.from_data_list([new_data]), torch.ones(new_data.x.shape[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdcea588-e080-4367-bd6d-baa52dc08469",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_data.x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4f0a065-8042-43cc-bf9a-c6cf5a160dcc",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:.conda-imp]",
   "language": "python",
   "name": "conda-env-.conda-imp-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
