{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('./codes/forgraph/')\n",
    "from config import args\n",
    "from sklearn.metrics import roc_auc_score\n",
    "from models import GCN2 as GCN\n",
    "from metrics import *\n",
    "import numpy as np\n",
    "from Explainer import Explainer\n",
    "from scipy.sparse import coo_matrix,csr_matrix\n",
    "import networkx as nx\n",
    "skip = 5\n",
    "topk = 5\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import pickle as pkl\n",
    "import torch \n",
    "import torch.nn as nn\n",
    "import torch.optim\n",
    "\n",
    "import time\n",
    "save_map = \"LISA_TEST_LOGS/BA_2MOTIF/\"\n",
    "\n",
    "args.elr = 0.00015\n",
    "args.coff_t0=5.0\n",
    "args.coff_t0=0.5\n",
    "args.coff_size = 0.01\n",
    "args.coff_ent = 0\n",
    "# args.concat = True\n",
    "# args.bn = True\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def acc(adj,insert):\n",
    "    mask = explainer.masked_adj.cpu().detach().numpy()\n",
    "    adj = coo_matrix(adj)\n",
    "    for r,c in list(zip(adj.row,adj.col)):\n",
    "        if r>=insert and r<insert+skip and c>=insert and c<insert+skip:\n",
    "            reals.append(1)\n",
    "        else:\n",
    "            reals.append(0)\n",
    "        preds.append(mask[r][c])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot(adj,label,graphid, iteration):\n",
    "    after_adj_dense = explainer.masked_adj.detach().numpy()\n",
    "\n",
    "    after_adj = coo_matrix(after_adj_dense)\n",
    "\n",
    "    rcd = np.concatenate(\n",
    "        [np.expand_dims(after_adj.row, -1), np.expand_dims(after_adj.col, -1), np.expand_dims(after_adj.data, -1)], -1)\n",
    "    pos_edges = []\n",
    "    filter_edges = []\n",
    "\n",
    "    edge_weights = np.triu(after_adj_dense).flatten()\n",
    "\n",
    "    sorted_edge_weights = np.sort(edge_weights)\n",
    "    thres_index = max(int(edge_weights.shape[0] - topk), 0)\n",
    "    thres = sorted_edge_weights[thres_index]\n",
    "\n",
    "    for r, c, d in rcd:\n",
    "        if r<c:\n",
    "            continue\n",
    "        if d >= thres:\n",
    "            pos_edges.append((r, c))\n",
    "        filter_edges.append((r, c))\n",
    "\n",
    "    G = nx.from_numpy_matrix(adj)\n",
    "    pos = nx.kamada_kawai_layout(G)\n",
    "\n",
    "    colors = ['orange', 'lime', 'red', 'blue', 'maroon', 'brown', 'darkslategray', 'paleturquoise', 'darksalmon',\n",
    "              'slategray', 'mediumseagreen', 'mediumblue', 'orchid']\n",
    "\n",
    "    # nodes\n",
    "\n",
    "    nmb_nodes = after_adj_dense.shape[0]\n",
    "\n",
    "    node_filter = []\n",
    "    for node in range(nmb_nodes):\n",
    "        if node in G.nodes():\n",
    "            node_filter.append(node)\n",
    "\n",
    "    nx.draw_networkx_nodes(G, pos,\n",
    "                           nodelist=node_filter,\n",
    "                           node_color=colors[0],\n",
    "                           node_size=300)\n",
    "\n",
    "    nx.draw_networkx_edges(G, pos, width=2, edge_color='grey')\n",
    "\n",
    "    nx.draw_networkx_edges(G, pos,\n",
    "                           edgelist=pos_edges,\n",
    "                           width=7)\n",
    "\n",
    "    plt.axis('off')\n",
    "#     plt.show()\n",
    "    plt.savefig(save_map + str(iteration) + \"/\" + str(graphid) + \".png\")\n",
    "    plt.clf()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def explain_graph(gid):\n",
    "    fea,emb,adj,label,graphid = features[gid], embs[gid], torch.tensor(adjs[gid]), torch.tensor(labels[gid]), gid\n",
    "    explainer.eval()\n",
    "    \n",
    "    explainer((fea,emb,adj,1.0,label))\n",
    "    insert = 20\n",
    "    acc(adj,insert)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test():\n",
    "    global preds\n",
    "    global reals\n",
    "    preds = []\n",
    "    reals = []\n",
    "    for gid in allnodes:\n",
    "        explain_graph(gid)\n",
    "    auc = roc_auc_score(reals,preds)\n",
    "    return auc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def train(iteration):\n",
    "    epochs = args.eepochs\n",
    "    t0 = args.coff_t0\n",
    "    t1 = args.coff_te\n",
    "    best_auc = 0\n",
    "    explainer.train()\n",
    "    for epoch in range(epochs):\n",
    "        loss = 0\n",
    "        tmp = float(t0 * np.power(t1 / t0, epoch /epochs))\n",
    "        train_instances = [ins for ins in range(adjs.shape[0])]\n",
    "        np.random.shuffle(train_instances)\n",
    "        for gid in train_instances:\n",
    "            pred = explainer((features[gid], embs[gid], torch.Tensor(adjs[gid]),tmp, torch.Tensor(labels[gid])))\n",
    "            loss = loss + explainer.loss(pred, pred_label[gid])\n",
    "\n",
    "        train_variables = []\n",
    "        for name, para in explainer.named_parameters():\n",
    "            if \"elayers\" in name:\n",
    "                train_variables.append(para)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        if epoch%1==0:\n",
    "            auc = test()\n",
    "            if auc > best_auc:\n",
    "                best_auc = auc\n",
    "                torch.save(explainer.state_dict(), 'model_weights/BA2motif_BESTAUC.pt')\n",
    "                torch.save(explainer.state_dict(), save_map + str(iteration) + \"/\" + 'BA2motif_BESTAUC.pt')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for iteration in range(10):\n",
    "    print(\"Starting iteration: {}\".format(iteration))\n",
    "\n",
    "    device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "    #CELL 1\n",
    "    with open('./dataset/BA-2motif.pkl','rb') as fin:\n",
    "        adjs, features, labels = pkl.load(fin)\n",
    "\n",
    "    model = GCN(input_dim=features.shape[1:], output_dim=labels.shape[1], device=device)\n",
    "    model.to(device)\n",
    "\n",
    "    if args.bn and args.concat:\n",
    "        model.load_state_dict(torch.load('model_weights/GCN_BA2motif_bn_concat.pt'))\n",
    "    elif args.bn: \n",
    "        model.load_state_dict(torch.load('model_weights/GCN_BA2motif_bn.pt'))\n",
    "    elif args.concat:\n",
    "        model.load_state_dict(torch.load('model_weights/GCN_BA2motif_concat.pt'))      \n",
    "    else:\n",
    "        model.load_state_dict(torch.load('model_weights/GCN_BA2motif_BEST.pt'))\n",
    "    model.eval()\n",
    "\n",
    "    with torch.no_grad():\n",
    "        embs = model.getNodeEmb((torch.tensor(features).type(torch.float32),\\\n",
    "                                torch.tensor(adjs).type(torch.float32)), training=False)\n",
    "\n",
    "        output = model((torch.tensor(features).type(torch.float32),\\\n",
    "                        torch.tensor(adjs).type(torch.float32)), training=False)\n",
    "    pred_label = torch.argmax(output, 1)\n",
    "\n",
    "    #CELL2\n",
    "    if args.setting==1:\n",
    "        allnodes = [i for i in range(0,100)]\n",
    "    elif args.setting==2:\n",
    "        allnodes = [i for i in range(0,100)]\n",
    "        allnodes.extend([i for i in range(500,600)])\n",
    "    elif args.setting==3:\n",
    "        allnodes=[i for i in range(1000)]\n",
    "    explainer = Explainer(model=model,nodesize=adjs.shape[1])\n",
    "    explainer.to(device)\n",
    "    optimizer = torch.optim.Adam(explainer.parameters(), lr=args.elr)\n",
    "    \n",
    "    \n",
    "    \n",
    "    # Training\n",
    "    np.random.seed(iteration)\n",
    "    torch.manual_seed(iteration)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed(iteration)\n",
    "        torch.cuda.manual_seed_all(iteration)\n",
    "        torch.backends.cudnn.deterministic = True\n",
    "        torch.backends.cudnn.benchmark = False\n",
    "    \n",
    "    f = open(save_map + str(iteration) + \"/\" + \"LOG.txt\", \"w\")\n",
    "\n",
    "    train(iteration)\n",
    "    \n",
    "    explainer.load_state_dict(torch.load('model_weights/BA2motif_BESTAUC.pt'))\n",
    "    \n",
    "    global preds\n",
    "    global reals\n",
    "    preds = []\n",
    "    reals = []\n",
    "    \n",
    "    tik = time.time()\n",
    "    for gid in allnodes:\n",
    "        explain_graph(gid)\n",
    "        auc = roc_auc_score(reals,preds)\n",
    "        tok = time.time()\n",
    "        f.write(\"gid,{}\".format(gid) + \",auc,{}\".format(auc) + \",time,{}\".format(tok-tik) + \"\\n\")\n",
    "    \n",
    "    tok = time.time()\n",
    "    f.write(\"time,{}\".format(tok-tik) + \"\\n\")\n",
    "    \n",
    "    for gid in allnodes:\n",
    "        fea,emb,adj,label,graphid = features[gid], embs[gid], adjs[gid], torch.Tensor(labels[gid]), gid\n",
    "        explainer((fea,emb,torch.Tensor(adj),1.0,label))\n",
    "        plot(adj,label,graphid, iteration)\n",
    "\n",
    "    f.close()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
