{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Instructions\n",
    "Run this notebook from top to bottom to obtain both quantitative and qualitative results.\n",
    "\n",
    "If you want to use our pretrained weights to get the results, set train_explainer to False. If you want to retrain the explainer set this to true. This will overwrite our original weights.\n",
    "\n",
    "The variable `iterations` gives the number of times the experiments are repeated and averaged over. Both the original paper and our report used 10 iterations, however fewer will result in faster results. Using more than 10 iterations is not possible."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_explainer = False\n",
    "iterations = 10\n",
    "import torch\n",
    "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Imports \n",
    "import sys\n",
    "sys.path.append('./codes/')\n",
    "import time\n",
    "from config import args\n",
    "from tqdm import tqdm\n",
    "from utils import *\n",
    "from models import GCN2 as GCN\n",
    "from metrics import *\n",
    "import numpy as np\n",
    "from Extractor import Extractor\n",
    "from Explainer import Explainer\n",
    "from scipy.sparse import coo_matrix,csr_matrix\n",
    "import networkx as nx\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.metrics import roc_auc_score\n",
    "import torch\n",
    "import torch.optim\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set settings according to experiment/dataset\n",
    "args.dataset='syn1'\n",
    "args.elr = 0.003\n",
    "args.eepochs = 10\n",
    "args.coff_size = 0.05\n",
    "args.budget = -1\n",
    "args.coff_ent = 1.0\n",
    "\n",
    "save_map = \"LISA_TEST_LOGS/BA_SHAPES/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plots the explanation given by the explainer model\n",
    "def plot(node, label, iteration):\n",
    "    after_adj_dense = explainer.masked_adj.cpu().detach().numpy()\n",
    "    after_adj = coo_matrix(after_adj_dense)\n",
    "    \n",
    "    rcd = np.concatenate([np.expand_dims(after_adj.row,-1),np.expand_dims(after_adj.col,-1),np.expand_dims(after_adj.data,-1)],-1)\n",
    "    pos_edges = []\n",
    "    filter_edges = []\n",
    "    edge_weights = after_adj.data\n",
    "    sorted_edge_weights = np.sort(edge_weights)\n",
    "    thres_index = max(int(edge_weights.shape[0]-12),0)\n",
    "    thres = sorted_edge_weights[thres_index]\n",
    "    filter_thres_index = min(thres_index,max(int(edge_weights.shape[0]-edge_weights.shape[0]/2),edge_weights.shape[0]-100))\n",
    "    filter_thres = sorted_edge_weights[filter_thres_index]\n",
    "    filter_nodes =set()\n",
    "\n",
    "    for r,c,d in rcd:\n",
    "        r = int(r)\n",
    "        c = int(c)\n",
    "        if d>=thres:\n",
    "            pos_edges.append((r,c))\n",
    "        if d>filter_thres:\n",
    "            filter_edges.append((r,c))\n",
    "            filter_nodes.add(r)\n",
    "            filter_nodes.add(c)\n",
    "\n",
    "    num_nodes = sub_adj.shape[0]\n",
    "    G = nx.Graph()\n",
    "    G.add_nodes_from(range(num_nodes))\n",
    "    G.add_edges_from(filter_edges)\n",
    "\n",
    "    for cc in nx.connected_components(G):\n",
    "        if 0 in cc:\n",
    "            G = G.subgraph(cc).copy()\n",
    "            break\n",
    "\n",
    "    pos_edges = [(u, v) for (u, v) in pos_edges if u in G.nodes() and v in G.nodes()]\n",
    "    pos = nx.kamada_kawai_layout(G)\n",
    "\n",
    "    colors = ['orange', 'red', 'green', 'blue', 'maroon', 'brown', 'darkslategray', 'paleturquoise', 'darksalmon',\n",
    "              'slategray', 'mediumseagreen', 'mediumblue', 'orchid', ]\n",
    "    if args.dataset=='syn3':\n",
    "        colors = ['orange', 'blue']\n",
    "\n",
    "\n",
    "    if args.dataset=='syn4':\n",
    "        colors = ['orange', 'black','black','black','blue']\n",
    "\n",
    "    labels = label\n",
    "    max_label = np.max(labels)+1\n",
    "\n",
    "    nmb_nodes = after_adj_dense.shape[0]\n",
    "    label2nodes= []\n",
    "    for i in range(max_label):\n",
    "    \tlabel2nodes.append([])\n",
    "    for i in range(nmb_nodes):\n",
    "    \tlabel2nodes[labels[i]].append(i)\n",
    "\n",
    "    for i in range(max_label):\n",
    "        node_filter = []\n",
    "        for j in range(len(label2nodes[i])):\n",
    "            if label2nodes[i][j] in G.nodes():\n",
    "                node_filter.append(label2nodes[i][j])\n",
    "        nx.draw_networkx_nodes(G, pos,\n",
    "                               nodelist=node_filter,\n",
    "                               node_color=colors[i % len(colors)],\n",
    "                               node_size=500)\n",
    "\n",
    "    nx.draw_networkx_nodes(G, pos,\n",
    "                           nodelist=[0],\n",
    "                           node_color=colors[labels[0]],\n",
    "                           node_size=1000)\n",
    "\n",
    "    nx.draw_networkx_edges(G, pos, width=7, alpha=0.5, edge_color='grey')\n",
    "\n",
    "    nx.draw_networkx_edges(G, pos,\n",
    "                           edgelist=pos_edges,\n",
    "                           width=7, alpha=0.5)\n",
    "\n",
    "\n",
    "    plt.axis('off')\n",
    "    plt.show()\n",
    "    plt.clf()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train the explainer model. Saves best weights to be used for inference\n",
    "def train(iteration):\n",
    "    t0 = args.coff_t0\n",
    "    t1 = args.coff_te\n",
    "    epochs = args.eepochs\n",
    "    model.eval()\n",
    "    explainer.train()\n",
    "    best_auc = 0\n",
    "    for epoch in range(epochs):\n",
    "        train_accs = []\n",
    "        loss = 0\n",
    "        pred_loss = 0\n",
    "        lap_loss = 0\n",
    "        tmp = float(1.0*np.power(0.05,epoch/epochs))\n",
    "        for i in range(len(allnodes)):\n",
    "            # Do not train the clasifier model\n",
    "            with torch.no_grad():\n",
    "                output = model((sub_features[i],sub_support_tensors[i]), training=False)\n",
    "            train_acc = accuracy(output, sub_label_tensors[i])\n",
    "            train_accs.append(float(train_acc))\n",
    "            pred_label = torch.argmax(output, 1)\n",
    "            \n",
    "            # Make explainer prediction\n",
    "            x = sub_features[i]\n",
    "            adj = sub_adjs[i]\n",
    "            nodeid = 0\n",
    "            embed = torch.Tensor(sub_embeds[i])\n",
    "            pred = explainer((x,adj,nodeid,embed,tmp), training=True)\n",
    "            \n",
    "            # Calculate loss   \n",
    "            l,pl,ll = explainer.loss(pred, pred_label, sub_label_tensor, 0)\n",
    "            loss = loss + l\n",
    "            pred_loss = pred_loss + pl\n",
    "            lap_loss = lap_loss + ll\n",
    "\n",
    "        # Gradient update\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        torch.nn.utils.clip_grad_value_(explainer.parameters(), clip_value_max)\n",
    "        optimizer.step()\n",
    "        \n",
    "        global reals\n",
    "        global preds\n",
    "        reals = []\n",
    "        preds = []\n",
    "        for node in allnodes:\n",
    "            h = explain_test(explainer, node, iteration, needplot=True)\n",
    "        auc = roc_auc_score(reals, preds)\n",
    "        explainer.train()\n",
    "        \n",
    "        if auc > best_auc:\n",
    "            best_auc = auc\n",
    "            torch.save(explainer.state_dict(), f'model_weights/BA-shapes_BEST.pt')\n",
    "            torch.save(explainer.state_dict(), save_map + str(iteration) + \"/\" + 'BA-shapes_BEST.pt')\n",
    "\n",
    "    torch.save(explainer.state_dict(), f'model_weights/BA-shapes_LAST.pt')\n",
    "    torch.save(explainer.state_dict(), save_map + str(iteration) + \"/\" + 'BA-shapes_LAST.pt')\n",
    "\n",
    "reals = []\n",
    "preds = []\n",
    "\n",
    "def acc(sub_adj,sub_edge_label):\n",
    "    real = []\n",
    "    pred = []\n",
    "    sub_edge_label = sub_edge_label.todense()\n",
    "    mask = explainer.masked_adj.cpu().detach().numpy()\n",
    "    for r,c in list(zip(sub_adj.row,sub_adj.col)):\n",
    "        d = sub_edge_label[r,c] + sub_edge_label[c,r]\n",
    "        if d==0:\n",
    "            real.append(0)\n",
    "        else:\n",
    "            real.append(1)\n",
    "        pred.append(mask[r][c]+mask[c][r])\n",
    "    reals.extend(real)\n",
    "    preds.extend(pred)\n",
    "\n",
    "    if len(np.unique(real))==1 or len(np.unique(pred))==1:\n",
    "        return -1\n",
    "    return roc_auc_score(real,pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def explain_test(explainer,node,iteration,needplot=True):\n",
    "    newid = remap[node]\n",
    "    sub_adj, sub_feature, sub_embed, sub_label, sub_edge_label =  sub_adjs[newid],sub_features[newid],sub_embeds[newid],sub_labels[newid],sub_edge_labels[newid]\n",
    "    \n",
    "    explainer.eval()\n",
    "    nodeid = 0\n",
    "    tik = time.time()\n",
    "    explainer((sub_feature,sub_adj,nodeid,sub_embed,1.0))\n",
    "    tok = time.time()\n",
    "    label = np.argmax(sub_label,-1)\n",
    "    if needplot:\n",
    "        plot(node,label,iteration)\n",
    "    acc(sub_adj,sub_edge_label)\n",
    "    return tok - tik"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished inference iteration 0 with an AUC score of 0.99965378        taking average time of 3.815 milliseconds\n",
      "Finished inference iteration 1 with an AUC score of 0.99884123        taking average time of 3.954 milliseconds\n",
      "Finished inference iteration 2 with an AUC score of 0.99967193        taking average time of 3.9560000000000004 milliseconds\n",
      "Finished inference iteration 3 with an AUC score of 0.99946546        taking average time of 3.467 milliseconds\n",
      "Finished inference iteration 4 with an AUC score of 0.99965436        taking average time of 4.225 milliseconds\n",
      "Finished inference iteration 5 with an AUC score of 0.99949451        taking average time of 4.0280000000000005 milliseconds\n",
      "Finished inference iteration 6 with an AUC score of 0.99935677        taking average time of 3.544 milliseconds\n",
      "Finished inference iteration 7 with an AUC score of 0.99967572        taking average time of 4.186999999999999 milliseconds\n",
      "Finished inference iteration 8 with an AUC score of 0.99862236        taking average time of 4.051 milliseconds\n",
      "Finished inference iteration 9 with an AUC score of 0.99938048        taking average time of 3.669 milliseconds\n",
      "\n",
      "Experiments are finished;      \n",
      " the final average auc score over 10 runs is: 0.999      \n",
      " the final average inference over 10 time is: 3.669\n"
     ]
    }
   ],
   "source": [
    "aucs, all_times = [], []\n",
    "# Run the experiments\n",
    "for iteration in range(iterations):\n",
    "\n",
    "    with open('./dataset/' + args.dataset + '.pkl', 'rb') as fin:\n",
    "        adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask, edge_label_matrix  = pkl.load(fin)\n",
    "    # Set up dataset\n",
    "    adj = csr_matrix(adj)\n",
    "    support = preprocess_adj(adj)\n",
    "\n",
    "    features_tensor = torch.tensor(features).type(torch.float32)\n",
    "    i = torch.LongTensor([*support[0]])\n",
    "    v = torch.FloatTensor([*support[1]])\n",
    "    support_tensor = torch.sparse.FloatTensor(i.t(), v, torch.Size([*support[2]]))\n",
    "    support_tensor = support_tensor.type(torch.float32)\n",
    "    \n",
    "    # Create models \n",
    "    model = GCN(input_dim=features.shape[1], output_dim=y_train.shape[1], device=device)\n",
    "    model.to(device)\n",
    "    model.load_state_dict(torch.load('model_weights/GCN_syn1_BEST.pt'))\n",
    "\n",
    "    explainer = Explainer(model=model)\n",
    "    explainer.to(device)\n",
    "    embeds = model.embedding((features_tensor,support_tensor)).cpu().detach().numpy()\n",
    "    \n",
    "    # Set up which part of the dataset to use for experiments\n",
    "    all_label = np.logical_or(y_train,np.logical_or(y_val,y_test))\n",
    "    single_label = np.argmax(all_label,axis=-1)\n",
    "    hops = len(args.hiddens.split('-'))\n",
    "    extractor = Extractor(adj,features,edge_label_matrix,embeds,all_label,hops)\n",
    "    if args.setting==1:\n",
    "        if args.dataset=='syn3':\n",
    "            allnodes = [i for i in range(511,871,6)]\n",
    "        elif args.dataset=='syn4':\n",
    "            allnodes = [i for i in range(511,800,1)]\n",
    "        else:\n",
    "            allnodes = [i for i in range(400,700,5)] # setting from their original paper\n",
    "    elif args.setting==2:\n",
    "        allnodes = [i for i in range(single_label.shape[0]) if single_label[i] ==1]\n",
    "    elif args.setting==3:\n",
    "        if args.dataset == 'syn2':\n",
    "            allnodes = [i for i in range(single_label.shape[0]) if single_label[i] != 0 and single_label[i] != 4]\n",
    "        else:\n",
    "            allnodes = [i for i in range(single_label.shape[0]) if single_label[i] != 0]\n",
    "    \n",
    "    optimizer = torch.optim.Adam(explainer.parameters(), lr=args.elr)\n",
    "    clip_value_min = -2.0\n",
    "    clip_value_max = 2.0\n",
    "    \n",
    "    # More data preparation\n",
    "    sub_support_tensors = []\n",
    "    sub_label_tensors = []\n",
    "    sub_features = []\n",
    "    sub_embeds = []\n",
    "    sub_adjs = []\n",
    "    sub_edge_labels = []\n",
    "    sub_labels = []\n",
    "    remap = {}\n",
    "    for node in allnodes:\n",
    "        sub_adj,sub_feature, sub_embed, sub_label,sub_edge_label_matrix = extractor.subgraph(node)\n",
    "        remap[node]=len(sub_adjs)\n",
    "        sub_support = preprocess_adj(sub_adj)\n",
    "        i = torch.LongTensor([*sub_support[0]])\n",
    "        v = torch.FloatTensor([*sub_support[1]])\n",
    "        sub_support_tensor = torch.sparse.FloatTensor(i.t(), v, torch.Size([*sub_support[2]])).type(torch.float32) \n",
    "        sub_label_tensor = torch.Tensor(sub_label).type(torch.float32)\n",
    "\n",
    "        sub_adjs.append(sub_adj)\n",
    "        sub_features.append(torch.Tensor(sub_feature).type(torch.float32))\n",
    "        sub_embeds.append(sub_embed)\n",
    "        sub_labels.append(sub_label)\n",
    "        sub_edge_labels.append(sub_edge_label_matrix)\n",
    "        sub_label_tensors.append(sub_label_tensor)\n",
    "        sub_support_tensors.append(sub_support_tensor)\n",
    "    best_auc = 0.0\n",
    "    \n",
    "    # Set seeds\n",
    "    random.seed(iteration)\n",
    "    np.random.seed(iteration)\n",
    "    torch.manual_seed(iteration)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed(iteration)\n",
    "        torch.cuda.manual_seed_all(iteration)\n",
    "        torch.backends.cudnn.deterministic = True\n",
    "        torch.backends.cudnn.benchmark = False\n",
    "    \n",
    "    # Train explainer model if needed\n",
    "    if train_explainer:\n",
    "        train(iteration)\n",
    "    explainer.load_state_dict(torch.load(save_map + str(iteration) + \"/\" + 'BA-shapes_BEST.pt'))\n",
    "    \n",
    "    # Make predictions and present score\n",
    "    reals = []\n",
    "    preds = []\n",
    "    ts = []\n",
    "    for node in allnodes:\n",
    "        t = explain_test(explainer,node, iteration, needplot=False)\n",
    "        auc = roc_auc_score(reals, preds)\n",
    "        ts.append(t)\n",
    "    avg_time = np.mean(ts)\n",
    "    all_times.append(avg_time)\n",
    "    aucs.append(auc)\n",
    "    print(f'Finished inference iteration {iteration} with an AUC score of {round(auc,8)}\\\n",
    "        taking average time of {round(avg_time, 6)*1000} milliseconds')\n",
    "print()\n",
    "print(f'Experiments are finished;\\\n",
    "      \\n the final average auc score over {iterations} runs is: {round(np.mean(aucs), 3)}\\\n",
    "      \\n the final average inference over {iterations} time is: {round(np.mean(avg_time), 6) *1000 }')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run cell below to obtain qualitative results\n",
    "Set `plot_graph_num` to the number of graphs to be shown. For this experiment a 'house' motif is the ground truth."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_graph_num = 50\n",
    "for i in range(plot_graph_num):\n",
    "    explain_test(explainer, allnodes[i], iteration, needplot=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
