{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('./codes/')\n",
    "import time\n",
    "from config import args\n",
    "save_map = \"LISA_TEST_LOGS/TREE_CYCLE/\"\n",
    "\n",
    "args.dataset='syn3'\n",
    "args.elr = 0.003\n",
    "args.eepochs = 20\n",
    "args.coff_size = 0.0001\n",
    "args.budget = -1\n",
    "args.coff_ent = 0.01\n",
    "\n",
    "# import tensorflow as tf\n",
    "from utils import *\n",
    "from models import GCN2 as GCN\n",
    "from metrics import *\n",
    "import numpy as np\n",
    "from Extractor import Extractor\n",
    "from Explainer import Explainer\n",
    "from scipy.sparse import coo_matrix,csr_matrix\n",
    "import networkx as nx\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.metrics import roc_auc_score\n",
    "\n",
    "import torch\n",
    "import torch.optim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def plot(node,label, iteration):\n",
    "    after_adj_dense = explainer.masked_adj.cpu().detach().numpy()\n",
    "    after_adj = coo_matrix(after_adj_dense)\n",
    "\n",
    "    rcd = np.concatenate([np.expand_dims(after_adj.row,-1),np.expand_dims(after_adj.col,-1),np.expand_dims(after_adj.data,-1)],-1)\n",
    "    pos_edges = []\n",
    "    filter_edges = []\n",
    "    edge_weights = after_adj.data\n",
    "    sorted_edge_weights = np.sort(edge_weights)\n",
    "    thres_index = max(int(edge_weights.shape[0]-12),0)\n",
    "    thres = sorted_edge_weights[thres_index]\n",
    "    filter_thres_index = min(thres_index,max(int(edge_weights.shape[0]-edge_weights.shape[0]/2),edge_weights.shape[0]-100))\n",
    "    # filter_thres_index = min(thres_index,max(int(edge_weights.shape[0]-edge_weights.shape[0]/4),edge_weights.shape[0]-100))\n",
    "    filter_thres = sorted_edge_weights[filter_thres_index]\n",
    "    filter_nodes =set()\n",
    "\n",
    "    for r,c,d in rcd:\n",
    "        r = int(r)\n",
    "        c = int(c)\n",
    "        if d>=thres:\n",
    "            pos_edges.append((r,c))\n",
    "        if d>filter_thres:\n",
    "            filter_edges.append((r,c))\n",
    "            filter_nodes.add(r)\n",
    "            filter_nodes.add(c)\n",
    "\n",
    "    num_nodes = sub_adj.shape[0]\n",
    "    G = nx.Graph()\n",
    "    G.add_nodes_from(range(num_nodes))\n",
    "    G.add_edges_from(filter_edges)\n",
    "\n",
    "    for cc in nx.connected_components(G):\n",
    "        if 0 in cc:\n",
    "            G = G.subgraph(cc).copy()\n",
    "            break\n",
    "\n",
    "    pos_edges = [(u, v) for (u, v) in pos_edges if u in G.nodes() and v in G.nodes()]\n",
    "    pos = nx.kamada_kawai_layout(G)\n",
    "\n",
    "    colors = ['orange', 'red', 'green', 'blue', 'maroon', 'brown', 'darkslategray', 'paleturquoise', 'darksalmon',\n",
    "              'slategray', 'mediumseagreen', 'mediumblue', 'orchid', ]\n",
    "    if args.dataset=='syn3':\n",
    "        colors = ['orange', 'blue']\n",
    "\n",
    "\n",
    "    if args.dataset=='syn4':\n",
    "        colors = ['orange', 'black','black','black','blue']\n",
    "\n",
    "\n",
    "    # nodes\n",
    "    labels = label#.numpy()\n",
    "    max_label = np.max(labels)+1\n",
    "\n",
    "    nmb_nodes = after_adj_dense.shape[0]\n",
    "    label2nodes= []\n",
    "    for i in range(max_label):\n",
    "    \tlabel2nodes.append([])\n",
    "    for i in range(nmb_nodes):\n",
    "    \tlabel2nodes[labels[i]].append(i)\n",
    "\n",
    "    for i in range(max_label):\n",
    "        node_filter = []\n",
    "        for j in range(len(label2nodes[i])):\n",
    "            if label2nodes[i][j] in G.nodes():\n",
    "                node_filter.append(label2nodes[i][j])\n",
    "        nx.draw_networkx_nodes(G, pos,\n",
    "                               nodelist=node_filter,\n",
    "                               node_color=colors[i % len(colors)],\n",
    "                               node_size=500)\n",
    "\n",
    "    nx.draw_networkx_nodes(G, pos,\n",
    "                           nodelist=[0],\n",
    "                           node_color=colors[labels[0]],\n",
    "                           node_size=1000)\n",
    "\n",
    "    nx.draw_networkx_edges(G, pos, width=7, alpha=0.5, edge_color='grey')\n",
    "\n",
    "    nx.draw_networkx_edges(G, pos,\n",
    "                           edgelist=pos_edges,\n",
    "                           width=7, alpha=0.5)\n",
    "\n",
    "\n",
    "    plt.axis('off')\n",
    "#     plt.show()\n",
    "    plt.savefig(save_map + str(iteration) + \"/\" + str(node) + \".png\")\n",
    "    plt.clf()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(iteration):\n",
    "    t0 = args.coff_t0\n",
    "    t1 = args.coff_te\n",
    "    epochs = args.eepochs\n",
    "    model.eval()\n",
    "    explainer.train()\n",
    "    best_auc = 0\n",
    "    epochs = 10\n",
    "    for epoch in range(epochs):\n",
    "        train_accs = []\n",
    "        loss = 0\n",
    "        pred_loss = 0\n",
    "        lap_loss = 0\n",
    "        tmp = float(t0*np.power(t1/t0,epoch/epochs))\n",
    "        tmp = 5.0\n",
    "        for i in range(len(allnodes)):\n",
    "            with torch.no_grad():\n",
    "                output = model((sub_features[i],sub_support_tensors[i]), training=False)\n",
    "\n",
    "            train_acc = accuracy(output, sub_label_tensors[i])\n",
    "            train_accs.append(float(train_acc))\n",
    "            pred_label = torch.argmax(output, 1)\n",
    "\n",
    "            x = sub_features[i]\n",
    "            adj = sub_adjs[i]\n",
    "            nodeid = 0\n",
    "            embed = torch.Tensor(sub_embeds[i])\n",
    "            pred = explainer((x,adj,nodeid,embed,tmp),training=True)\n",
    "            l,pl,ll = explainer.loss(pred, pred_label, sub_label_tensor, 0)\n",
    "            loss = loss + l\n",
    "            pred_loss = pred_loss + pl\n",
    "            lap_loss = lap_loss + ll\n",
    "                    \n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        torch.nn.utils.clip_grad_value_(explainer.parameters(), clip_value_max)\n",
    "        optimizer.step()\n",
    "\n",
    "        global reals\n",
    "        global preds\n",
    "        reals = []\n",
    "        preds = []\n",
    "        for node in allnodes:\n",
    "            h = explain_test(node, iteration, needplot=True)\n",
    "        auc = roc_auc_score(reals, preds)\n",
    "        explainer.train()\n",
    "        print(\"epoch,{}\".format(epoch) + \",auc,{}\".format(auc))\n",
    "        if auc > best_auc:\n",
    "            print(\"better auc\")\n",
    "            best_auc = auc\n",
    "            torch.save(explainer.state_dict(), f'model_weights/Tree-Cycles_BEST.pt')\n",
    "            torch.save(explainer.state_dict(), save_map + str(iteration) + \"/\" + 'Tree_Cycles_BEST.pt')\n",
    "#             best_state_dict = explainer.state_dict()\n",
    "\n",
    "#     torch.save(best_state_dict, f'model_weights/Tree-Cycles_BEST.pt')\n",
    "#     torch.save(best_state_dict, save_map + str(iteration) + \"/\" + 'Tree_Cycles_BEST.pt')\n",
    "    torch.save(explainer.state_dict(), f'model_weights/Tree-Cycles_LAST.pt')\n",
    "    torch.save(explainer.state_dict(), save_map + str(iteration) + \"/\" + 'Tree_Cycles_LAST.pt')\n",
    "            \n",
    "reals = []\n",
    "preds = []\n",
    "def acc(sub_adj,sub_edge_label):\n",
    "    real = []\n",
    "    pred = []\n",
    "    sub_edge_label = sub_edge_label.todense()\n",
    "    mask = explainer.masked_adj.cpu().detach().numpy()\n",
    "    for r,c in list(zip(sub_adj.row,sub_adj.col)):\n",
    "        d = sub_edge_label[r,c] + sub_edge_label[c,r]\n",
    "        if d==0:\n",
    "            real.append(0)\n",
    "        else:\n",
    "            real.append(1)\n",
    "        pred.append(mask[r][c]+mask[c][r])\n",
    "    reals.extend(real)\n",
    "    preds.extend(pred)\n",
    "\n",
    "    if len(np.unique(real))==1 or len(np.unique(pred))==1:\n",
    "        return -1\n",
    "    return roc_auc_score(real,pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def explain_test(node,iteration,needplot=True):\n",
    "    newid = remap[node]\n",
    "    sub_adj, sub_feature, sub_embed, sub_label, sub_edge_label =  sub_adjs[newid],sub_features[newid],sub_embeds[newid],sub_labels[newid],sub_edge_labels[newid]\n",
    "\n",
    "    nodeid = 0\n",
    "    explainer.eval()\n",
    "    explainer((sub_feature,sub_adj,nodeid,sub_embed,1.0),training=False)\n",
    "    label = np.argmax(sub_label,-1)\n",
    "    if needplot:\n",
    "        plot(node,label,iteration)\n",
    "    acc(sub_adj,sub_edge_label)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "for iteration in range(10):\n",
    "    print(\"Starting iteration: {}\".format(iteration))\n",
    "    device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "    random.seed(iteration)\n",
    "    np.random.seed(iteration)\n",
    "    torch.manual_seed(iteration)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed(iteration)\n",
    "        torch.cuda.manual_seed_all(iteration)\n",
    "        torch.backends.cudnn.deterministic = True\n",
    "        torch.backends.cudnn.benchmark = False\n",
    "        \n",
    "    # CELL 1\n",
    "    with open('./dataset/' + args.dataset + '.pkl', 'rb') as fin:\n",
    "        adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask, edge_label_matrix  = pkl.load(fin)\n",
    "\n",
    "    adj = csr_matrix(adj)\n",
    "    support = preprocess_adj(adj)\n",
    "\n",
    "    features_tensor = torch.tensor(features).type(torch.float32)\n",
    "    i = torch.LongTensor([*support[0]])\n",
    "    v = torch.FloatTensor([*support[1]])\n",
    "    # LET OP: i moet getransposed worden om sparse tensor te maken met pytorch\n",
    "    support_tensor = torch.sparse.FloatTensor(i.t(), v, torch.Size([*support[2]]))\n",
    "    support_tensor = support_tensor.type(torch.float32)\n",
    "\n",
    "    model = GCN(input_dim=features.shape[1], output_dim=y_train.shape[1], device=device)\n",
    "    model.to(device)\n",
    "    model.load_state_dict(torch.load('model_weights/GCN_syn3_BEST.pt'))\n",
    "\n",
    "    explainer = Explainer(model=model)\n",
    "    explainer.to(device)\n",
    "    embeds = model.embedding((features_tensor,support_tensor)).cpu().detach().numpy()\n",
    "\n",
    "    all_label = np.logical_or(y_train,np.logical_or(y_val,y_test))\n",
    "    single_label = np.argmax(all_label,axis=-1)\n",
    "    hops = len(args.hiddens.split('-'))\n",
    "    extractor = Extractor(adj,features,edge_label_matrix,embeds,all_label,hops)\n",
    "    if args.setting==1:\n",
    "        if args.dataset=='syn3':\n",
    "            allnodes = [i for i in range(511,871,6)]\n",
    "        elif args.dataset=='syn4':\n",
    "            allnodes = [i for i in range(511,800,1)]\n",
    "        else:\n",
    "            allnodes = [i for i in range(400,700,5)] # setting from their original paper\n",
    "    elif args.setting==2:\n",
    "        allnodes = [i for i in range(single_label.shape[0]) if single_label[i] ==1]\n",
    "    elif args.setting==3:\n",
    "        if args.dataset == 'syn2':\n",
    "            allnodes = [i for i in range(single_label.shape[0]) if single_label[i] != 0 and single_label[i] != 4]\n",
    "        else:\n",
    "            allnodes = [i for i in range(single_label.shape[0]) if single_label[i] != 0]\n",
    "\n",
    "    optimizer = torch.optim.Adam(explainer.parameters(), lr=args.elr)\n",
    "    clip_value_min = -2.0\n",
    "    clip_value_max = 2.0\n",
    "\n",
    "    sub_support_tensors = []\n",
    "    sub_label_tensors = []\n",
    "    sub_features = []\n",
    "    sub_embeds = []\n",
    "    sub_adjs = []\n",
    "    sub_edge_labels = []\n",
    "    sub_labels = []\n",
    "    remap = {}\n",
    "    \n",
    "    \n",
    "    # CELL 2\n",
    "    for node in allnodes:\n",
    "        sub_adj,sub_feature, sub_embed, sub_label,sub_edge_label_matrix = extractor.subgraph(node)\n",
    "        remap[node]=len(sub_adjs)\n",
    "        sub_support = preprocess_adj(sub_adj)\n",
    "        i = torch.LongTensor([*sub_support[0]])\n",
    "        v = torch.FloatTensor([*sub_support[1]])\n",
    "        # LET OP: i moet getransposed worden om sparse tensor te maken met pytorch\n",
    "        sub_support_tensor = torch.sparse.FloatTensor(i.t(), v, torch.Size([*sub_support[2]])).type(torch.float32) \n",
    "        sub_label_tensor = torch.Tensor(sub_label).type(torch.float32)\n",
    "\n",
    "        sub_adjs.append(sub_adj)\n",
    "        sub_features.append(torch.Tensor(sub_feature).type(torch.float32))\n",
    "        sub_embeds.append(sub_embed)\n",
    "        sub_labels.append(sub_label)\n",
    "        sub_edge_labels.append(sub_edge_label_matrix)\n",
    "        sub_label_tensors.append(sub_label_tensor)\n",
    "        sub_support_tensors.append(sub_support_tensor)\n",
    "    best_auc = 0.0\n",
    "    \n",
    "    # TRAIN LOOP\n",
    "    \n",
    "    f = open(save_map + str(iteration) + \"/\" + \"LOG.txt\", \"w\")\n",
    "    train(iteration)\n",
    "\n",
    "    explainer.load_state_dict(torch.load(save_map + str(iteration) + \"/\" + 'Tree_Cycles_BEST.pt'))\n",
    "    \n",
    "    reals= []\n",
    "    preds = []\n",
    "\n",
    "    tik = time.time()\n",
    "    for node in allnodes:\n",
    "        h = explain_test(node, iteration, needplot=True)\n",
    "        auc = roc_auc_score(reals, preds)\n",
    "        tok = time.time()\n",
    "        f.write(\"node,{}\".format(node) + \",auc,{}\".format(auc) + \",time,{}\".format(tok-tik) + \"\\n\")\n",
    "        print(\"node,{}\".format(node) + \",auc,{}\".format(auc))\n",
    "    \n",
    "    tok = time.time()\n",
    "    print(\"time,{}\".format(tok-tik))\n",
    "    \n",
    "    f.write(\"time,{}\".format(tok-tik) + \"\\n\")\n",
    "        \n",
    "    f.close()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
