{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "phantom-newman",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import pandas as pd\n",
    "\n",
    "from torch_geometric.data import DataLoader, Data\n",
    "from torch_geometric.datasets import PPI\n",
    "from torch_geometric.utils import remove_isolated_nodes\n",
    "\n",
    "from dig.sslgraph.utils import Encoder\n",
    "from dig.sslgraph.dataset import get_node_dataset\n",
    "\n",
    "from downstream import MLP, EndtoEnd, train_MLP\n",
    "from dig.xgraph.evaluation import XCollector\n",
    "\n",
    "device = torch.device(\"cuda:1\" if torch.cuda.is_available() else torch.device(\"cpu\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "hungry-stuart",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_task(idx):\n",
    "    def transform(data):\n",
    "        return Data(x=data.x, edge_index=data.edge_index, y=data.y[:, idx])\n",
    "    return transform\n",
    "\n",
    "def get_task_rm_iso(idx):\n",
    "    def transform(data):\n",
    "        edge_index, _, mask = remove_isolated_nodes(data.edge_index, num_nodes=data.x.shape[0])\n",
    "        return Data(x=data.x[mask], edge_index=edge_index, y=data.y[mask, idx])\n",
    "    return transform\n",
    "    \n",
    "ppi = PPI('node_dataset/ppi/', transform=get_task_rm_iso(0))\n",
    "loader = DataLoader(ppi, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "cosmetic-conclusion",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "encoder = Encoder(feat_dim=ppi[0].x.shape[1], hidden_dim=600, \n",
    "                  n_layers=2, gnn='gcn', node_level=True, graph_level=False)\n",
    "encoder.load_state_dict(torch.load('ckpts_model/ppi_pretrain_grace600_h2.pth', map_location='cpu'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "saving-engineering",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tagexplainer import TAGExplainer, MLPExplainer\n",
    "enc_explainer = TAGExplainer(encoder, embed_dim=600, device=device, explain_graph=False, \n",
    "                              grad_scale=0.1, coff_size=0.05, coff_ent=0.002, loss_type='JSE')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "surprising-daily",
   "metadata": {},
   "source": [
    "#### To train the explainer, uncomment the following cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "quality-kernel",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 390/390 [00:31<00:00, 12.25it/s, loss=-.24, log=0.4901, 0.9028, 0.1317, 0.1679]   \n",
      "100%|██████████| 345/345 [00:27<00:00, 12.40it/s, loss=-.262, log=0.4368, 0.8754, -1.0137, 0.1058]  \n",
      "100%|██████████| 566/566 [01:02<00:00,  9.04it/s, loss=-.359, log=2.0513, 0.9053, -1.1522, 0.1367]  \n",
      "100%|██████████| 585/585 [01:10<00:00,  8.27it/s, loss=-.465, log=5.0124, 0.9107, -1.9695, 0.1184]  \n",
      "100%|██████████| 395/395 [00:33<00:00, 11.84it/s, loss=0.133, log=3.8440, 0.9185, -2.5551, 0.1190] \n",
      "100%|██████████| 256/256 [00:17<00:00, 14.82it/s, loss=-.32, log=3.0553, 0.8555, -1.9011, 0.1751]  \n",
      "100%|██████████| 456/456 [00:41<00:00, 10.87it/s, loss=-.127, log=3.7548, 0.8240, -2.5014, 0.0741] \n",
      "100%|██████████| 622/622 [01:17<00:00,  8.05it/s, loss=-.397, log=10.1655, 0.9261, -4.8972, 0.0603]  \n",
      "100%|██████████| 148/148 [00:08<00:00, 18.12it/s, loss=-.394, log=10.5567, 0.9312, -1.6987, 0.2227]\n",
      "100%|██████████| 828/828 [02:19<00:00,  5.94it/s, loss=-.27, log=13.6232, 0.9427, -9.1332, 0.0388]  \n",
      "100%|██████████| 601/601 [01:09<00:00,  8.67it/s, loss=-.405, log=6.4849, 0.8588, -6.5346, 0.0813]  \n",
      "100%|██████████| 470/470 [00:45<00:00, 10.34it/s, loss=-.355, log=9.2056, 0.9030, -5.9569, 0.0418]   \n",
      "100%|██████████| 455/455 [00:44<00:00, 10.18it/s, loss=-.0371, log=8.9175, 0.9016, -5.8061, 0.0773]  \n",
      "100%|██████████| 870/870 [02:19<00:00,  6.22it/s, loss=-.411, log=18.2061, 0.9869, -8.4921, 0.0640]  \n",
      "100%|██████████| 699/699 [01:40<00:00,  6.96it/s, loss=-.128, log=17.9049, 0.9764, -13.0799, 0.0216]  \n",
      "100%|██████████| 582/582 [01:03<00:00,  9.14it/s, loss=-.392, log=12.9812, 0.9487, -9.8769, 0.0544]   \n",
      "100%|██████████| 663/663 [01:28<00:00,  7.47it/s, loss=-.835, log=6.7467, 0.8425, -13.8651, 0.0107]  \n",
      "100%|██████████| 704/704 [01:41<00:00,  6.97it/s, loss=-.188, log=18.0522, 0.9812, -15.0068, 0.0208]   \n",
      "100%|██████████| 791/791 [01:54<00:00,  6.89it/s, loss=-.458, log=16.1347, 0.9659, -11.5103, 0.0328]  \n",
      "100%|██████████| 756/756 [01:50<00:00,  6.83it/s, loss=-.424, log=33.1915, 0.9987, -12.8400, 0.0121]  \n"
     ]
    }
   ],
   "source": [
    "# enc_explainer.train_explainer_node(loader, batch_size=4, lr=5e-6, epochs=1)\n",
    "# torch.save(enc_explainer.state_dict(), 'ckpts_explainer/explain_ppi_grace.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "aerial-wound",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "state_dict = torch.load('ckpts_explainer/explain_ppi_grace.pt')\n",
    "enc_explainer.load_state_dict(state_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "proper-prime",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def get_results(task_id, top_k):\n",
    "    ppi = PPI('node_dataset/ppi/', transform=get_task_rm_iso(task_id))\n",
    "    loader = DataLoader(ppi, 1)\n",
    "\n",
    "    mlp_model = MLP(num_layer = 2, emb_dim = 600, hidden_dim = 600, out_dim = 2)\n",
    "    mlp_model.load_state_dict(torch.load('ckpts_model/downstream_ppi%d_grace600.pth'%task_id, map_location='cpu'))\n",
    "    mlp_explainer = MLPExplainer(mlp_model, device)\n",
    "\n",
    "    x_collector = XCollector()\n",
    "    for i, data in enumerate(loader):\n",
    "        for j, node_idx in enumerate(torch.where(data.y)[0]):\n",
    "            data.to(device)\n",
    "            walks, masks, related_preds = \\\n",
    "                enc_explainer(data, mlp_explainer, node_idx=node_idx, top_k=top_k)\n",
    "            fidelity = related_preds[0]['origin'] - related_preds[0]['maskout']\n",
    "\n",
    "            print(f'explain graph {i} node {node_idx}'+' fidelity %.4f'%fidelity, end='\\r')\n",
    "            x_collector.collect_data(masks, related_preds)\n",
    "\n",
    "    fid, fid_std = x_collector.fidelity\n",
    "    spa, spa_std = x_collector.sparsity\n",
    "\n",
    "    print()\n",
    "    print(f'Fidelity: {fid:.4f} ±{fid_std:.4f}\\n'\n",
    "          f'Sparsity: {spa:.4f} ±{spa_std:.4f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "filled-visitor",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "explain graph 19 node 3020 fidelity 0.05188\n",
      "Fidelity: 0.2694 ±0.3878\n",
      "Sparsity: 0.8545 ±0.1814\n"
     ]
    }
   ],
   "source": [
    "get_results(task_id=0, top_k=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "comfortable-athletics",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "explain graph 19 node 3020 fidelity -0.0007\n",
      "Fidelity: 0.3038 ±0.4385\n",
      "Sparsity: 0.8671 ±0.1770\n"
     ]
    }
   ],
   "source": [
    "get_results(task_id=1, top_k=120)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "worldwide-vanilla",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "explain graph 19 node 3015 fidelity 0.00002\n",
      "Fidelity: 0.5042 ±0.4782\n",
      "Sparsity: 0.8444 ±0.2278\n"
     ]
    }
   ],
   "source": [
    "get_results(task_id=2, top_k=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "respective-michael",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "explain graph 19 node 3015 fidelity 0.30530\n",
      "Fidelity: 0.2763 ±0.4332\n",
      "Sparsity: 0.8541 ±0.2171\n"
     ]
    }
   ],
   "source": [
    "get_results(task_id=3, top_k=400)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "major-czech",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "explain graph 19 node 3015 fidelity 0.01282\n",
      "Fidelity: 0.3234 ±0.4460\n",
      "Sparsity: 0.8547 ±0.2490\n"
     ]
    }
   ],
   "source": [
    "get_results(task_id=4, top_k=400)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "enhanced-career",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
