{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Example code for running GISST\n",
    "This notebook serves as a quick tutorial for using the APIs in this repository. It provides an example of training a node-classification GISST model and getting node feature and edge interpretation results for a specific node. The other models and interpretation methods follow the same API design.\n",
    "\n",
    "Dataset splitting for cross validation, model evaluation, and hyperparameter tuning are intentionally omitted here for simplicity. For details, please refer to the source code such as `synthesize_graph.py` and `train_model.py`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "ename": "ImportError",
     "evalue": "attempted relative import with no known parent package",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mImportError\u001b[0m                               Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[1], line 3\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\n\u001b[1;32m      2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch_geometric\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdata\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Data\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msynthesize_graph\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m syn_ba_house, syn_node_feat\n\u001b[1;32m      4\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msig\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msynthetic_graph\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m build_house, build_cycle\n\u001b[1;32m      5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msig\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpyg_utils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m get_pyg_edge_index\n",
      "File \u001b[0;32m~/Desktop/PostDOC/Steve/code/EvaluationOfSEGNNs-main/models/GISST/synthesize_graph.py:10\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\n\u001b[1;32m      9\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch_geometric\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdata\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Data\n\u001b[0;32m---> 10\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m sig\u001b[38;5;241m.\u001b[39mutils \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01msynthetic_graph\u001b[39;00m  \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01msg\u001b[39;00m\n\u001b[1;32m     11\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m sig\u001b[38;5;241m.\u001b[39mutils \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mgraph_utils\u001b[39;00m\n\u001b[1;32m     12\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m sig\u001b[38;5;241m.\u001b[39mutils \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mpyg_utils\u001b[39;00m\n",
      "\u001b[0;31mImportError\u001b[0m: attempted relative import with no known parent package"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch_geometric.data import Data\n",
    "from synthesize_graph import syn_ba_house, syn_node_feat\n",
    "from sig.utils.synthetic_graph import build_house, build_cycle\n",
    "from sig.utils.pyg_utils import get_pyg_edge_index\n",
    "from sig.nn.loss.regularization_loss import reg_sig_loss\n",
    "from sig.nn.loss.classification_loss import cross_entropy_loss\n",
    "from sig.nn.models.sigcn import SIGCN\n",
    "from sig.explainers.sig_explainer import SIGExplainer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. Prepare the node classification graph\n",
    "First prepare a graph using `networkx`. Store the node class labels as a list and the node feature matrix as a numpy array."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph, labels, _ = syn_ba_house()\n",
    "node_feat, _ = syn_node_feat(labels, sigma_scale=0.1)\n",
    "num_class = len(set(labels))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "graph: <class 'networkx.classes.graph.Graph'>\n",
      "labels: <class 'list'>\n",
      "node_feat: <class 'numpy.ndarray'>\n"
     ]
    }
   ],
   "source": [
    "print('graph: {}'.format(type(graph)))\n",
    "print('labels: {}'.format(type(labels)))\n",
    "print('node_feat: {}'.format(type(node_feat)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "graph num nodes: 700\n",
      "graph num edges: 2238\n",
      "labels len: 700\n",
      "node_feat shape: (700, 50)\n"
     ]
    }
   ],
   "source": [
    "print('graph num nodes: {}'.format(graph.number_of_nodes()))\n",
    "print('graph num edges: {}'.format(graph.number_of_edges()))\n",
    "print('labels len: {}'.format(len(labels)))\n",
    "print('node_feat shape: {}'.format(node_feat.shape))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. Format the graph data for PyTorch Geometric\n",
    "Under the hood, the APIs are implemented using PyTorch Geometric. Hence the graph data should be converted into an instance of `torch_geometric.data.Data` (https://pytorch-geometric.readthedocs.io/en/latest/notes/introduction.html#data-handling-of-graphs). As a side note, for graph-level classification, the graphs should be converted into an instance of `torch_geometric.data.Dataset` or a list of `torch_geometric.data.Data` for mini-batching via a data loader (https://pytorch-geometric.readthedocs.io/en/latest/notes/introduction.html#mini-batches)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = Data()\n",
    "data.x = torch.Tensor(node_feat)\n",
    "data.y = torch.LongTensor(labels)\n",
    "data.edge_index = get_pyg_edge_index(graph)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data(edge_index=[2, 4476], x=[700, 50], y=[700])\n"
     ]
    }
   ],
   "source": [
    "print(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. Initialize the model and optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = SIGCN(\n",
    "    input_size=data.x.size(1),\n",
    "    output_size=num_class,\n",
    "    hidden_conv_sizes=(8, 8), \n",
    "    hidden_dropout_probs=(0, 0)\n",
    ")\n",
    "optimizer = torch.optim.Adam(\n",
    "    model.parameters(), \n",
    "    lr=0.05, \n",
    "    weight_decay=0.001\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4. Train the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train():\n",
    "    model.train()\n",
    "    optimizer.zero_grad()\n",
    "    out, x_prob, edge_prob = model(\n",
    "        data.x, \n",
    "        data.edge_index, \n",
    "        return_probs=True\n",
    "    )\n",
    "    loss_x_l1, \\\n",
    "    loss_x_ent, \\\n",
    "    loss_edge_l1, \\\n",
    "    loss_edge_ent = reg_sig_loss(\n",
    "        x_prob, \n",
    "        edge_prob, \n",
    "        coeffs={\n",
    "            'x_l1': 0.01,\n",
    "            'x_ent': 0.05,\n",
    "            'edge_l1': 0.01,\n",
    "            'edge_ent': 0.05\n",
    "        }\n",
    "    )\n",
    "    loss = cross_entropy_loss(out, data.y) \\\n",
    "        + loss_x_l1 + loss_x_ent + loss_edge_l1 + loss_edge_ent\n",
    "    loss.backward(retain_graph=True)\n",
    "    optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished training for 10 epochs\n",
      "Finished training for 20 epochs\n",
      "Finished training for 30 epochs\n",
      "Finished training for 40 epochs\n",
      "Finished training for 50 epochs\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(1, 51):\n",
    "    train()\n",
    "    if epoch % 10 == 0:\n",
    "        print('Finished training for %d epochs' % epoch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5. Run the explainer on a node"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "explainer = SIGExplainer(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 5.1 Node feature and edge probability"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "node_feat_prob, edge_prob = explainer.explain_node(\n",
    "    node_index=15,\n",
    "    x=data.x,\n",
    "    edge_index=data.edge_index,\n",
    "    use_grad=False,\n",
    "    y=None,\n",
    "    loss_fn=None,\n",
    "    take_abs=False,\n",
    "    pred_for_grad=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "node_feat_prob shape: torch.Size([50])\n",
      "tensor([0.6073, 0.5938, 0.5674, 0.6318, 0.6238, 0.6164, 0.6025, 0.5960, 0.5018,\n",
      "        0.5844, 0.6137, 0.6091, 0.5878, 0.5704, 0.6085, 0.5630, 0.6246, 0.5523,\n",
      "        0.6082, 0.5521, 0.5457, 0.6023, 0.5746, 0.6129, 0.5883, 0.5937, 0.4954,\n",
      "        0.6117, 0.6076, 0.6095, 0.5415, 0.5882, 0.5887, 0.5526, 0.5419, 0.5848,\n",
      "        0.5249, 0.5669, 0.5302, 0.6130, 0.5383, 0.5244, 0.5030, 0.5338, 0.5114,\n",
      "        0.5007, 0.5078, 0.5055, 0.4957, 0.5635], grad_fn=<ClampBackward>)\n"
     ]
    }
   ],
   "source": [
    "print('node_feat_prob shape: {}'.format(node_feat_prob.shape))\n",
    "print(node_feat_prob)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "edge_prob shape: torch.Size([4476])\n",
      "tensor([1.0000e-05, 1.0000e-05, 1.0000e-05,  ..., 0.0000e+00, 0.0000e+00,\n",
      "        0.0000e+00], grad_fn=<IndexPutBackward>)\n"
     ]
    }
   ],
   "source": [
    "print('edge_prob shape: {}'.format(edge_prob.shape))\n",
    "print(edge_prob)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 5.2 Node feature and edge probability gradient"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "node_feat_score, edge_score = explainer.explain_node(\n",
    "    node_index=15,\n",
    "    x=data.x,\n",
    "    edge_index=data.edge_index,\n",
    "    use_grad=True,\n",
    "    y=data.y,\n",
    "    loss_fn=cross_entropy_loss,\n",
    "    take_abs=False,\n",
    "    pred_for_grad=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "node_feat_score shape: torch.Size([50])\n",
      "tensor([ 1.8113e-07,  1.7804e-07,  1.3067e-07,  1.4501e-07,  9.5189e-08,\n",
      "         2.6519e-07,  8.1901e-08,  9.8601e-08, -3.8628e-09,  1.9757e-07,\n",
      "         1.3460e-07,  1.6377e-07,  1.0515e-07,  1.6313e-07,  1.9874e-07,\n",
      "         4.8434e-08,  2.7607e-07,  1.5328e-07,  3.0147e-07,  1.1535e-07,\n",
      "         2.2737e-08,  9.6002e-08,  1.7883e-07,  1.8335e-07,  1.4471e-07,\n",
      "         3.5769e-08, -1.9045e-07,  9.6916e-08,  2.2862e-07,  2.6558e-07,\n",
      "         3.1914e-08,  7.0649e-08,  1.7684e-07,  2.5435e-07,  1.2272e-07,\n",
      "         2.2647e-08, -7.0268e-09, -7.1459e-08,  8.7770e-08,  2.1514e-07,\n",
      "         4.7412e-10, -4.0880e-09, -1.6669e-09, -1.2434e-08,  4.4608e-09,\n",
      "        -1.2842e-08,  1.9108e-09, -1.0655e-08, -1.7267e-08, -6.8925e-09])\n"
     ]
    }
   ],
   "source": [
    "print('node_feat_score shape: {}'.format(node_feat_score.shape))\n",
    "print(node_feat_score)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "edge_score shape: torch.Size([4476])\n",
      "tensor([-4.8283e-11, -4.0702e-11, -4.8302e-11,  ...,  0.0000e+00,\n",
      "         0.0000e+00,  0.0000e+00])\n"
     ]
    }
   ],
   "source": [
    "print('edge_score shape: {}'.format(edge_score.shape))\n",
    "print(edge_score)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
