{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['/home/mohit/anaconda3/envs/py36/lib/python36.zip', '/home/mohit/anaconda3/envs/py36/lib/python3.6', '/home/mohit/anaconda3/envs/py36/lib/python3.6/lib-dynload', '', '/home/mohit/anaconda3/envs/py36/lib/python3.6/site-packages', '/home/mohit/anaconda3/envs/py36/lib/python3.6/site-packages/IPython/extensions', '/home/mohit/.ipython']\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "print(sys.path)\n",
    "sys.path.insert(0, \"/home/mohit/Mohit/model_interpretation/ai-adversarial-detection\")\n",
    "\n",
    "from dnn_invariant.models.models4invariant import *\n",
    "from dnn_invariant.utilities.trainer import *\n",
    "from dnn_invariant.utilities.datasets import *\n",
    "from dnn_invariant.utilities.environ import *\n",
    "import numpy as np\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = VGG19(num_classes_=2)\n",
    "model.load_state_dict(torch.load(\"./dnn_invariant/models/VGG19.mdl\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "model2 = torch.load(\"./dnn_invariant/models/VGG19.mdl\", map_location='cuda:0')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s1 = model.state_dict()\n",
    "model.state_dict().keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'BatchNorm1d'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m = nn.BatchNorm1d(100)\n",
    "m.__class__.__name__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#synthetic data 1: with 4 patterns \n",
    "import networkx as nx\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pickle\n",
    "\n",
    "max_nodes = 20\n",
    "def makesPattern1(adj, feats, p, n):\n",
    "    l_p = np.argmax(feats[p]) \n",
    "    l_n = np.argmax(feats[n]) \n",
    "    if l_p not in [0,1,2,3] or l_n not in [0,1,2,3]:\n",
    "        return False\n",
    "    n2p = True\n",
    "    p2n = True\n",
    "    if l_p != (l_n+1)%4:\n",
    "        n2p = False\n",
    "    if l_p != (l_n-1)%4:\n",
    "        p2n = False\n",
    "    if p2n == False and n2p == False:\n",
    "        return False\n",
    "    contend_p = []\n",
    "    if p2n == True:\n",
    "        \n",
    "        for nix in range(max_nodes):\n",
    "            l_x = np.argmax(feats[nix])\n",
    "            if adj[nix,p] > 0. and l_x == (l_p-1)%4:\n",
    "                contend_p.append(nix)\n",
    "        if len(contend_p) == 0:\n",
    "            return False\n",
    "        for nix in range(max_nodes):\n",
    "            l_x = np.argmax(feats[nix])\n",
    "            if adj[nix,n] > 0. and l_x == (l_n+1)%4:\n",
    "                for c_p in contend_p:\n",
    "                    if adj[nix,c_p] > 0.:\n",
    "                        return True\n",
    "\n",
    "        \n",
    "        \n",
    "    else:\n",
    "        for nix in range(max_nodes):\n",
    "            l_x = np.argmax(feats[nix])\n",
    "            if adj[nix,p] > 0. and l_x == (l_p+1)%4:\n",
    "                contend_p.append(nix)\n",
    "        if len(contend_p) == 0:\n",
    "            return False\n",
    "        for nix in range(max_nodes):\n",
    "            l_x = np.argmax(feats[nix])\n",
    "            if adj[nix,n] > 0. and l_x == (l_n-1)%4:\n",
    "                for c_p in contend_p:\n",
    "                    if adj[nix,c_p] > 0.:\n",
    "                        return True\n",
    "    return False\n",
    "                \n",
    "        \n",
    "def makesPattern2(adj, feats, p, n):\n",
    "    #B,C,D among A neighbors\n",
    "    l_p = np.argmax(feats[p]) \n",
    "    l_n = np.argmax(feats[n]) \n",
    "    if l_p not in [0,1,2,3] or l_n not in [0,1,2,3]:\n",
    "        return False\n",
    "   \n",
    "    if l_p != 0 and l_n !=0:\n",
    "        return False\n",
    "    if l_p == l_n:\n",
    "        return False\n",
    "    if l_p == 0:\n",
    "        lbls = [1,2,3]\n",
    "#         print(\"l_n: \", l_n, lbls)\n",
    "        lbls.remove(l_n)\n",
    "        for nix in range(max_nodes):\n",
    "            if adj[p,nix] > 0.:\n",
    "                l_x = np.argmax(feats[nix])\n",
    "                if l_x in lbls:\n",
    "#                     print(\"l_x: \", l_x, lbls)\n",
    "                    lbls.remove(l_x)\n",
    "                    if len(lbls) == 0:\n",
    "                        return True\n",
    "    else:\n",
    "        lbls = [1,2,3]\n",
    "#         print(\"l_p: \", l_p, lbls)\n",
    "        lbls.remove(l_p)\n",
    "        for nix in range(max_nodes):\n",
    "            if adj[n,nix] > 0.:\n",
    "                l_x = np.argmax(feats[nix])\n",
    "                if l_x in lbls:\n",
    "#                     print(\"l_x: \", l_x, lbls)\n",
    "                    lbls.remove(l_x)\n",
    "                    if len(lbls) == 0:\n",
    "                        return True\n",
    "    return False\n",
    "\n",
    "def makesPattern3(adj, feats, p, n):\n",
    "    #D,E,F among C neighbors\n",
    "    l_p = np.argmax(feats[p]) \n",
    "    l_n = np.argmax(feats[n]) \n",
    "    if l_p not in [2,3,4,5] or l_n not in [2,3,4,5]:\n",
    "        return False\n",
    "   \n",
    "    if l_p != 2 and l_n !=2:\n",
    "        return False\n",
    "    if l_p == l_n:\n",
    "        return False\n",
    "    if l_p == 2:\n",
    "        \n",
    "        lbls = [3,4,5]\n",
    "#         print(\"l_n: \", l_n, lbls)\n",
    "        lbls.remove(l_n)\n",
    "        for nix in range(max_nodes):\n",
    "            if adj[p,nix] > 0.:\n",
    "                l_x = np.argmax(feats[nix])\n",
    "                if l_x in lbls:\n",
    "#                     print(\"l_x: \", l_x, lbls)\n",
    "                    lbls.remove(l_x)\n",
    "                    if len(lbls) == 0:\n",
    "                        return True\n",
    "    else:\n",
    "        lbls = [3,4,5]\n",
    "#         print(\"l_p: \", l_p, lbls)\n",
    "\n",
    "        lbls.remove(l_p)\n",
    "        for nix in range(max_nodes):\n",
    "            if adj[n,nix] > 0.:\n",
    "                l_x = np.argmax(feats[nix])\n",
    "                if l_x in lbls:\n",
    "#                     print(\"l_x: \", l_x, lbls)\n",
    "                    lbls.remove(l_x)\n",
    "                    if len(lbls) == 0:\n",
    "                        return True\n",
    "    return False\n",
    "                \n",
    "        \n",
    "        \n",
    "\n",
    "def makesPattern(adj, feats, p, n):\n",
    "    p_l = []\n",
    "    p1 = makesPattern1(adj, feats, p, n)\n",
    "    if p1 is True:\n",
    "        p_l.append(0)\n",
    "    \n",
    "    p2 = makesPattern2(adj, feats, p, n)\n",
    "    if p2 is True:\n",
    "        p_l.append(1)\n",
    "    p3 = makesPattern3(adj, feats, p, n)\n",
    "    if p3 is True:\n",
    "        p_l.append(2)\n",
    "    \n",
    "    return ((len(p_l) > 0), p_l)\n",
    "\n",
    "def addPattern(adj,feats,nodes, sub_label):\n",
    "    degree_sum = np.sum(adj,axis=1)\n",
    "    avg_deg = int(np.sum(degree_sum)/nodes)\n",
    "    if sub_label == 0 or sub_label == 1:\n",
    "        lbls = [0,1,2,3]\n",
    "        for i in range(4):\n",
    "            feats[nodes+i,lbls[i]] = 1.\n",
    "        if sub_label == 0:\n",
    "            for i in range(3):\n",
    "                adj[nodes+i,nodes+i+1] = 1.0\n",
    "                adj[nodes+i+1, nodes+i] = 1.0\n",
    "            adj[nodes+3,nodes] = 1.0\n",
    "            adj[nodes,nodes+3] = 1.0\n",
    "        else:\n",
    "            for i in range(1,4):\n",
    "                adj[nodes,nodes+i] = 1.0\n",
    "                adj[nodes+i,nodes] = 1.0\n",
    "    elif sub_label == 2:\n",
    "        lbls = [2,3,4,5]\n",
    "        for i in range(4):\n",
    "            feats[nodes+i,lbls[i]] = 1.\n",
    "        \n",
    "        for i in range(1,4):\n",
    "            adj[nodes,nodes+i] = 1.0\n",
    "            adj[nodes+i,nodes] = 1.0\n",
    "    else: #-1 fake pattern\n",
    "        fake_sub = np.random.randint(2)\n",
    "        if fake_sub == 0: #add ring\n",
    "            lbls = np.random.randint(0,6,size=(4)).tolist()\n",
    "            while True:\n",
    "                sample = True\n",
    "                for jx in range(3):\n",
    "                    if lbls[jx] != lbls[jx+1] - 1:\n",
    "                        sample = False\n",
    "                        break\n",
    "                if sample:\n",
    "                    lbls = np.random.randint(0,6,size=(4)).tolist()\n",
    "                else:\n",
    "                    break\n",
    "            for i in range(4):\n",
    "                feats[nodes+i,lbls[i]] = 1.\n",
    "            for i in range(3):\n",
    "                adj[nodes+i,nodes+i+1] = 1.0\n",
    "                adj[nodes+i+1, nodes+i] = 1.0\n",
    "            adj[nodes+3,nodes] = 1.0\n",
    "            adj[nodes,nodes+3] = 1.0\n",
    "            \n",
    "        else: #add tetra\n",
    "            lbls = np.random.randint(0,6,size=(4)).tolist()\n",
    "            while True:\n",
    "                sample = False\n",
    "                if lbls[0] == 0:\n",
    "                    if 1 in lbls and 2 in lbls and 3 in lbls:\n",
    "                        sample = True\n",
    "                    else:\n",
    "                        break\n",
    "                        \n",
    "                elif lbls[0] == 2:\n",
    "                    if 3 in lbls and 4 in lbls and 5 in lbls:\n",
    "                        sample = True\n",
    "                    else:\n",
    "                        break\n",
    "                if sample:\n",
    "                    lbls = np.random.randint(0,6,size=(4)).tolist()\n",
    "                else:\n",
    "                    break\n",
    "            \n",
    "            for i in range(4):\n",
    "                feats[nodes+i,lbls[i]] = 1.\n",
    "        \n",
    "            for i in range(1,4):\n",
    "                adj[nodes,nodes+i] = 1.0\n",
    "                adj[nodes+i,nodes] = 1.0\n",
    "\n",
    "\n",
    "        \n",
    "    for i in range(4):\n",
    "        deg_exp = np.random.randint(avg_deg-2, avg_deg+2)\n",
    "        dest_n = np.random.randint(nodes)\n",
    "        while_count = 0\n",
    "        skip = False\n",
    "        while(makesPattern(adj, feats, dest_n, nodes+i)[0] == True):\n",
    "            dest_n = np.random.randint(nodes)\n",
    "            while_count += 1\n",
    "            if while_count == 5:\n",
    "                skip = True\n",
    "                break\n",
    "        if skip:\n",
    "            continue\n",
    "            \n",
    "        adj[nodes+i, dest_n] = 1.0\n",
    "        adj[dest_n, nodes+i] = 1.0\n",
    "        deg = int(np.sum(adj[nodes+i,:]))\n",
    "        if deg_exp > deg:\n",
    "            for e in range(deg_exp-deg):\n",
    "                dest_n = np.random.randint(nodes+4)\n",
    "                if adj[nodes+i, dest_n] > 0:\n",
    "                    continue\n",
    "                if(makesPattern(adj, feats, dest_n, nodes+i)[0] == True):\n",
    "                    continue\n",
    "                adj[nodes+i, dest_n] = 1.0\n",
    "                adj[dest_n, nodes+i] = 1.0\n",
    "        \n",
    "    \n",
    "    pos_t = list(range(nodes+4))\n",
    "    pos_covered = []\n",
    "    for i in range(4):\n",
    "        \n",
    "        if np.random.rand() < 0.4:\n",
    "            continue\n",
    "        \n",
    "        dest_pos = np.random.randint(nodes)\n",
    "        if dest_pos in pos_covered:\n",
    "            continue\n",
    "        \n",
    "        \n",
    "        \n",
    "        temp_feats = np.copy(feats[nodes+i,:])\n",
    "        feats[nodes+i,:] = feats[dest_pos,:]\n",
    "        feats[dest_pos,:] = temp_feats\n",
    "        \n",
    "        \n",
    "        temp_adj1 = np.copy(adj[nodes+i,:])\n",
    "        temp_adj2 = np.copy(adj[:,nodes+i])\n",
    "\n",
    "        adj[nodes+i,:] = adj[dest_pos,:]\n",
    "        adj[:,nodes+i] = adj[:,dest_pos]\n",
    "        adj[dest_pos,:] = temp_adj1\n",
    "        adj[:,dest_pos] = temp_adj2\n",
    "        \n",
    "        \n",
    "        pos_t[nodes+i] = dest_pos\n",
    "        pos_t[dest_pos] = nodes + i\n",
    "        pos_covered.append(dest_pos)\n",
    "      \n",
    "        \n",
    "        \n",
    "#     print(\"nodes: \", nodes+4, pos_t[-4:])\n",
    "        \n",
    "        \n",
    "        \n",
    "\n",
    "    return adj, feats, pos_t[nodes:]\n",
    "                            \n",
    "            \n",
    "def drawGraph(adj, feats, nodes, highlight_nodes=None):\n",
    "    node_labels = ['A','B','C','D','E','F']\n",
    "    G_class = nx.from_numpy_array(adj[:nodes,:nodes])\n",
    "\n",
    "    fig, ax_l = plt.subplots(1,1, figsize=(15,10))\n",
    "    colors = []\n",
    "    for n in range(nodes):\n",
    "        colors.append((0.9,0.9,0.9))\n",
    "    if highlight_nodes is not None:\n",
    "        for h_n in highlight_nodes:\n",
    "            colors[h_n] = (0.9,0.1,0.1)\n",
    "    labels_dict = {}\n",
    "    for n in range(nodes):\n",
    "        lb = np.argmax(feats[n,:])\n",
    "        labels_dict[n] = node_labels[lb]\n",
    "        \n",
    "#     colors[0] = (0.9,0.1,0.1)\n",
    "#     nx.draw_networkx(G_class, ax=ax_l, node_color=colors)\n",
    "    nx.draw_networkx(G_class,labels=labels_dict, ax=ax_l, node_color = colors)\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "len_data = 4000\n",
    "max_nodes = 20\n",
    "feat_dim = 6\n",
    "feats_data = np.zeros((len_data, max_nodes, feat_dim))\n",
    "adjs_data = np.zeros((len_data, max_nodes, max_nodes))\n",
    "labels_data = np.zeros((len_data),dtype=np.int32)\n",
    "sub_labels_data = np.zeros((len_data),dtype=np.int32) - 1\n",
    "sub_label_nodes = np.zeros((len_data,4),dtype=np.int32) - 1\n",
    "num_nodes_data = np.zeros((len_data),dtype=np.int32)\n",
    "\n",
    "#A/B/C/D/E/F\n",
    "for i in range(len_data):\n",
    "    adj = np.zeros((max_nodes, max_nodes))\n",
    "    feats = np.zeros((max_nodes, feat_dim))\n",
    "    label = np.random.randint(2)\n",
    "    if label == 0:\n",
    "        if np.random.rand() < 0.5:\n",
    "            add_fake = True\n",
    "            nodes = np.random.randint(10-4,max_nodes+1-4)\n",
    "        else:\n",
    "            add_fake = False\n",
    "            nodes = np.random.randint(10,max_nodes+1)\n",
    "    else:\n",
    "        nodes = np.random.randint(10-4,max_nodes+1-4)\n",
    "\n",
    "#     n_dict = {}\n",
    "    for n_ix in range(nodes):\n",
    "        l_n = np.random.randint(0,feat_dim)\n",
    "        feats[n_ix, l_n] = 1.0\n",
    "#         if l_n not in n_dict:\n",
    "#             n_dict[l_n] = [n_ix]\n",
    "#         else:\n",
    "#             n_dict[l_n].append(n_ix)\n",
    "        if n_ix == 0:\n",
    "            continue\n",
    "        p_node = np.random.randint(n_ix)\n",
    "        while makesPattern(adj, feats, p_node, n_ix)[0] == True:\n",
    "            p_node = np.random.randint(n_ix)\n",
    "\n",
    "        adj[p_node,n_ix] = 1.\n",
    "        adj[n_ix, p_node] = 1.\n",
    "        \n",
    "    max_edges = min(26,int((nodes*nodes-1)/6))\n",
    "    if max_edges > nodes:\n",
    "        edge_total = np.random.randint(nodes, max_edges)\n",
    "        for e_ix in range(edge_total-nodes+1):\n",
    "            rand_nix = np.random.randint(nodes)\n",
    "            rand_pix = np.random.randint(nodes)\n",
    "            if rand_nix == rand_pix:\n",
    "                continue\n",
    "            if not makesPattern(adj, feats, rand_pix, rand_nix)[0]:\n",
    "                adj[rand_nix, rand_pix] = 1.0\n",
    "                adj[rand_pix, rand_nix] = 1.0\n",
    "\n",
    "        \n",
    "    #0 or 1\n",
    "    highlight_nodes = None\n",
    "\n",
    "    if label == 0:\n",
    "        sub_label = -1\n",
    "        if add_fake:\n",
    "            adj, feats, highlight_nodes = addPattern(adj,feats,nodes, sub_label)\n",
    "            nodes = nodes + 4\n",
    "\n",
    "        #done\n",
    "    else:\n",
    "        sub_label = np.random.randint(3)\n",
    "        adj, feats, highlight_nodes = addPattern(adj,feats,nodes, sub_label)\n",
    "        nodes = nodes + 4\n",
    "        \n",
    "    print(i, label, sub_label)\n",
    "    \n",
    "#     if label == 1 or add_fake == True:\n",
    "#         drawGraph(adj, feats, nodes, highlight_nodes=highlight_nodes)\n",
    "#     else:\n",
    "#         drawGraph(adj, feats, nodes)\n",
    "    feats_data[i] = feats\n",
    "    labels_data[i] = label\n",
    "    sub_labels_data[i] = sub_label\n",
    "    if highlight_nodes is not None:\n",
    "        sub_label_nodes[i] = np.array(highlight_nodes)\n",
    "    num_nodes_data[i] = nodes\n",
    "    adjs_data[i] = adj\n",
    "    \n",
    "\n",
    "synthetic_data = {}\n",
    "synthetic_data['adj'] = adjs_data\n",
    "synthetic_data['feat'] = feats_data\n",
    "synthetic_data['label'] = labels_data\n",
    "synthetic_data['sub_label'] = sub_labels_data\n",
    "synthetic_data['sub_label_nodes'] = sub_label_nodes\n",
    "synthetic_data['num_nodes'] = num_nodes_data\n",
    "\n",
    "pickle.dump(synthetic_data, open(\"../../gcn_interpretation/data/synthetic_data_2label_3sublabel/synthetic_data.p\", \"wb\"))\n",
    "        \n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "        \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0., 0., 0., 0., 1., 0.],\n",
       "       [0., 0., 0., 0., 0., 1.],\n",
       "       [1., 0., 0., 0., 0., 0.],\n",
       "       [0., 1., 0., 0., 0., 0.],\n",
       "       [0., 0., 1., 0., 0., 0.],\n",
       "       [0., 0., 0., 1., 0., 0.],\n",
       "       [0., 0., 0., 0., 1., 0.],\n",
       "       [0., 0., 1., 0., 0., 0.],\n",
       "       [0., 0., 0., 0., 0., 1.],\n",
       "       [0., 0., 0., 0., 1., 0.],\n",
       "       [0., 0., 0., 0., 1., 0.],\n",
       "       [0., 1., 0., 0., 0., 0.],\n",
       "       [1., 0., 0., 0., 0., 0.],\n",
       "       [0., 1., 0., 0., 0., 0.],\n",
       "       [0., 0., 0., 0., 0., 0.],\n",
       "       [0., 0., 0., 0., 0., 0.],\n",
       "       [0., 0., 0., 0., 0., 0.],\n",
       "       [0., 0., 0., 0., 0., 0.],\n",
       "       [0., 0., 0., 0., 0., 0.],\n",
       "       [0., 0., 0., 0., 0., 0.]])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "synthetic_data['feat'][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "data = synthetic_data\n",
    "train_idx = 3000\n",
    "X_train_feat = torch.from_numpy(data['feat'][:train_idx]).float()\n",
    "X_train_adj = torch.from_numpy(data['adj'][:train_idx]).float()\n",
    "X_train_nodes = torch.from_numpy(data['num_nodes'][:train_idx])\n",
    "y_train_label = torch.from_numpy(data['label'][:train_idx])\n",
    "# data['sub_label']\n",
    "\n",
    "X_val_feat = torch.from_numpy(data['feat'][train_idx:]).float()\n",
    "X_val_adj = torch.from_numpy(data['adj'][train_idx:]).float()\n",
    "X_val_nodes = torch.from_numpy(data['num_nodes'][train_idx:])\n",
    "y_val_label = torch.from_numpy(data['label'][train_idx:])\n",
    "\n",
    "tensor_data_train = (X_train_adj, X_train_feat, y_train_label, X_train_nodes)\n",
    "tensor_data_val = (X_val_adj, X_val_feat, y_val_label, X_val_nodes)\n",
    "\n",
    "import torch\n",
    "torch.save(tensor_data_train, \"./data/synthetic_train.pth\")\n",
    "torch.save(tensor_data_val, \"./data/synthetic_val.pth\")\n",
    "\n",
    "#3 cases\n",
    "#ABCD form a ring\n",
    "#or\n",
    "#C has D, E and F among its 3 neighbors\n",
    "#or\n",
    "#A has B,C,D among its 3 neighbors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#synthetic data 1: with 4 patterns \n",
    "import networkx as nx\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pickle\n",
    "\n",
    "max_nodes = 20\n",
    "def makesPattern1(adj, feats, p, n):\n",
    "    l_p = np.argmax(feats[p]) \n",
    "    l_n = np.argmax(feats[n]) \n",
    "    if l_p not in [0,1,2,3] or l_n not in [0,1,2,3]:\n",
    "        return False\n",
    "    n2p = True\n",
    "    p2n = True\n",
    "    if l_p != (l_n+1)%4:\n",
    "        n2p = False\n",
    "    if l_p != (l_n-1)%4:\n",
    "        p2n = False\n",
    "    if p2n == False and n2p == False:\n",
    "        return False\n",
    "    contend_p = []\n",
    "    if p2n == True:\n",
    "        \n",
    "        for nix in range(max_nodes):\n",
    "            l_x = np.argmax(feats[nix])\n",
    "            if adj[nix,p] > 0. and l_x == (l_p-1)%4:\n",
    "                contend_p.append(nix)\n",
    "        if len(contend_p) == 0:\n",
    "            return False\n",
    "        for nix in range(max_nodes):\n",
    "            l_x = np.argmax(feats[nix])\n",
    "            if adj[nix,n] > 0. and l_x == (l_n+1)%4:\n",
    "                for c_p in contend_p:\n",
    "                    if adj[nix,c_p] > 0.:\n",
    "                        return True\n",
    "\n",
    "        \n",
    "        \n",
    "    else:\n",
    "        for nix in range(max_nodes):\n",
    "            l_x = np.argmax(feats[nix])\n",
    "            if adj[nix,p] > 0. and l_x == (l_p+1)%4:\n",
    "                contend_p.append(nix)\n",
    "        if len(contend_p) == 0:\n",
    "            return False\n",
    "        for nix in range(max_nodes):\n",
    "            l_x = np.argmax(feats[nix])\n",
    "            if adj[nix,n] > 0. and l_x == (l_n-1)%4:\n",
    "                for c_p in contend_p:\n",
    "                    if adj[nix,c_p] > 0.:\n",
    "                        return True\n",
    "    return False\n",
    "                \n",
    "        \n",
    "def makesPattern2(adj, feats, p, n):\n",
    "    #B,C,D among A neighbors\n",
    "    l_p = np.argmax(feats[p]) \n",
    "    l_n = np.argmax(feats[n]) \n",
    "    if l_p not in [0,1,2,3] or l_n not in [0,1,2,3]:\n",
    "        return False\n",
    "   \n",
    "    if l_p != 0 and l_n !=0:\n",
    "        return False\n",
    "    if l_p == l_n:\n",
    "        return False\n",
    "    if l_p == 0:\n",
    "        lbls = [1,2,3]\n",
    "#         print(\"l_n: \", l_n, lbls)\n",
    "        lbls.remove(l_n)\n",
    "        for nix in range(max_nodes):\n",
    "            if adj[p,nix] > 0.:\n",
    "                l_x = np.argmax(feats[nix])\n",
    "                if l_x in lbls:\n",
    "#                     print(\"l_x: \", l_x, lbls)\n",
    "                    lbls.remove(l_x)\n",
    "                    if len(lbls) == 0:\n",
    "                        return True\n",
    "    else:\n",
    "        lbls = [1,2,3]\n",
    "#         print(\"l_p: \", l_p, lbls)\n",
    "        lbls.remove(l_p)\n",
    "        for nix in range(max_nodes):\n",
    "            if adj[n,nix] > 0.:\n",
    "                l_x = np.argmax(feats[nix])\n",
    "                if l_x in lbls:\n",
    "#                     print(\"l_x: \", l_x, lbls)\n",
    "                    lbls.remove(l_x)\n",
    "                    if len(lbls) == 0:\n",
    "                        return True\n",
    "    return False\n",
    "\n",
    "def makesPattern3(adj, feats, p, n):\n",
    "    #D,E,F among C neighbors\n",
    "    l_p = np.argmax(feats[p]) \n",
    "    l_n = np.argmax(feats[n]) \n",
    "    if l_p not in [2,3,4,5] or l_n not in [2,3,4,5]:\n",
    "        return False\n",
    "   \n",
    "    if l_p != 2 and l_n !=2:\n",
    "        return False\n",
    "    if l_p == l_n:\n",
    "        return False\n",
    "    if l_p == 2:\n",
    "        \n",
    "        lbls = [3,4,5]\n",
    "#         print(\"l_n: \", l_n, lbls)\n",
    "        lbls.remove(l_n)\n",
    "        for nix in range(max_nodes):\n",
    "            if adj[p,nix] > 0.:\n",
    "                l_x = np.argmax(feats[nix])\n",
    "                if l_x in lbls:\n",
    "#                     print(\"l_x: \", l_x, lbls)\n",
    "                    lbls.remove(l_x)\n",
    "                    if len(lbls) == 0:\n",
    "                        return True\n",
    "    else:\n",
    "        lbls = [3,4,5]\n",
    "#         print(\"l_p: \", l_p, lbls)\n",
    "\n",
    "        lbls.remove(l_p)\n",
    "        for nix in range(max_nodes):\n",
    "            if adj[n,nix] > 0.:\n",
    "                l_x = np.argmax(feats[nix])\n",
    "                if l_x in lbls:\n",
    "#                     print(\"l_x: \", l_x, lbls)\n",
    "                    lbls.remove(l_x)\n",
    "                    if len(lbls) == 0:\n",
    "                        return True\n",
    "    return False\n",
    "                \n",
    "        \n",
    "        \n",
    "\n",
    "def makesPattern(adj, feats, p, n):\n",
    "    p_l = []\n",
    "    p1 = makesPattern1(adj, feats, p, n)\n",
    "    if p1 is True:\n",
    "        p_l.append(0)\n",
    "    \n",
    "    p2 = makesPattern2(adj, feats, p, n)\n",
    "    if p2 is True:\n",
    "        p_l.append(1)\n",
    "    p3 = makesPattern3(adj, feats, p, n)\n",
    "    if p3 is True:\n",
    "        p_l.append(2)\n",
    "    \n",
    "    return ((len(p_l) > 0), p_l)\n",
    "\n",
    "def addPattern(adj,feats,nodes, sub_label):\n",
    "    degree_sum = np.sum(adj,axis=1)\n",
    "    avg_deg = int(np.sum(degree_sum)/nodes)\n",
    "    if sub_label == 0 or sub_label == 1: #has to do with A,B,C,D\n",
    "        lbls = [0,1,2,3]\n",
    "        for i in range(4): #add 4 nodes\n",
    "            feats[nodes+i,lbls[i]] = 1.\n",
    "        if sub_label == 0:\n",
    "            for i in range(3): #make ring\n",
    "                adj[nodes+i,nodes+i+1] = 1.0\n",
    "                adj[nodes+i+1, nodes+i] = 1.0\n",
    "            adj[nodes+3,nodes] = 1.0\n",
    "            adj[nodes,nodes+3] = 1.0\n",
    "        else:\n",
    "            for i in range(1,4): #make tetra\n",
    "                adj[nodes,nodes+i] = 1.0\n",
    "                adj[nodes+i,nodes] = 1.0\n",
    "    elif sub_label == 2:\n",
    "        lbls = [2,3,4,5]\n",
    "        for i in range(4): \n",
    "            feats[nodes+i,lbls[i]] = 1.\n",
    "        \n",
    "        for i in range(1,4): #make tetra-h with CDEF\n",
    "            adj[nodes,nodes+i] = 1.0\n",
    "            adj[nodes+i,nodes] = 1.0\n",
    "    else: #-1 fake pattern\n",
    "        fake_sub = np.random.randint(2)\n",
    "        if fake_sub == 0: #add ring\n",
    "            lbls = np.random.randint(0,6,size=(4)).tolist()\n",
    "            while True:\n",
    "                sample = True\n",
    "                for jx in range(3):\n",
    "                    if lbls[jx] != lbls[jx+1] - 1:\n",
    "                        sample = False\n",
    "                        break\n",
    "                if sample:\n",
    "                    lbls = np.random.randint(0,6,size=(4)).tolist()\n",
    "                else:\n",
    "                    break\n",
    "            for i in range(4):\n",
    "                feats[nodes+i,lbls[i]] = 1.\n",
    "            for i in range(3):\n",
    "                adj[nodes+i,nodes+i+1] = 1.0\n",
    "                adj[nodes+i+1, nodes+i] = 1.0\n",
    "            adj[nodes+3,nodes] = 1.0\n",
    "            adj[nodes,nodes+3] = 1.0\n",
    "            \n",
    "        else: #add tetra\n",
    "            lbls = np.random.randint(0,6,size=(4)).tolist()\n",
    "            while True:\n",
    "                sample = False\n",
    "                if lbls[0] == 0:\n",
    "                    if 1 in lbls and 2 in lbls and 3 in lbls:\n",
    "                        sample = True\n",
    "                    else:\n",
    "                        break\n",
    "                        \n",
    "                elif lbls[0] == 2:\n",
    "                    if 3 in lbls and 4 in lbls and 5 in lbls:\n",
    "                        sample = True\n",
    "                    else:\n",
    "                        break\n",
    "                if sample:\n",
    "                    lbls = np.random.randint(0,6,size=(4)).tolist()\n",
    "                else:\n",
    "                    break\n",
    "            \n",
    "            for i in range(4):\n",
    "                feats[nodes+i,lbls[i]] = 1.\n",
    "        \n",
    "            for i in range(1,4):\n",
    "                adj[nodes,nodes+i] = 1.0\n",
    "                adj[nodes+i,nodes] = 1.0\n",
    "\n",
    "\n",
    "        \n",
    "    for i in range(4):\n",
    "        deg_exp = np.random.randint(avg_deg-2, avg_deg+2)\n",
    "        dest_n = np.random.randint(nodes)\n",
    "        while_count = 0\n",
    "        skip = False\n",
    "        while(makesPattern(adj, feats, dest_n, nodes+i)[0] == True):\n",
    "            dest_n = np.random.randint(nodes)\n",
    "            while_count += 1\n",
    "            if while_count == 5:\n",
    "                skip = True\n",
    "                break\n",
    "        if skip:\n",
    "            continue\n",
    "            \n",
    "        adj[nodes+i, dest_n] = 1.0\n",
    "        adj[dest_n, nodes+i] = 1.0\n",
    "        deg = int(np.sum(adj[nodes+i,:]))\n",
    "        if deg_exp > deg:\n",
    "            for e in range(deg_exp-deg):\n",
    "                dest_n = np.random.randint(nodes+4)\n",
    "                if adj[nodes+i, dest_n] > 0:\n",
    "                    continue\n",
    "                if(makesPattern(adj, feats, dest_n, nodes+i)[0] == True):\n",
    "                    continue\n",
    "                adj[nodes+i, dest_n] = 1.0\n",
    "                adj[dest_n, nodes+i] = 1.0\n",
    "        \n",
    "    \n",
    "    pos_t = list(range(nodes+4))\n",
    "    pos_covered = []\n",
    "    for i in range(4):\n",
    "        \n",
    "        if np.random.rand() < 0.4:\n",
    "            continue\n",
    "        \n",
    "        dest_pos = np.random.randint(nodes)\n",
    "        if dest_pos in pos_covered:\n",
    "            continue\n",
    "        \n",
    "        \n",
    "        \n",
    "        temp_feats = np.copy(feats[nodes+i,:])\n",
    "        feats[nodes+i,:] = feats[dest_pos,:]\n",
    "        feats[dest_pos,:] = temp_feats\n",
    "        \n",
    "        \n",
    "        temp_adj1 = np.copy(adj[nodes+i,:])\n",
    "        temp_adj2 = np.copy(adj[:,nodes+i])\n",
    "\n",
    "        adj[nodes+i,:] = adj[dest_pos,:]\n",
    "        adj[:,nodes+i] = adj[:,dest_pos]\n",
    "        adj[dest_pos,:] = temp_adj1\n",
    "        adj[:,dest_pos] = temp_adj2\n",
    "        \n",
    "        \n",
    "        pos_t[nodes+i] = dest_pos\n",
    "        pos_t[dest_pos] = nodes + i\n",
    "        pos_covered.append(dest_pos)\n",
    "      \n",
    "        \n",
    "        \n",
    "#     print(\"nodes: \", nodes+4, pos_t[-4:])\n",
    "        \n",
    "        \n",
    "        \n",
    "\n",
    "    return adj, feats, pos_t[nodes:]\n",
    "                            \n",
    "            \n",
    "def drawGraph(adj, feats, nodes, highlight_nodes=None):\n",
    "    node_labels = ['A','B','C','D','E','F']\n",
    "    G_class = nx.from_numpy_array(adj[:nodes,:nodes])\n",
    "\n",
    "    fig, ax_l = plt.subplots(1,1, figsize=(15,10))\n",
    "    colors = []\n",
    "    for n in range(nodes):\n",
    "        colors.append((0.9,0.9,0.9))\n",
    "    if highlight_nodes is not None:\n",
    "        for h_n in highlight_nodes:\n",
    "            colors[h_n] = (0.9,0.1,0.1)\n",
    "    labels_dict = {}\n",
    "    for n in range(nodes):\n",
    "        lb = np.argmax(feats[n,:])\n",
    "        labels_dict[n] = node_labels[lb]\n",
    "        \n",
    "#     colors[0] = (0.9,0.1,0.1)\n",
    "#     nx.draw_networkx(G_class, ax=ax_l, node_color=colors)\n",
    "    nx.draw_networkx(G_class,labels=labels_dict, ax=ax_l, node_color = colors)\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "len_data = 8000\n",
    "max_nodes = 20\n",
    "feat_dim = 6\n",
    "feats_data = np.zeros((len_data, max_nodes, feat_dim))\n",
    "adjs_data = np.zeros((len_data, max_nodes, max_nodes))\n",
    "labels_data = np.zeros((len_data),dtype=np.int32)\n",
    "sub_labels_data = np.zeros((len_data),dtype=np.int32) - 1\n",
    "sub_label_nodes = np.zeros((len_data,4),dtype=np.int32) - 1\n",
    "num_nodes_data = np.zeros((len_data),dtype=np.int32)\n",
    "\n",
    "#A/B/C/D/E/F\n",
    "for i in range(len_data):\n",
    "    adj = np.zeros((max_nodes, max_nodes))\n",
    "    feats = np.zeros((max_nodes, feat_dim))\n",
    "    label = np.random.randint(2)\n",
    "    if np.random.rand() < 0.1:\n",
    "        label = 0\n",
    "    else:\n",
    "        label = 1 #all with patterns and then make minor changes to change the pattern\n",
    "    \n",
    "    if label == 0:\n",
    "        if np.random.rand() < 0.5:\n",
    "            add_fake = True\n",
    "            nodes = np.random.randint(10-4,max_nodes+1-4)\n",
    "        else:\n",
    "            add_fake = False\n",
    "            nodes = np.random.randint(10,max_nodes+1)\n",
    "    else:\n",
    "        nodes = np.random.randint(10-4,max_nodes+1-4) #4 nodes less because they are added later\n",
    "\n",
    "#     n_dict = {}\n",
    "    for n_ix in range(nodes): # add nodes\n",
    "        l_n = np.random.randint(0,feat_dim)\n",
    "        feats[n_ix, l_n] = 1.0\n",
    "#         if l_n not in n_dict:\n",
    "#             n_dict[l_n] = [n_ix]\n",
    "#         else:\n",
    "#             n_dict[l_n].append(n_ix)\n",
    "        if n_ix == 0:\n",
    "            continue\n",
    "        p_node = np.random.randint(n_ix) \n",
    "        while makesPattern(adj, feats, p_node, n_ix)[0] == True: #connect node to graph such that\n",
    "            p_node = np.random.randint(n_ix)                    #it doesn't result in any pattern\n",
    "\n",
    "        adj[p_node,n_ix] = 1.\n",
    "        adj[n_ix, p_node] = 1.\n",
    "        \n",
    "    max_edges = min(26,int((nodes*nodes-1)/6))\n",
    "    if max_edges > nodes:\n",
    "        edge_total = np.random.randint(nodes, max_edges)\n",
    "        for e_ix in range(edge_total-nodes+1):\n",
    "            rand_nix = np.random.randint(nodes)\n",
    "            rand_pix = np.random.randint(nodes)\n",
    "            if rand_nix == rand_pix:\n",
    "                continue\n",
    "            if not makesPattern(adj, feats, rand_pix, rand_nix)[0]: #add more random edges\n",
    "                adj[rand_nix, rand_pix] = 1.0\n",
    "                adj[rand_pix, rand_nix] = 1.0\n",
    "\n",
    "        \n",
    "    #0 or 1\n",
    "    highlight_nodes = None\n",
    "\n",
    "    if label == 0:\n",
    "        sub_label = -1\n",
    "        if add_fake:\n",
    "            adj, feats, highlight_nodes = addPattern(adj,feats,nodes, sub_label)\n",
    "            nodes = nodes + 4\n",
    "\n",
    "        #done\n",
    "    else:\n",
    "        sub_label = np.random.randint(3)\n",
    "        adj, feats, highlight_nodes = addPattern(adj,feats,nodes, sub_label)\n",
    "        nodes = nodes + 4\n",
    "        ####this code will alter the pattern making few changes###########\n",
    "        if np.random.rand()  < 0.4:\n",
    "            if sub_label == 1 or sub_label == 2:\n",
    "                break_node = np.random.randint(1,4)\n",
    "                assert( adj[highlight_nodes[0], highlight_nodes[break_node]] > 0.0)\n",
    "                adj[highlight_nodes[0], highlight_nodes[break_node]] = 0.\n",
    "                adj[highlight_nodes[break_node], highlight_nodes[0]] = 0.\n",
    "                dest_ix = np.random.randint(nodes)\n",
    "                while_count = 0\n",
    "                found = True\n",
    "                while(dest_ix == highlight_nodes[0]) or (makesPattern(adj, feats,highlight_nodes[0],dest_ix)[0]):\n",
    "                    dest_ix = np.random.randint(nodes)\n",
    "                    if while_count == 5:\n",
    "                        found = False\n",
    "                        break\n",
    "                    while_count += 1\n",
    "                if found:\n",
    "                    adj[highlight_nodes[0], dest_ix] = 1.0\n",
    "                    adj[dest_ix, highlight_nodes[0]] = 1.0\n",
    "                h_ix = np.random.randint(1,4)\n",
    "                while_count = 0\n",
    "                found = True\n",
    "                while(h_ix == break_node) or (makesPattern(adj, feats,highlight_nodes[h_ix],highlight_nodes[break_node])[0]):\n",
    "                    h_ix = np.random.randint(1,4)\n",
    "                    if while_count == 5:\n",
    "                        found = False\n",
    "                        break\n",
    "                    while_count += 1\n",
    "                if found:\n",
    "                    adj[highlight_nodes[h_ix],highlight_nodes[break_node]] = 1.0\n",
    "                    adj[highlight_nodes[break_node],highlight_nodes[h_ix]] = 1.0\n",
    "                \n",
    "            else:\n",
    "                break_node = np.random.randint(0,4)\n",
    "                assert(adj[highlight_nodes[break_node], highlight_nodes[(break_node+1)%4]] > 0.0)\n",
    "                adj[highlight_nodes[break_node], highlight_nodes[(break_node+1)%4]] = 0.\n",
    "                adj[highlight_nodes[(break_node+1)%4], highlight_nodes[break_node]] = 0.\n",
    "                dest_ix = np.random.randint(nodes)\n",
    "                \n",
    "                while_count = 0\n",
    "                found = True\n",
    "                while(dest_ix == highlight_nodes[break_node]) or (makesPattern(adj, feats,highlight_nodes[break_node],dest_ix)[0]):\n",
    "                    dest_ix = np.random.randint(nodes)\n",
    "                    if while_count == 5:\n",
    "                        found = False\n",
    "                        break\n",
    "                    while_count += 1\n",
    "                if found:\n",
    "                    adj[highlight_nodes[break_node], dest_ix] = 1.0\n",
    "                    adj[dest_ix, highlight_nodes[break_node]] = 1.0\n",
    "                h_ix = np.random.randint(0,4)\n",
    "                while_count = 0\n",
    "                found = True\n",
    "                break_node = (break_node+1)%4\n",
    "                while(h_ix == break_node) or (makesPattern(adj, feats,highlight_nodes[h_ix],highlight_nodes[break_node])[0]):\n",
    "                    h_ix = np.random.randint(0,4)\n",
    "                    if while_count == 5:\n",
    "                        found = False\n",
    "                        break\n",
    "                    while_count += 1\n",
    "                if found:\n",
    "                    adj[highlight_nodes[h_ix],highlight_nodes[break_node]] = 1.0\n",
    "                    adj[highlight_nodes[break_node],highlight_nodes[h_ix]] = 1.0\n",
    "                \n",
    "                \n",
    "                \n",
    "            sub_label = -1\n",
    "            label = 0\n",
    "        \n",
    "        \n",
    "    print(i, label, sub_label)\n",
    "    \n",
    "#     if highlight_nodes != None:\n",
    "#         drawGraph(adj, feats, nodes, highlight_nodes=highlight_nodes)\n",
    "#     else:\n",
    "#         drawGraph(adj, feats, nodes)\n",
    "    feats_data[i] = feats\n",
    "    labels_data[i] = label\n",
    "    sub_labels_data[i] = sub_label\n",
    "    if highlight_nodes is not None:\n",
    "        sub_label_nodes[i] = np.array(highlight_nodes)\n",
    "    num_nodes_data[i] = nodes\n",
    "    adjs_data[i] = adj\n",
    "    \n",
    "\n",
    "synthetic_data = {}\n",
    "synthetic_data['adj'] = adjs_data\n",
    "synthetic_data['feat'] = feats_data\n",
    "synthetic_data['label'] = labels_data\n",
    "synthetic_data['sub_label'] = sub_labels_data\n",
    "synthetic_data['sub_label_nodes'] = sub_label_nodes\n",
    "synthetic_data['num_nodes'] = num_nodes_data\n",
    "\n",
    "pickle.dump(synthetic_data, open(\"../../gcn_interpretation/data/synthetic_data_2label_3sublabel/synthetic_data_hard_1_8000.p\", \"wb\"))\n",
    "        \n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "        \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#add even or odd number patterns\n",
    "#synthetic data 1: with 4 patterns \n",
    "import networkx as nx\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pickle\n",
    "\n",
    "max_nodes = 20\n",
    "def makesPattern1(adj, feats, p, n):\n",
    "    l_p = np.argmax(feats[p]) \n",
    "    l_n = np.argmax(feats[n]) \n",
    "    if l_p not in [0,1,2,3] or l_n not in [0,1,2,3]:\n",
    "        return False\n",
    "    n2p = True\n",
    "    p2n = True\n",
    "    if l_p != (l_n+1)%4:\n",
    "        n2p = False\n",
    "    if l_p != (l_n-1)%4:\n",
    "        p2n = False\n",
    "    if p2n == False and n2p == False:\n",
    "        return False\n",
    "    contend_p = []\n",
    "    if p2n == True:\n",
    "        \n",
    "        for nix in range(max_nodes):\n",
    "            l_x = np.argmax(feats[nix])\n",
    "            if adj[nix,p] > 0. and l_x == (l_p-1)%4:\n",
    "                contend_p.append(nix)\n",
    "        if len(contend_p) == 0:\n",
    "            return False\n",
    "        for nix in range(max_nodes):\n",
    "            l_x = np.argmax(feats[nix])\n",
    "            if adj[nix,n] > 0. and l_x == (l_n+1)%4:\n",
    "                for c_p in contend_p:\n",
    "                    if adj[nix,c_p] > 0.:\n",
    "                        return True\n",
    "\n",
    "        \n",
    "        \n",
    "    else:\n",
    "        for nix in range(max_nodes):\n",
    "            l_x = np.argmax(feats[nix])\n",
    "            if adj[nix,p] > 0. and l_x == (l_p+1)%4:\n",
    "                contend_p.append(nix)\n",
    "        if len(contend_p) == 0:\n",
    "            return False\n",
    "        for nix in range(max_nodes):\n",
    "            l_x = np.argmax(feats[nix])\n",
    "            if adj[nix,n] > 0. and l_x == (l_n-1)%4:\n",
    "                for c_p in contend_p:\n",
    "                    if adj[nix,c_p] > 0.:\n",
    "                        return True\n",
    "    return False\n",
    "                \n",
    "        \n",
    "def makesPattern2(adj, feats, p, n):\n",
    "    #B,C,D among A neighbors\n",
    "    l_p = np.argmax(feats[p]) \n",
    "    l_n = np.argmax(feats[n]) \n",
    "    if l_p not in [0,1,2,3] or l_n not in [0,1,2,3]:\n",
    "        return False\n",
    "   \n",
    "    if l_p != 0 and l_n !=0:\n",
    "        return False\n",
    "    if l_p == l_n:\n",
    "        return False\n",
    "    if l_p == 0:\n",
    "        lbls = [1,2,3]\n",
    "#         print(\"l_n: \", l_n, lbls)\n",
    "        lbls.remove(l_n)\n",
    "        for nix in range(max_nodes):\n",
    "            if adj[p,nix] > 0.:\n",
    "                l_x = np.argmax(feats[nix])\n",
    "                if l_x in lbls:\n",
    "#                     print(\"l_x: \", l_x, lbls)\n",
    "                    lbls.remove(l_x)\n",
    "                    if len(lbls) == 0:\n",
    "                        return True\n",
    "    else:\n",
    "        lbls = [1,2,3]\n",
    "#         print(\"l_p: \", l_p, lbls)\n",
    "        lbls.remove(l_p)\n",
    "        for nix in range(max_nodes):\n",
    "            if adj[n,nix] > 0.:\n",
    "                l_x = np.argmax(feats[nix])\n",
    "                if l_x in lbls:\n",
    "#                     print(\"l_x: \", l_x, lbls)\n",
    "                    lbls.remove(l_x)\n",
    "                    if len(lbls) == 0:\n",
    "                        return True\n",
    "    return False\n",
    "\n",
    "def makesPattern3(adj, feats, p, n):\n",
    "    #D,E,F among C neighbors\n",
    "    l_p = np.argmax(feats[p]) \n",
    "    l_n = np.argmax(feats[n]) \n",
    "    if l_p not in [2,3,4,5] or l_n not in [2,3,4,5]:\n",
    "        return False\n",
    "   \n",
    "    if l_p != 2 and l_n !=2:\n",
    "        return False\n",
    "    if l_p == l_n:\n",
    "        return False\n",
    "    if l_p == 2:\n",
    "        \n",
    "        lbls = [3,4,5]\n",
    "#         print(\"l_n: \", l_n, lbls)\n",
    "        lbls.remove(l_n)\n",
    "        for nix in range(max_nodes):\n",
    "            if adj[p,nix] > 0.:\n",
    "                l_x = np.argmax(feats[nix])\n",
    "                if l_x in lbls:\n",
    "#                     print(\"l_x: \", l_x, lbls)\n",
    "                    lbls.remove(l_x)\n",
    "                    if len(lbls) == 0:\n",
    "                        return True\n",
    "    else:\n",
    "        lbls = [3,4,5]\n",
    "#         print(\"l_p: \", l_p, lbls)\n",
    "\n",
    "        lbls.remove(l_p)\n",
    "        for nix in range(max_nodes):\n",
    "            if adj[n,nix] > 0.:\n",
    "                l_x = np.argmax(feats[nix])\n",
    "                if l_x in lbls:\n",
    "#                     print(\"l_x: \", l_x, lbls)\n",
    "                    lbls.remove(l_x)\n",
    "                    if len(lbls) == 0:\n",
    "                        return True\n",
    "    return False\n",
    "                \n",
    "        \n",
    "        \n",
    "\n",
    "def makesPattern(adj, feats, p, n):\n",
    "    p_l = []\n",
    "    p1 = makesPattern1(adj, feats, p, n)\n",
    "    if p1 is True:\n",
    "        p_l.append(0)\n",
    "    \n",
    "    p2 = makesPattern2(adj, feats, p, n)\n",
    "    if p2 is True:\n",
    "        p_l.append(1)\n",
    "    p3 = makesPattern3(adj, feats, p, n)\n",
    "    if p3 is True:\n",
    "        p_l.append(2)\n",
    "    \n",
    "    return ((len(p_l) > 0), p_l)\n",
    "\n",
    "def addPattern(adj,feats,nodes, sub_label):\n",
    "    degree_sum = np.sum(adj,axis=1)\n",
    "    avg_deg = int(np.sum(degree_sum)/nodes)\n",
    "    if sub_label == 0 or sub_label == 1: #has to do with A,B,C,D\n",
    "        lbls = [0,1,2,3]\n",
    "        for i in range(4): #add 4 nodes\n",
    "            feats[nodes+i,lbls[i]] = 1.\n",
    "        if sub_label == 0:\n",
    "            for i in range(3): #make ring\n",
    "                adj[nodes+i,nodes+i+1] = 1.0\n",
    "                adj[nodes+i+1, nodes+i] = 1.0\n",
    "            adj[nodes+3,nodes] = 1.0\n",
    "            adj[nodes,nodes+3] = 1.0\n",
    "        else:\n",
    "            for i in range(1,4): #make tetra\n",
    "                adj[nodes,nodes+i] = 1.0\n",
    "                adj[nodes+i,nodes] = 1.0\n",
    "    elif sub_label == 2:\n",
    "        lbls = [2,3,4,5]\n",
    "        for i in range(4): \n",
    "            feats[nodes+i,lbls[i]] = 1.\n",
    "        \n",
    "        for i in range(1,4): #make tetra-h with CDEF\n",
    "            adj[nodes,nodes+i] = 1.0\n",
    "            adj[nodes+i,nodes] = 1.0\n",
    "    else: #-1 fake pattern\n",
    "        fake_sub = np.random.randint(2)\n",
    "        if fake_sub == 0: #add ring\n",
    "            lbls = np.random.randint(0,6,size=(4)).tolist()\n",
    "            while True:\n",
    "                sample = True\n",
    "                for jx in range(3):\n",
    "                    if lbls[jx] != lbls[jx+1] - 1:\n",
    "                        sample = False\n",
    "                        break\n",
    "                if sample:\n",
    "                    lbls = np.random.randint(0,6,size=(4)).tolist()\n",
    "                else:\n",
    "                    break\n",
    "            for i in range(4):\n",
    "                feats[nodes+i,lbls[i]] = 1.\n",
    "            for i in range(3):\n",
    "                adj[nodes+i,nodes+i+1] = 1.0\n",
    "                adj[nodes+i+1, nodes+i] = 1.0\n",
    "            adj[nodes+3,nodes] = 1.0\n",
    "            adj[nodes,nodes+3] = 1.0\n",
    "            \n",
    "        else: #add tetra\n",
    "            lbls = np.random.randint(0,6,size=(4)).tolist()\n",
    "            while True:\n",
    "                sample = False\n",
    "                if lbls[0] == 0:\n",
    "                    if 1 in lbls and 2 in lbls and 3 in lbls:\n",
    "                        sample = True\n",
    "                    else:\n",
    "                        break\n",
    "                        \n",
    "                elif lbls[0] == 2:\n",
    "                    if 3 in lbls and 4 in lbls and 5 in lbls:\n",
    "                        sample = True\n",
    "                    else:\n",
    "                        break\n",
    "                if sample:\n",
    "                    lbls = np.random.randint(0,6,size=(4)).tolist()\n",
    "                else:\n",
    "                    break\n",
    "            \n",
    "            for i in range(4):\n",
    "                feats[nodes+i,lbls[i]] = 1.\n",
    "        \n",
    "            for i in range(1,4):\n",
    "                adj[nodes,nodes+i] = 1.0\n",
    "                adj[nodes+i,nodes] = 1.0\n",
    "\n",
    "\n",
    "        \n",
    "    for i in range(4):\n",
    "        deg_exp = np.random.randint(avg_deg-2, avg_deg+2)\n",
    "        dest_n = np.random.randint(nodes)\n",
    "        while_count = 0\n",
    "        skip = False\n",
    "        while(makesPattern(adj, feats, dest_n, nodes+i)[0] == True):\n",
    "            dest_n = np.random.randint(nodes)\n",
    "            while_count += 1\n",
    "            if while_count == 5:\n",
    "                skip = True\n",
    "                break\n",
    "        if skip:\n",
    "            continue\n",
    "            \n",
    "        adj[nodes+i, dest_n] = 1.0\n",
    "        adj[dest_n, nodes+i] = 1.0\n",
    "        deg = int(np.sum(adj[nodes+i,:]))\n",
    "        if deg_exp > deg:\n",
    "            for e in range(deg_exp-deg):\n",
    "                dest_n = np.random.randint(nodes+4)\n",
    "                if adj[nodes+i, dest_n] > 0:\n",
    "                    continue\n",
    "                if(makesPattern(adj, feats, dest_n, nodes+i)[0] == True):\n",
    "                    continue\n",
    "                adj[nodes+i, dest_n] = 1.0\n",
    "                adj[dest_n, nodes+i] = 1.0\n",
    "        \n",
    "    \n",
    "    pos_t = list(range(nodes+4))\n",
    "    pos_covered = []\n",
    "#     for i in range(4):\n",
    "        \n",
    "#         if np.random.rand() < 0.4:\n",
    "#             continue\n",
    "        \n",
    "#         dest_pos = np.random.randint(nodes)\n",
    "#         if dest_pos in pos_covered:\n",
    "#             continue\n",
    "        \n",
    "        \n",
    "        \n",
    "#         temp_feats = np.copy(feats[nodes+i,:])\n",
    "#         feats[nodes+i,:] = feats[dest_pos,:]\n",
    "#         feats[dest_pos,:] = temp_feats\n",
    "        \n",
    "        \n",
    "#         temp_adj1 = np.copy(adj[nodes+i,:])\n",
    "#         temp_adj2 = np.copy(adj[:,nodes+i])\n",
    "\n",
    "#         adj[nodes+i,:] = adj[dest_pos,:]\n",
    "#         adj[:,nodes+i] = adj[:,dest_pos]\n",
    "#         adj[dest_pos,:] = temp_adj1\n",
    "#         adj[:,dest_pos] = temp_adj2\n",
    "        \n",
    "        \n",
    "#         pos_t[nodes+i] = dest_pos\n",
    "#         pos_t[dest_pos] = nodes + i\n",
    "#         pos_covered.append(dest_pos)\n",
    "      \n",
    "        \n",
    "        \n",
    "#     print(\"nodes: \", nodes+4, pos_t[-4:])\n",
    "        \n",
    "        \n",
    "        \n",
    "\n",
    "    return adj, feats, pos_t[nodes:]\n",
    "                            \n",
    "            \n",
    "def drawGraph(adj, feats, nodes, highlight_nodes=None, sublabel_d=None):\n",
    "    node_labels = ['A','B','C','D','E','F']\n",
    "    G_class = nx.from_numpy_array(adj[:nodes,:nodes])\n",
    "\n",
    "    fig, ax_l = plt.subplots(1,1, figsize=(15,10))\n",
    "    colors = []\n",
    "    for n in range(nodes):\n",
    "        colors.append((0.9,0.9,0.9))\n",
    "    if highlight_nodes is not None:\n",
    "        for ix in range(highlight_nodes.shape[0]):\n",
    "            red = np.random.rand()\n",
    "            green = np.random.rand()\n",
    "            blue = np.random.rand()  \n",
    "            if sublabel_d[ix] != -1:\n",
    "                for gx in range(highlight_nodes.shape[1]):\n",
    "                    for h_n in highlight_nodes[ix,gx]:\n",
    "                        if h_n == -1:\n",
    "                            continue\n",
    "#                         print(\"node\", h_n, red, green, blue)\n",
    "                        colors[h_n] = (red,green,blue)\n",
    "    labels_dict = {}\n",
    "    for n in range(nodes):\n",
    "        lb = np.argmax(feats[n,:])\n",
    "        labels_dict[n] = node_labels[lb]\n",
    "        \n",
    "#     colors[0] = (0.9,0.1,0.1)\n",
    "#     nx.draw_networkx(G_class, ax=ax_l, node_color=colors)\n",
    "    nx.draw_networkx(G_class,labels=labels_dict, ax=ax_l, node_color = colors)\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "total_adds = 4\n",
    "len_data = 4000\n",
    "max_repeat = 1\n",
    "max_nodes = total_adds*max_repeat*4 + 8\n",
    "feat_dim = 6\n",
    "feats_data = np.zeros((len_data, max_nodes, feat_dim))\n",
    "adjs_data = np.zeros((len_data, max_nodes, max_nodes))\n",
    "labels_data = np.zeros((len_data),dtype=np.int32)\n",
    "sub_labels_data = np.zeros((len_data,total_adds),dtype=np.int32) - 1\n",
    "sub_label_nodes = np.zeros((len_data,total_adds,max_repeat,4),dtype=np.int32) - 1\n",
    "num_nodes_data = np.zeros((len_data),dtype=np.int32)\n",
    "\n",
    "#A/B/C/D/E/F\n",
    "for i in range(len_data):\n",
    "# for i in range(20):\n",
    "    adj = np.zeros((max_nodes, max_nodes))\n",
    "    feats = np.zeros((max_nodes, feat_dim))\n",
    "    graph_label = np.random.randint(4) #distinct_patterns\n",
    "    sub_labels_covered = []\n",
    "    label_count = 0\n",
    "    \n",
    "    nodes = np.random.randint(4, 8)\n",
    "    \n",
    "    for n_ix in range(nodes): # add nodes\n",
    "        l_n = np.random.randint(0,feat_dim)\n",
    "        feats[n_ix, l_n] = 1.0\n",
    "\n",
    "        if n_ix == 0:\n",
    "            continue\n",
    "        p_node = np.random.randint(n_ix) \n",
    "        while makesPattern(adj, feats, p_node, n_ix)[0] == True: #connect node to graph such that\n",
    "            p_node = np.random.randint(n_ix)                    #it doesn't result in any pattern\n",
    "\n",
    "        adj[p_node,n_ix] = 1.\n",
    "        adj[n_ix, p_node] = 1.\n",
    "\n",
    "    max_edges = min(26,int((nodes*nodes-1)/4))\n",
    "    if max_edges > nodes:\n",
    "        edge_total = np.random.randint(nodes, max_edges)\n",
    "        for e_ix in range(edge_total-nodes+1):\n",
    "            rand_nix = np.random.randint(nodes)\n",
    "            rand_pix = np.random.randint(nodes)\n",
    "            if rand_nix == rand_pix:\n",
    "                continue\n",
    "            if not makesPattern(adj, feats, rand_pix, rand_nix)[0]: #add more random edges\n",
    "                adj[rand_nix, rand_pix] = 1.0\n",
    "                adj[rand_pix, rand_nix] = 1.0\n",
    "    \n",
    "    \n",
    "    for g_i in range(total_adds):\n",
    "        if label_count == graph_label:\n",
    "            label = 0\n",
    "        elif total_adds-g_i == graph_label-label_count:\n",
    "            label = 1\n",
    "            \n",
    "        else:\n",
    "            label = np.random.randint(2)\n",
    "       \n",
    "        \n",
    "        if label == 1:\n",
    "            \n",
    "            label_count += 1\n",
    "            sub_label = np.random.randint(3)\n",
    "            while(sub_label in sub_labels_covered):\n",
    "                sub_label = np.random.randint(3)\n",
    "            sub_labels_covered.append(sub_label)\n",
    "        else:\n",
    "            sub_label = -1\n",
    "#         print(\"g_i: \", g_i, \"label: \", label, \"sub_label,: \", sub_label)\n",
    "        if max_repeat > 1:\n",
    "            repeat = np.random.randint(1,max_repeat)\n",
    "        else:\n",
    "            repeat = 1\n",
    "        assert(repeat > 0)\n",
    "        for rep in range(repeat):\n",
    "            \n",
    "            add_fake = False\n",
    "            distort_pattern = False\n",
    "            if label == 0:\n",
    "                if np.random.rand() < 0.3:\n",
    "                    add_fake = True\n",
    "                else:\n",
    "                    distort_pattern = True\n",
    "                    sub_label = np.random.randint(3)\n",
    "                    \n",
    "\n",
    "            \n",
    "\n",
    "\n",
    "            #0 or 1\n",
    "            highlight_nodes = None\n",
    "\n",
    "            if label == 0 and add_fake:\n",
    "\n",
    "                adj, feats, highlight_nodes = addPattern(adj,feats,nodes, sub_label)\n",
    "                nodes = nodes + 4\n",
    "\n",
    "                #done\n",
    "            else:\n",
    "                \n",
    "                adj, feats, highlight_nodes = addPattern(adj,feats,nodes, sub_label)\n",
    "                nodes = nodes + 4\n",
    "                ####this code will alter the pattern making few changes###########\n",
    "                if distort_pattern:\n",
    "                    if sub_label == 1 or sub_label == 2:\n",
    "                        break_node = np.random.randint(1,4)\n",
    "                        assert( adj[highlight_nodes[0], highlight_nodes[break_node]] > 0.0)\n",
    "                        adj[highlight_nodes[0], highlight_nodes[break_node]] = 0.\n",
    "                        adj[highlight_nodes[break_node], highlight_nodes[0]] = 0.\n",
    "                        dest_ix = np.random.randint(nodes)\n",
    "                        while_count = 0\n",
    "                        found = True\n",
    "                        while(dest_ix == highlight_nodes[0]) or (makesPattern(adj, feats,highlight_nodes[0],dest_ix)[0]):\n",
    "                            dest_ix = np.random.randint(nodes)\n",
    "                            if while_count == 5:\n",
    "                                found = False\n",
    "                                break\n",
    "                            while_count += 1\n",
    "                        if found:\n",
    "                            adj[highlight_nodes[0], dest_ix] = 1.0\n",
    "                            adj[dest_ix, highlight_nodes[0]] = 1.0\n",
    "                        h_ix = np.random.randint(1,4)\n",
    "                        while_count = 0\n",
    "                        found = True\n",
    "                        while(h_ix == break_node) or (makesPattern(adj, feats,highlight_nodes[h_ix],highlight_nodes[break_node])[0]):\n",
    "                            h_ix = np.random.randint(1,4)\n",
    "                            if while_count == 5:\n",
    "                                found = False\n",
    "                                break\n",
    "                            while_count += 1\n",
    "                        if found:\n",
    "                            adj[highlight_nodes[h_ix],highlight_nodes[break_node]] = 1.0\n",
    "                            adj[highlight_nodes[break_node],highlight_nodes[h_ix]] = 1.0\n",
    "\n",
    "                    else:\n",
    "                        break_node = np.random.randint(0,4)\n",
    "                        assert(adj[highlight_nodes[break_node], highlight_nodes[(break_node+1)%4]] > 0.0)\n",
    "                        adj[highlight_nodes[break_node], highlight_nodes[(break_node+1)%4]] = 0.\n",
    "                        adj[highlight_nodes[(break_node+1)%4], highlight_nodes[break_node]] = 0.\n",
    "                        dest_ix = np.random.randint(nodes)\n",
    "\n",
    "                        while_count = 0\n",
    "                        found = True\n",
    "                        while(dest_ix == highlight_nodes[break_node]) or (makesPattern(adj, feats,highlight_nodes[break_node],dest_ix)[0]):\n",
    "                            dest_ix = np.random.randint(nodes)\n",
    "                            if while_count == 5:\n",
    "                                found = False\n",
    "                                break\n",
    "                            while_count += 1\n",
    "                        if found:\n",
    "                            adj[highlight_nodes[break_node], dest_ix] = 1.0\n",
    "                            adj[dest_ix, highlight_nodes[break_node]] = 1.0\n",
    "                        h_ix = np.random.randint(0,4)\n",
    "                        while_count = 0\n",
    "                        found = True\n",
    "                        break_node = (break_node+1)%4\n",
    "                        while(h_ix == break_node) or (makesPattern(adj, feats,highlight_nodes[h_ix],highlight_nodes[break_node])[0]):\n",
    "                            h_ix = np.random.randint(0,4)\n",
    "                            if while_count == 5:\n",
    "                                found = False\n",
    "                                break\n",
    "                            while_count += 1\n",
    "                        if found:\n",
    "                            adj[highlight_nodes[h_ix],highlight_nodes[break_node]] = 1.0\n",
    "                            adj[highlight_nodes[break_node],highlight_nodes[h_ix]] = 1.0\n",
    "\n",
    "\n",
    "\n",
    "                    sub_label = -1\n",
    "                    label = 0\n",
    "                    #24 => 48 added nodes\n",
    "            if sub_label != -1:\n",
    "                print(highlight_nodes)\n",
    "            sub_label_nodes[i, g_i, rep] = np.array(highlight_nodes)\n",
    "            \n",
    "        sub_labels_data[i,g_i] = sub_label\n",
    "    \n",
    "            \n",
    "\n",
    "\n",
    "    print(i, graph_label)\n",
    "            \n",
    "\n",
    "#     if highlight_nodes != None:\n",
    "#         drawGraph(adj, feats, nodes, highlight_nodes=sub_label_nodes[i], sublabel_d=sub_labels_data[i])\n",
    "#     else:\n",
    "#         drawGraph(adj, feats, nodes)\n",
    "    \n",
    "    feats_data[i] = feats\n",
    "    labels_data[i] = graph_label\n",
    "    \n",
    "    num_nodes_data[i] = nodes\n",
    "    \n",
    "    adjs_data[i] = adj\n",
    "    \n",
    "\n",
    "synthetic_data = {}\n",
    "synthetic_data['adj'] = adjs_data\n",
    "synthetic_data['feat'] = feats_data\n",
    "synthetic_data['label'] = labels_data\n",
    "synthetic_data['sub_label'] = sub_labels_data\n",
    "synthetic_data['sub_label_nodes'] = sub_label_nodes\n",
    "synthetic_data['num_nodes'] = num_nodes_data\n",
    "\n",
    "pickle.dump(synthetic_data, open(\"../../gcn_interpretation/data/synthetic_data_2label_3sublabel/synthetic_data_4000_comb_norep.p\", \"wb\"))\n",
    "        \n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "        \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "data = synthetic_data\n",
    "train_idx = 3000\n",
    "X_train_feat = torch.from_numpy(data['feat'][:train_idx]).float()\n",
    "X_train_adj = torch.from_numpy(data['adj'][:train_idx]).float()\n",
    "X_train_nodes = torch.from_numpy(data['num_nodes'][:train_idx])\n",
    "y_train_label = torch.from_numpy(data['label'][:train_idx])\n",
    "# data['sub_label']\n",
    "\n",
    "X_val_feat = torch.from_numpy(data['feat'][train_idx:]).float()\n",
    "X_val_adj = torch.from_numpy(data['adj'][train_idx:]).float()\n",
    "X_val_nodes = torch.from_numpy(data['num_nodes'][train_idx:])\n",
    "y_val_label = torch.from_numpy(data['label'][train_idx:])\n",
    "\n",
    "tensor_data_train = (X_train_adj, X_train_feat, y_train_label, X_train_nodes)\n",
    "tensor_data_val = (X_val_adj, X_val_feat, y_val_label, X_val_nodes)\n",
    "\n",
    "import torch\n",
    "torch.save(tensor_data_train, \"./data/synthetic/synthetic_train_4000_norep.pth\")\n",
    "torch.save(tensor_data_val, \"./data/synthetic/synthetic_val_4000_norep.pth\")\n",
    "\n",
    "#3 cases\n",
    "#ABCD form a ring\n",
    "#or\n",
    "#C has D, E and F among its 3 neighbors\n",
    "#or\n",
    "#A has B,C,D among its 3 neighbors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import numpy as np\n",
    "synthetic_data = pickle.load(open(\"../../gcn_interpretation/data/synthetic_data_2label_3sublabel/synthetic_data_4000_comb_norep.p\", \"rb\"))\n",
    "sublabel_array = synthetic_data['sub_label']\n",
    "sublabel_nodes_array = synthetic_data['sub_label_nodes']\n",
    "label = synthetic_data['label']\n",
    "\n",
    "highlight_nodes = {}\n",
    "for ix in range(sublabel_array.shape[0]):\n",
    "    if ix not in highlight_nodes:\n",
    "        highlight_nodes[ix] = []\n",
    "    for hx in range(sublabel_array.shape[1]):\n",
    "        if sublabel_array[ix,hx] == -1:\n",
    "            continue\n",
    "        for n in sublabel_nodes_array[ix,hx,0]:\n",
    "            highlight_nodes[ix].append(n)\n",
    "    assert(len(highlight_nodes[ix]) == 4*label[ix])\n",
    "            \n",
    "    \n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAfAAAAI4CAYAAACV/7uiAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3dedxVZb3//9cHbpGDCjigKbeJODEldwiiHY+JBipSilOa5ySGv/SoZY551E5kmuaQWllmOeTJo6ZfTDSHOCiVRRIoGIqoOSIORIKCEgjX74+9oM3NPYHse3PdvJ6Px36w17Wmz1r33rz3utbaa0dKCUmSlJd21S5AkiStOQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuNSEijouI31S7jhUi4l8i4r6IWBARd1W7HjUvIkZFxGMtnPb8iPhZpWtS22CAq1VExBciYkpELIyINyLiwYjYp9p1NSeldFtKaVi16yhzJLANsGVK6ahqF7OuRclpEfFURLwfEW9GxMSIOKZsmokRsTgiti9r+0xEvFw2/HJEvB0Rm5S1nRgRE5tY9y0RsaR4ja54TF/3W9m4lNJ3UkontuY6lS8DXBUXEWcC1wDfoRQ+Hwd+BBxazbqaExE11a6hATsAz6WUPqz0iiKifaXX0YDvA18DzgK2BLoDFwIH1ZtuEfCNZpbVHjh9Ddd/eUpp07JH/zWcf720nr6W9REZ4KqoiOgCXAScmlIam1JalFJamlK6L6V0TjHNxhFxTUTMKR7XRMTGxbj9ImJ2RJxbHFG9ERGHRcTwiHguIv4eEeeXrW9MRNwdEXdGxHsR8URE9C8bf15E/LUY90xEjCwbNyoi/hARV0fEPGBMefdncXR4dVHHuxHxl4jot2I7I+LWiJgbEa9ExIUR0a5suY9FxJUR8U5EvBQRBzexz3oXR5nzI+LpiPhc0f4t4L+BzxdHh6MbmLepfblaV25EpIjYuXh+S0T8OCIeiIhFwJBiPz9T7K/XI+LsRtY5f8W+KNq6RcQHEbF1RGwVEfcX0/w9In6/Yt/UW86uwCnAMSml8SmlD1JKy1JKj6WURtWb/PvAsRGxU2P7EbgCODsiujYxTYtExOeLv1vnYvjgonegWzGcIuKrEfFiRPwtIq5oaBuLaa+NiNeK19DUiPi3snFjIuIXxfMexXKPj4hXi+VeUDZtu7LX87yI+GVEbFFv3tER8SrwyEfdB1r/GOCqtL2BjsA9TUxzAbAXUAf0B/akdNS1wseKZXSnFGA/Bf4d2AP4N+AbEbFj2fSHAncBWwD/C/wqIjYqxv21mKcL8C3gFxGxbdm8g4EXKfUUXFKvzmHAvsCuxfxHA/OKcT8o2noCnwa+CJxQb7mzgK2Ay4EbIyLq74iizvuA3wBbA18BbouI3VJK36TUi3FncXR4Y/35aX5fNucLxXZvBjwG3AiclFLaDOhHA0GQUvoHMBY4tqz5aOC3KaW3KR1Nzwa6Udqv5wMN3cN5f+C1lNKUFtT5OqXXwbeamGYKMBFY7UPHmkop3Qn8Efh+RGxJab+cmFKaWzbZSGAgMIDSa/BLjSzuz5T+Piten3dFRMcmVr8PsBtwAPDfEdG7aP8KcBil19t2wDvAdfXm/TTQGziwBZup3KSUfPio2AM4DnizmWn+CgwvGz4QeLl4vh/wAdC+GN6M0n/+g8umnwocVjwfA/ypbFw74A3g3xpZ9zTg0OL5KODVeuNHAY8Vz/cHnqMUkO3KpmkPLAH6lLWdBEwsW8YLZeM6FdvwsQbq+TfgzXrLvx0YU7Z9v1jLfblyW8rGJ2Dn4vktwK31xr9abEvnZv6GnwH+Wjb8B+CLxfOLgHtXrKeJZVxY/rcr2mYD84HFwA5F20TgREofCBYAfYv1v1w238tFW79imm7FPBObWP8txXrmlz1+Xja+a7E//gL8pIH9eFDZ8CnAhMb2e7153wH61//7Aj2K5daWTTuZUg8FwEzggLJx2wJLgZqyeXtW4n3tY/14eASuSpsHbBVNn4PbDnilbPiVom3lMlJKy4rnHxT/vlU2/gNg07Lh11Y8SSktpxQC2wFExBcjYlrRnTuf0n/wWzU0b30ppUeAH1I6ynk7Im4oulS3AjZqYBu6lw2/Wbac94un5TWvsB2lo9DlTSyrKc3ty+bU3/4jgOHAKxHx24jYu5H5HgU6RcTgiOhB6QhzRa/LFcALwG+KLubzGlnGPEohtFJKqZbS/t0YiHrj5lL6e1zU2MaklGYA9wOrrDNKV3uvuFDt+rJRV6aUupY9ji9b1nxKPTv9gKsaWF35vmt0v0fE2RExM0rfJJhPqedmq4amLbxZ9vx9/vm62QG4p+y1PBNYRqmXo6Ga1MYY4Kq0ScA/KHX1NWYOpf+MVvh40ba2yq9ObgfUAnMiYgdK3a6nUbqKuyswg1WDocmf50spfT+ltAfQh1JX+jnA3ygd+dTfhtfXovY5wPb1zp+uybKa2peLKB39AxARH2tg/lW2P6X055TSoZS6838F/LKhlRYfsH5JqRv9WOD+lNJ7xbj3UkpnpZR6Ap8DzoyIAxpYzCNAbUQMbHYr/+kKYAil0ymN+Sbw/1H2ISiVrvZecaHayS1ZUUTUUeoWv53SOfj6ti973uBruDjffS6lUwybF6/BBdT7cNJCrwEH1/vA0TGlVP5a8ecm2zADXBWVUlpA6bz1dVG6+KxTRGxUXAR0eTHZ7cCFxYVPWxXT/+IjrHaPiDi8OOr/GqUPEH8CNqH0H9pcgIg4gdLRVItExKDiCHMjSmG4GFheFl6XRMRmxQeFM9dyGx6ndJR1brGf9gM+C9zRwvmb2pfTgb4RUVeccx3T1IIiokOUvgffJaW0FHgXWN7ELP8LfJ7SaZP/LVvOiIjYuTjnv4DSUeJqy0kpzQJ+AtwREUOj9J339sCnGlthcVR8FaVQbGyaF4A7ga82UXuTiv31C0rn708AukfEKfUmOyciNo/S19tOL9ZZ32bAh5RegzUR8d9A57Us63pKr7kdihq7RcR6/c0OrVsGuCoupXQVpUC7kNJ/XK9ROgr+VTHJxZQuOHqK0vnFJ4q2tXUvpSB5B/gP4PBUuvL9GUr/2U+i1AX/CUrnaluqM6Uj+HcodZHOo3QECKULihZRugDuMUoBdtOaFp5SWkIpsA+mdGT/I0rnkp9t4SIa3ZcppecodTf/H/B8UWdz/gN4OSLeBU6mFM6N1f44pX2wHfBg2ahdinUupLTvf5RSerSRxZxK6ej2e8DfKZ3++Dalv+erjcxzLaUPBU25iNIHuOacG6t+D/xvRfullE5t/DiVLtr7d+DiiNilbN57KV2PMQ34NaUL3ep7GHiI0rUUr1D6ELi23dzXAuMonZp4j9KH1MFruSxlKFKyh0VtR0SMoXSx1L9XuxZtOCIiAbsUR/tSq/AIXJKkDBngkiRlyC50SZIy5BG4JEkZyuIG91tttVXq0aNHtcuQJKnVTZ069W8ppW7127MI8B49ejBlSktujyxJUtsSEa801G4XuiRJGTLAJUnKkAEuSVKGsjgH3pClS5cye/ZsFi9eXO1S1AIdO3aktraWjTbaqPmJJUnNyjbAZ8+ezWabbUaPHj0o/UaC1lcpJebNm8fs2bPZcccdq12OJLUJ2XahL168mC233NLwzkBEsOWWW9pbIknrULYBDhjeGfFvJUnrVtYBLknShirbc+D19Tjv1+t0eS9fdkiz07z11lucccYZ/OlPf2LzzTenQ4cOnHvuuWy++eYMGTKEcePG8dnPfhaAESNGcPbZZ7Pffvux3377sXDhwpU3p5kyZQpnn302EydOXG0do0aN4re//S1dunQBoFOnTvzxj39c4+2ZOHEiV155Jffff3+j00yZMoVbb72V73//+2u8fElS6/IIfC2llDjssMPYd999efHFF5k6dSp33HEHs2fPBqC2tpZLLrmk0fnffvttHnzwwRat64orrmDatGlMmzZtrcK7pQYOHLjOwnvZsmXrZDmSpIYZ4GvpkUceoUOHDpx88skr23bYYQe+8pWvANC/f3+6dOnC+PHjG5z/nHPOaTLgm3P66adz0UUXAfDwww+z7777snz5ckaNGsXJJ5/MwIED2XXXXRs84p48eTJ77703n/zkJ/nUpz7FrFmzgNJR+ogRIwAYM2YMX/rSl9hvv/3o2bPnKsH+i1/8gj333JO6ujpOOumklWG96aabctZZZ9G/f38mTZq01tsmSWqeAb6Wnn76aQYMGNDkNBdccAEXX3xxg+P23ntvOnTowKOPPtrsus455xzq6uqoq6vjuOOOA+DSSy/lzjvv5NFHH+WrX/0qN998M+3alf6cL7/8MpMnT+bXv/41J5988mpXf/fq1Yvf//73PPnkk1x00UWcf/75Da732Wef5eGHH2by5Ml861vfYunSpcycOZM777yTP/zhD0ybNo327dtz2223AbBo0SIGDx7M9OnT2WeffZrdLkkCuPbaa+nXrx99+/blmmuuAUr/7/Xq1Yvdd9+dkSNHMn/+/Abnfeihh9htt93Yeeedueyyy1qz7KozwNeRU089lf79+zNo0KCVbfvuuy8Ajz32WIPzXHjhhY0GfLnyLvQVYdmpUyd++tOfMnToUE477TR22mmnldMfffTRtGvXjl122YWePXvy7LPPrrK8BQsWcNRRR9GvXz/OOOMMnn766QbXe8ghh7Dxxhuz1VZbsfXWW/PWW28xYcIEpk6dyqBBg6irq2PChAm8+OKLALRv354jjjii2e2RpBVmzJjBT3/6UyZPnsz06dO5//77eeGFFxg6dCgzZszgqaeeYtddd+XSSy9dbd5ly5Zx6qmn8uCDD/LMM89w++2388wzz1RhK6rDAF9Lffv25Yknnlg5fN111zFhwgTmzp27ynRNHYXvv//+fPDBB/zpT39a2XbCCSdQV1fH8OHDm63hL3/5C1tuuSVz5sxZpb3+V7bqD3/jG99gyJAhzJgxg/vuu6/R72dvvPHGK5+3b9+eDz/8kJQSxx9//MoPFLNmzWLMmDFA6W5r7du3b7ZuSVph5syZDB48mE6dOlFTU8OnP/1pxo4dy7Bhw6ipKV1nvddee628vqjc5MmT2XnnnenZsycdOnTgmGOO4d57723tTagaA3wt7b///ixevJgf//jHK9vef//91aYbNmwY77zzDk899VSDy7nwwgu5/PLLVw7ffPPNTJs2jQceeKDJ9b/yyitcddVVPPnkkzz44IM8/vjjK8fdddddLF++nL/+9a+8+OKL7LbbbqvMu2DBArp37w7ALbfc0uy2ljvggAO4++67efvttwH4+9//ziuvNPhLd5LUrH79+vH73/+eefPm8f777/PAAw/w2muvrTLNTTfdxMEHH7zavK+//jrbb7/9yuHa2lpef/31ite8vmgzXyNryde+1qWI4Fe/+hVnnHEGl19+Od26dWOTTTbhu9/97mrTXnDBBRx66KENLmf48OF067ba77Sv4pxzzlnlKP7xxx9n9OjRXHnllWy33XbceOONjBo1ij//+c8AfPzjH2fPPffk3Xff5frrr6djx46rLO/cc8/l+OOP5+KLL+aQQ9Zsv/Xp04eLL76YYcOGsXz5cjbaaCOuu+46dthhhzVajiQB9O7dm69//esMGzaMTTbZhLq6ulV68i655BJqampWXv+jf4qUUrVraNbAgQPTiu9MrzBz5kx69+5dpYrWX6NGjWLEiBEceeSR1S5lNf7NJDXn/PPPp7a2llNOOYVbbrmFn/zkJ0yYMIFOnTqtNu2kSZMYM2YMDz/8MMDK8+T/9V//1ao1V1pETE0pDazfbhe6JKmqVpySe/XVVxk7dixf+MIXeOihh7j88ssZN25cg+ENMGjQIJ5//nleeukllixZwh133MHnPve51iy9qtpMF7pK1vSctiRV2xFHHMG8efNWnpLr2rUrp512Gv/4xz8YOnQoULqQ7frrr2fOnDmceOKJPPDAA9TU1PDDH/6QAw88kGXLlvGlL32Jvn37VnlrWo9d6Go1/s0kac3ZhS5JUhtigEuSlCEDXJKkDLWdi9jGdFnHy1vQ7CSzZ8/m1FNP5ZlnnmH58uWMGDGCK664gg4dOqzbWiSpUtb1/50buhZkx7riEfhaSilx+OGHc9hhh/H888/z3HPPsXDhQi644IIWL8Of3JQkrS0DfC098sgjdOzYkRNOOAEo3Sv86quv5qabbuJHP/oRp5122sppR4wYwcSJE4HVf3LzvPPOo0+fPuy+++6cffbZ1dgUSVKG2k4Xeit7+umn2WOPPVZp69y5Mx//+Mf58MMPG51vxU9uXnXVVcybN4/Ro0fz7LPPEhGN/lyeJEn1eQTeysp/crNLly507NiR0aNHM3bs2EbvNiRJUn0G+Frq06cPU6dOXaXt3Xff5dVXX6Vr164sX758ZXv5z3WW/+RmTU0NkydP5sgjj+T+++/noIMOap3iJUnZM8DX0gEHHMD777/PrbfeCpQuSDvrrLMYNWoUPXv2ZNq0aSxfvpzXXnuNyZMnN7iMhQsXsmDBAoYPH87VV1/N9OnTW3MTJEkZazvnwFvx0n0o/ZzoPffcwymnnMK3v/1tli9fzvDhw/nOd75Dhw4d2HHHHenTpw+9e/dmwIABDS7jvffe49BDD2Xx4sWklPje977XqtsgScpX2wnwKth+++257777Ghx32223Ndi+cOHClc+33XbbRo/OJUlqil3okiRlyACXJClDWQd4Dj+FqhL/VpK0bmUb4B07dmTevHkGQwZSSsybN4+OHTtWuxRJajOyvYittraW2bNnM3fu3GqXohbo2LEjtbW11S5DktqMbAN8o402Yscdd6x2GZIkVUW2XeiSJG3IDHBJkjJkgEuSlCEDXJKkDBngkiRlyACXJClDBrgkSRkywCVJypABLklShgxwSZIyZIBLkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZMsAlScqQAS5JUoYMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuSVKGDHBJkjJkgEuSlCEDXJKkDBngkiRlyACXJClDBrgkSRkywCVJypABLklShgxwSZIyZIBLkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZMsAlScqQAS5JUoYMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuCRJGapogEfEGRHxdETMiIjbI6JjROwYEY9HxAsRcWdEdKhkDZIktUUVC/CI6A58FRiYUuoHtAeOAb4LXJ1S2hl4BxhdqRokSWqrKt2FXgP8S0TUAJ2AN4D9gbuL8T8HDqtwDZIktTkVC/CU0uvAlcCrlIJ7ATAVmJ9S+rCYbDbQvaH5I+LLETElIqbMnTu3UmVKkpSlSnahbw4cCuwIbAdsAhzU0vlTSjeklAamlAZ269atQlVKkpSnSnahfwZ4KaU0N6W0FBgL/CvQtehSB6gFXq9gDZIktUmVDPBXgb0iolNEBHAA8AzwKHBkMc3xwL0VrEGSpDapkufAH6d0sdoTwF+Kdd0AfB04MyJeALYEbqxUDZIktVU1zU+y9lJK3wS+Wa/5RWDPSq5XkqS2zjuxSZKUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuSVKGDHBJkjJkgEuSlCEDXJKkDBngkiRlyACXJClDBrgkSRkywCVJypABLklShgxwSZIyZIBLkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZMsAlScqQAS5JUoYMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJWkNzJo1i7q6upWPzp07c80113DXXXfRt29f2rVrx5QpUxqcd/Hixey5557079+fvn378s1vfrOVq1dbUlPtAiQpJ7vtthvTpk0DYNmyZXTv3p2RI0fy/vvvM3bsWE466aRG591444155JFH2HTTTVm6dCn77LMPBx98MHvttVdrla82xACXpLU0YcIEdtppJ3bYYYcWTR8RbLrppgAsXbqUpUuXEhGVLFFtmF3okrSW7rjjDo499tg1mmfZsmXU1dWx9dZbM3ToUAYPHlyh6tTWGeCStBaWLFnCuHHjOOqoo9Zovvbt2zNt2jRmz57N5MmTmTFjRoUqVFtngEvSWnjwwQcZMGAA22yzzVrN37VrV4YMGcJDDz20jivThsIAl6S1cPvtt69x9/ncuXOZP38+AB988AHjx4+nV69elShPGwADXJLW0KJFixg/fjyHH374yrZ77rmH2tpaJk2axCGHHMKBBx4IwJw5cxg+fDgAb7zxBkOGDGH33Xdn0KBBDB06lBEjRlRlG5S/SClVu4ZmDRw4MDX2vUpJ0kcwpku1K2hbxixY54uMiKkppYH12z0ClyQpQwa4JEkZMsAlScqQd2KTlJUe5/262iW0KS93rHYFWlsegUuSlCEDXJKkDBngkiRlyACXJClDBrgkSRkywCVJypABLklShgxwSZIyZIBLkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZMsAlScqQAS5JUoYMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuSVKGDHBJkjJkgEuSlCEDXJKkDBngkiRlyACXJClDBrgkSRkywCVJypABLklShgxwSZIyZIBLkpQhA1xtwvz58znyyCPp1asXvXv3ZtKkSZxzzjn06tWL3XffnZEjRzJ//vzV5ps1axZ1dXUrH507d+aaa66pwhZI0poxwNUmnH766Rx00EE8++yzTJ8+nd69ezN06FBmzJjBU089xa677sqll1662ny77bYb06ZNY9q0aUydOpVOnToxcuTIKmyBJK0ZA1zZW7BgAb/73e8YPXo0AB06dKBr164MGzaMmpoaAPbaay9mz57d5HImTJjATjvtxA477FDxmiXpozLAlb2XXnqJbt26ccIJJ/DJT36SE088kUWLFq0yzU033cTBBx/c5HLuuOMOjj322EqWKknrjAGu7H344Yc88cQT/Od//idPPvkkm2yyCZdddtnK8Zdccgk1NTUcd9xxjS5jyZIljBs3jqOOOqo1Spakj8wAV/Zqa2upra1l8ODBABx55JE88cQTANxyyy3cf//93HbbbUREo8t48MEHGTBgANtss02r1CxJH5UBrux97GMfY/vtt2fWrFlA6Vx2nz59eOihh7j88ssZN24cnTp1anIZt99+u93nkrJSU+0CpHXhBz/4AccddxxLliyhZ8+e3HzzzQwaNIh//OMfDB06FChdyHb99dczZ84cTjzxRB544AEAFi1axPjx4/nJT35SzU2QpDVigKtNqKurY8qUKau0vfDCCw1Ou912260Mb4BNNtmEefPmVbQ+SVrX7EKXJClDBrgkSRkywCVJypDnwPXRjelS7QraljELql2BpAx4BC5JUoYMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuSVKGDHBJkjJkgEuSlKGKBnhEdI2IuyPi2YiYGRF7R8QWETE+Ip4v/t28kjVIktQWVfoI/FrgoZRSL6A/MBM4D5iQUtoFmFAMS5KkNVCxAI+ILsC+wI0AKaUlKaX5wKHAz4vJfg4cVqkaJElqqyp5BL4jMBe4OSKejIifRcQmwDYppTeKad4Etmlo5oj4ckRMiYgpc+fOrWCZkiTlp5IBXgMMAH6cUvoksIh63eUppQSkhmZOKd2QUhqYUhrYrVu3CpYpSVJ+Khngs4HZKaXHi+G7KQX6WxGxLUDx79sVrEGSpDapYgGeUnoTeC0idiuaDgCeAcYBxxdtxwP3VqoGSZLaqpoKL/8rwG0R0QF4ETiB0oeGX0bEaOAV4OgK1yBJUptT0QBPKU0DBjYw6oBKrleSpLbOO7FJkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZMsAlScqQAS5JUoYMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuSVKGDHBJkjJkgEuSlCEDXJKkDBngkiRlyACXJClDBrgkSRkywCVJypABLklShgxwSZIyZIBLkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZMsCrpEePHnziE5+grq6OgQMHAnDXXXfRt29f2rVrx5QpUxqc77XXXmPIkCH06dOHvn37cu2117Zm2ZKk9URNSyeMiL2AMUBH4JqU0q8qVdSG4tFHH2WrrbZaOdyvXz/Gjh3LSSed1Og8NTU1XHXVVQwYMID33nuPPfbYg6FDh9KnT5/WKFmStJ5oNMAj4mMppTfLms4ERgIBPA4Y4OtY7969m51m2223ZdtttwVgs802o3fv3rz++usGuCRtYJrqQr8+Iv47IjoWw/OBIymF+LsVr6yNiwiGDRvGHnvswQ033LBWy3j55Zd58sknGTx48DquTpK0vmv0CDyldFhEfBa4PyJuBb4GfAHoBBzWSvW1WY899hjdu3fn7bffZujQofTq1Yt99923xfMvXLiQI444gmuuuYbOnTtXsFJJ0vqoyYvYUkr3AQcCXYB7gOdSSt9PKc1tjeLasu7duwOw9dZbM3LkSCZPntzieZcuXcoRRxzBcccdx+GHH16pEiVJ67FGAzwiPhcRjwIPATOAzwOHRsQdEbFTaxXYFi1atIj33ntv5fPf/OY39OvXr0XzppQYPXo0vXv35swzz6xkmZKk9VhTR+AXAwcDRwPfTSnNTymdBXwDuKQ1imur3nrrLfbZZx/69+/PnnvuySGHHMJBBx3EPffcQ21tLZMmTeKQQw7hwAMPBGDOnDkMHz4cgD/84Q/8z//8D4888gh1dXXU1dXxwAMPVHNzJElV0NTXyBYAh1M65/32isaU0vPAMRWuq03r2bMn06dPX6195MiRjBw5crX27bbbbmVI77PPPqSUKl6jJGn91tQR+EhgS0oh/4XWKUeSJLVEU1eh/w34QSvWIkmSWshbqUqSlKEW30q1Lelx3q+rXUKb8nLH5qeRJK1bzR6BR8RXImLz1ihGkiS1TEu60LcB/hwRv4yIgyIiKl2UJElqWrMBnlK6ENgFuBEYBTwfEd/xZi6SJFVPiy5iSxHQ2EsAAA5vSURBVKUvHr9ZPD4ENgfujojLK1ibJElqRLMXsUXE6cAXgb8BPwPOSSktjYh2wPPAuZUtUZIk1deSq9C3AA5PKb1S3phSWh4RIypTliRJakpLutAfBP6+YiAiOkfEYICU0sxKFSZJkhrXkgD/MbCwbHhh0SZJkqqkJQEeqezXM1JKy9lAbwAjSdL6oiUB/mJEfDUiNioepwMvVrowSZLUuJYE+MnAp4DXgdnAYODLlSxKkiQ1rdmu8JTS2/j735IkrVda8j3wjsBooC+w8mcrUkpfqmBdkiSpCS3pQv8f4GPAgcBvgVrgvUoWJUmSmtaSAN85pfQNYFFK6efAIZTOg0uSpCppSYAvLf6dHxH9gC7A1pUrSZIkNacl3+e+ofg98AuBccCmwDcqWpUkSWpSkwFe/GDJuymld4DfAT1bpSpJktSkJrvQi7uu+WtjkiStZ1pyDvz/IuLsiNg+IrZY8ah4ZZIkqVEtOQf++eLfU8vaEnanS5JUNS25E9uOrVGIJElquZbcie2LDbWnlG5d9+VIkqSWaEkX+qCy5x2BA4AnAANckqQqaUkX+lfKhyOiK3BHxSqSJEnNaslV6PUtAjwvLklSFbXkHPh9lK46h1Lg9wF+WcmiJElS01pyDvzKsucfAq+klGZXqB5JktQCLQnwV4E3UkqLASLiXyKiR0rp5YpWJkmSGtWSc+B3AcvLhpcVbZIkqUpaEuA1KaUlKwaK5x0qV5IkSWpOSwJ8bkR8bsVARBwK/K1yJUmSpOa05Bz4ycBtEfHDYng20ODd2SRJUutoyY1c/grsFRGbFsMLK16VJElqUrNd6BHxnYjomlJamFJaGBGbR8TFrVGcJElqWEvOgR+cUpq/YiCl9A4wvHIlSZKk5rQkwNtHxMYrBiLiX4CNm5hekiRVWEsuYrsNmBARNxfDJ+AvkUmSVFUtuYjtuxExHfhM0fTtlNLDlS1LkiQ1pSVH4KSUHgIeAoiIfSLiupTSqRWtTJIkNapFAR4RnwSOBY4GXgLGVrIoSZLUtEYDPCJ2pRTax1K689qdQKSUhrRSbZIkqRFNHYE/C/weGJFSegEgIs5olaokSVKTmvoa2eHAG8CjEfHTiDgAiNYpS5IkNaXRAE8p/SqldAzQC3gU+BqwdUT8OCKGtVaBkiRpdc3eyCWltCil9L8ppc8CtcCTwNcrXpkkSWpUS+7EtlJK6Z2U0g0ppQMqVZAkSWreGgW4JElaPxjgkiRlyACXJClDBrgkSRkywCVJypABLklShgxwSZIyZIBLkpQhA1ySpAwZ4JIkZajiAR4R7SPiyYi4vxjeMSIej4gXIuLOiOhQ6RokSWprWuMI/HRgZtnwd4GrU0o7A+8Ao1uhBkmS2pSKBnhE1AKHAD8rhgPYH7i7mOTnwGGVrEGSpLao0kfg1wDnAsuL4S2B+SmlD4vh2UD3hmaMiC9HxJSImDJ37twKlylJUl4qFuARMQJ4O6U0dW3mL362dGBKaWC3bt3WcXWSJOWtpoLL/lfgcxExHOgIdAauBbpGRE1xFF4LvF7BGiRJapMqdgSeUvqvlFJtSqkHcAzwSErpOOBR4MhisuOBeytVgyRJbVU1vgf+deDMiHiB0jnxG6tQgyRJWatkF/pKKaWJwMTi+YvAnq2xXkmS2irvxCZJUoYMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuSVKGDHBJkjJkgEuSlCEDXJKkDBngkiRlyACXJClDBrgkSRkywCVJypABLklShgxwSZIyZIBLkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZMsAlScqQAS5JUoYMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuSVKGDHBJkjJkgEuSlCEDXJKkDBngkiRlyACXJClDBrgkSRkywCVJypABLklShgxwSZIyZIBLkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZMsAlScqQAS5JUoYMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuSVKGDHBJkjJkgEuSlCEDXJKkDBngkiRlyACXJClDBrgkSRkywCVJypABLklShgxwSZIyZIBLkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZMsAlScqQAS5JUoYMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuSVKGDHBJkjJkgEuSlCEDXJKkDBngkiRlyACXJClDBrgkSRkywCVJypABLklShgxwSZIyZIBLkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZMsAlScqQAS5JUoYMcEmSMmSAS5KUoYoFeERsHxGPRsQzEfF0RJxetG8REeMj4vni380rVYMkSW1VJY/APwTOSin1AfYCTo2IPsB5wISU0i7AhGJYkiStgYoFeErpjZTSE8Xz94CZQHfgUODnxWQ/Bw6rVA2SJLVVrXIOPCJ6AJ8EHge2SSm9UYx6E9imkXm+HBFTImLK3LlzW6NMSZKyUfEAj4hNgf8HfC2l9G75uJRSAlJD86WUbkgpDUwpDezWrVuly5QkKSsVDfCI2IhSeN+WUhpbNL8VEdsW47cF3q5kDZIktUWVvAo9gBuBmSml75WNGgccXzw/Hri3UjVIktRW1VRw2f8K/Afwl4iYVrSdD1wG/DIiRgOvAEdXsAZJktqkigV4SukxIBoZfUCl1itJ0obAO7FJkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZMsAlScqQAS5JUoYMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuSVKGDHBJkjJkgEuSlCEDXJKkDBngkiRlyACXJClDBrgkSRkywCVJypABLklShgxwSZIyZIBLkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZMsAlScqQAS5JUoYMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuSVKGDHBJkjJkgEuSlCEDXJKkDBngkiRlyACXJClDBrgkSRkywCVJypABLklShgxwSZIyZIBLkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZMsAlScqQAS5JUoYMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuSVKGDHBJkjJkgEuSlCEDXJKkDBngkiRlyACXJClDBrgkSRkywCVJypABLklShgxwSZIyZIBLkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZMsAlScqQAS5JUoYMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuSVKGDHBJkjJkgEuSlCEDXJKkDBngkiRlyACXJClDBrgkSRkywCVJypABLklShgxwSZIyZIBLkpQhA1ySpAwZ4JIkZagqAR4RB0XErIh4ISLOq0YNkiTlrNUDPCLaA9cBBwN9gGMjok9r1yFJUs6qcQS+J/BCSunFlNIS4A7g0CrUIUlStmqqsM7uwGtlw7OBwfUniogvA18uBhdGxKxWqE1rIWAr4G/VrqPN+FZUuwJtQHz/rmOVef/u0FBjNQK8RVJKNwA3VLsONS8ipqSUBla7DklrzvdvvqrRhf46sH3ZcG3RJkmSWqgaAf5nYJeI2DEiOgDHAOOqUIckSdlq9S70lNKHEXEa8DDQHrgppfR0a9ehdcpTHVK+fP9mKlJK1a5BkiStIe/EJklShgxwSZIyZIBvYCKia0Scso6WNS4iZqyLZUlq3rp4/0ZEh4i4ISKei4hnI+KIdVWfWpcBvuHpCnzkAI+Iw4GFH70cSWtgXbx/LwDeTintSul21r/9yFWpKgzwDc9lwE4RMS0iroiSKyJiRkT8JSI+DxAR+0XE7yLi18UPz1wfEe2KcZsCZwIXN7aSiPhsRDweEU9GxP9FxDYr5o2Im4t1PbXi03/xAzdPRMT0iJhQ8b0g5ekjv3+BLwGXAqSUlqeUVrsLW0TsGRGTivfvHyNit6K9fURcWazvqYj4StE+qJhuekRMjojNWmVvbOhSSj42oAfQA5hRNnwEMJ7SV/q2AV4FtgX2AxYDPYtx44Eji3muBkbWX1a99WzOP7/lcCJwVfH8u8A19abrRun2ujsWbVtUez/58LE+Pj7q+5fSEfxrwPeAJ4C7gG0aWE9noKZ4/hng/xXP/xO4u2zcFkAH4EVgUP15fVT24RG49gFuTyktSym9Rak7bVAxbnIq/ejMMuB2YJ+IqAN2Sind08xya4GHI+IvwDlA36L9M5R+jQ6AlNI7wF7A71JKLxVtf19H2ya1dWv0/qV0749a4I8ppQHAJODKBpbbBbiruMblalZ9//4kpfQhrHyv7ga8kVL6c9H27orxqiwDXE2pf5OABOwNDIyIl4HHgF0jYmID8/4A+GFK6RPASUDHCtYpaXUNvX/nAe8DY4u2u4ABDcz7beDRlFI/4LP4/l0vGeAbnveA8vNTvwc+X5zb6gbsC0wuxu1Z3PK2HfB54LGU0o9TStullHpQ+kT/XEppvwbW04V/3uP++LL28cCpKwYiYnPgT8C+EbFj0bbFR9xGqa36qO/fBNxHqYsd4ADgmQbWU/7+HVXWPh44KSJqYOV7dRawbUQMKto2WzFelWWAb2BSSvOAPxQXoVwB3AM8BUwHHgHOTSm9WUz+Z+CHwEzgpWLalhpDqQtuKqv+VOHFwObF+qcDQ1JKcyn9dOzYou3Otd5AqQ1bR+/frwNjIuIp4D+AsxpY1eXApRHxJKvecvtnlM6zP1W8V7+QUlpC6QPCD4q28XjE3iq8laoaFBH7AWenlEZUuxZJa8b374bBI3BJkjLkEbgkSRnyCFySpAwZ4JIkZcgAlyQpQwa4JEkZMsAlScrQ/w8GGAPrRE5+sgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 504x576 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "labels = ['top4 acc', 'top6 acc']\n",
    "theirs = [51.2, 71.3]\n",
    "ours = [62.7, 92.0]\n",
    "\n",
    "x = np.arange(len(labels))  # the label locations\n",
    "width = 0.35  # the width of the bars\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(7,8))\n",
    "rects1 = ax.bar(x - width/2, theirs, width, label='GNN-Explainer')\n",
    "rects2 = ax.bar(x + width/2, ours, width, label='Ours')\n",
    "\n",
    "# Add some text for labels, title and custom x-axis tick labels, etc.\n",
    "ax.set_ylabel('Accuracy %')\n",
    "ax.set_title('Comparison of ours vs GNN-Explainer')\n",
    "ax.set_xticks(x)\n",
    "ax.set_xticklabels(labels)\n",
    "ax.legend()\n",
    "\n",
    "def autolabel(rects):\n",
    "    \"\"\"Attach a text label above each bar in *rects*, displaying its height.\"\"\"\n",
    "    for rect in rects:\n",
    "        height = rect.get_height()\n",
    "        ax.annotate('{}'.format(height),\n",
    "                    xy=(rect.get_x() + rect.get_width() / 2, height),\n",
    "                    xytext=(0, 3),  # 3 points vertical offset\n",
    "                    textcoords=\"offset points\",\n",
    "                    ha='center', va='bottom')\n",
    "\n",
    "\n",
    "autolabel(rects1)\n",
    "autolabel(rects2)\n",
    "\n",
    "fig.tight_layout()\n",
    "\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([5, 0, 1, 2])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#add even or odd number patterns\n",
    "#synthetic data 1: with 4 patterns \n",
    "import networkx as nx\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pickle\n",
    "\n",
    "max_nodes = 20\n",
    "def makesPattern(adj, feats, p, n, num_nodes):\n",
    "    l_p = np.argmax(feats[p]) \n",
    "    l_n = np.argmax(feats[n])\n",
    "    \n",
    "    p_neighbors = []\n",
    "    n_neighbors = []\n",
    "    for pix in range(num_nodes):\n",
    "        if p != pix and pix != n:\n",
    "            l_pix = np.argmax(feats[pix]) \n",
    "            if l_pix == l_p or l_pix == l_n:\n",
    "                continue\n",
    "            if adj[p,pix] > 0.0:\n",
    "                p_neighbors.append(pix)\n",
    "            elif adj[n,pix] > 0.0:\n",
    "                n_neighbors.append(pix)\n",
    "    found = False\n",
    "    pattern = -1\n",
    "    for p_nbr in p_neighbors:\n",
    "        l_pnbr = np.argmax(feats[p_nbr]) \n",
    "        for n_nbr in n_neighbors:\n",
    "            l_nnbr = np.argmax(feats[n_nbr]) \n",
    "            if l_pnbr == l_nnbr:\n",
    "                continue\n",
    "            if p_nbr != n_nbr and adj[p_nbr, n_nbr] > 0.:\n",
    "                if (adj[p,n_nbr] < 1. and adj[n, p_nbr] < 1.): #ring\n",
    "                    pattern = 0\n",
    "                    found = True\n",
    "                    return found, pattern\n",
    "                elif (adj[p,n_nbr] < 1. and adj[n, p_nbr] > 0.) or (adj[p,n_nbr] > 0. and adj[n, p_nbr] < 1.) :\n",
    "                    pattern = 1\n",
    "                    found = True\n",
    "                    return found, pattern\n",
    "                elif (adj[p,n_nbr] > 0. and adj[n, p_nbr] > 0.):\n",
    "                    pattern = 2\n",
    "                    found = True\n",
    "                    return found, pattern\n",
    "    return found, pattern\n",
    "        \n",
    "\n",
    "\n",
    "def addPattern(adj,feats,nodes, sub_label):\n",
    "    degree_sum = np.sum(adj,axis=1)\n",
    "    avg_deg = int(np.sum(degree_sum)/nodes)\n",
    "    fake_sub_label = -1\n",
    "    if sub_label == -1:\n",
    "        fake_sub_label = np.random.randint(3)\n",
    "        \n",
    "    if fake_sub_label > -1 or sub_label > -1: \n",
    "        if fake_sub_label == -1:\n",
    "            lbls = np.random.choice(range(feat_dim),4, replace=False).tolist()\n",
    "        else:\n",
    "            lbls = np.random.choice(range(feat_dim),3, replace=True).tolist()\n",
    "            lbls.append(lbls[np.random.randint(3)])\n",
    "            \n",
    "        for i in range(4): #add 4 nodes\n",
    "            feats[nodes+i,lbls[i]] = 1.\n",
    "        if sub_label == 0 or fake_sub_label == 0:\n",
    "            for i in range(3): #make ring\n",
    "                adj[nodes+i,nodes+i+1] = 1.0\n",
    "                adj[nodes+i+1, nodes+i] = 1.0\n",
    "            adj[nodes+3,nodes] = 1.0\n",
    "            adj[nodes,nodes+3] = 1.0\n",
    "            \n",
    "        elif sub_label == 1 or fake_sub_label == 1:\n",
    "            for i in range(3): #make ring\n",
    "                adj[nodes+i,nodes+i+1] = 1.0\n",
    "                adj[nodes+i+1, nodes+i] = 1.0\n",
    "            adj[nodes+3,nodes] = 1.0\n",
    "            adj[nodes,nodes+3] = 1.0\n",
    "            diag = np.random.randint(2)\n",
    "            adj[nodes+diag, nodes+diag+2] = 1.0\n",
    "            adj[nodes+diag+2, nodes+diag] = 1.0\n",
    "                \n",
    "            \n",
    "        \n",
    "        elif sub_label == 2 or fake_sub_label == 2: #make tetra real or fake\n",
    "            for i in range(3): #make ring\n",
    "                adj[nodes+i,nodes+i+1] = 1.0\n",
    "                adj[nodes+i+1, nodes+i] = 1.0\n",
    "            adj[nodes+3,nodes] = 1.0\n",
    "            adj[nodes,nodes+3] = 1.0\n",
    "            adj[nodes, nodes+2] = 1.0\n",
    "            adj[nodes+2, nodes] = 1.0\n",
    "            adj[nodes+1, nodes+3] = 1.0\n",
    "            adj[nodes+3, nodes+1] = 1.0\n",
    "    else:\n",
    "        assert(False)\n",
    "            \n",
    "            \n",
    "\n",
    "        \n",
    "    for i in range(4):\n",
    "        deg_exp = np.random.randint(avg_deg-1, avg_deg+1)\n",
    "        dest_n = np.random.randint(nodes)\n",
    "        while_count = 0\n",
    "        skip = False\n",
    "        while(makesPattern(adj, feats, dest_n, nodes+i, nodes+4)[0] == True):\n",
    "            dest_n = np.random.randint(nodes)\n",
    "            while_count += 1\n",
    "            if while_count == 5:\n",
    "                skip = True\n",
    "                break\n",
    "        if skip:\n",
    "            continue\n",
    "            \n",
    "        adj[nodes+i, dest_n] = 1.0\n",
    "        adj[dest_n, nodes+i] = 1.0\n",
    "        deg = int(np.sum(adj[nodes+i,:]))\n",
    "        if deg_exp > deg:\n",
    "            for e in range(deg_exp-deg):\n",
    "                dest_n = np.random.randint(nodes)\n",
    "                if adj[nodes+i, dest_n] > 0:\n",
    "                    continue\n",
    "                if(makesPattern(adj, feats, dest_n, nodes+i, nodes+4)[0] == True):\n",
    "                    continue\n",
    "                adj[nodes+i, dest_n] = 1.0\n",
    "                adj[dest_n, nodes+i] = 1.0\n",
    "        \n",
    "    \n",
    "    pos_t = list(range(nodes+4))\n",
    "    pos_covered = []\n",
    "#     for i in range(4):\n",
    "        \n",
    "#         if np.random.rand() < 0.4:\n",
    "#             continue\n",
    "        \n",
    "#         dest_pos = np.random.randint(nodes)\n",
    "#         if dest_pos in pos_covered:\n",
    "#             continue\n",
    "        \n",
    "        \n",
    "        \n",
    "#         temp_feats = np.copy(feats[nodes+i,:])\n",
    "#         feats[nodes+i,:] = feats[dest_pos,:]\n",
    "#         feats[dest_pos,:] = temp_feats\n",
    "        \n",
    "        \n",
    "#         temp_adj1 = np.copy(adj[nodes+i,:])\n",
    "#         temp_adj2 = np.copy(adj[:,nodes+i])\n",
    "\n",
    "#         adj[nodes+i,:] = adj[dest_pos,:]\n",
    "#         adj[:,nodes+i] = adj[:,dest_pos]\n",
    "#         adj[dest_pos,:] = temp_adj1\n",
    "#         adj[:,dest_pos] = temp_adj2\n",
    "        \n",
    "        \n",
    "#         pos_t[nodes+i] = dest_pos\n",
    "#         pos_t[dest_pos] = nodes + i\n",
    "#         pos_covered.append(dest_pos)\n",
    "      \n",
    "        \n",
    "        \n",
    "#     print(\"nodes: \", nodes+4, pos_t[-4:])\n",
    "        \n",
    "        \n",
    "        \n",
    "\n",
    "    return adj, feats, pos_t[nodes:]\n",
    "                            \n",
    "            \n",
    "def drawGraph(adj, feats, nodes, highlight_nodes=None, sublabel_d=None):\n",
    "    node_labels = ['A','B','C','D','E','F']\n",
    "    G_class = nx.from_numpy_array(adj[:nodes,:nodes])\n",
    "\n",
    "    fig, ax_l = plt.subplots(1,1, figsize=(15,10))\n",
    "    colors = []\n",
    "    hn_list = []\n",
    "    for n in range(nodes):\n",
    "        colors.append((0.9,0.9,0.9))\n",
    "    if highlight_nodes is not None:\n",
    "        for ix in range(highlight_nodes.shape[0]):\n",
    "            red = np.random.rand()\n",
    "            green = 1.0 - red \n",
    "            blue = np.random.rand()  \n",
    "#             if sublabel_d[ix] != -1:\n",
    "            if True:\n",
    "                for gx in range(highlight_nodes.shape[1]):\n",
    "                    for h_n in highlight_nodes[ix,gx]:\n",
    "                        if h_n == -1:\n",
    "                            continue\n",
    "#                         print(\"node\", h_n, red, green, blue)\n",
    "                        colors[h_n] = (red,green,blue)\n",
    "                        hn_list.append(h_n)\n",
    "    labels_dict = {}\n",
    "    for n in range(nodes):\n",
    "        lb = np.argmax(feats[n,:])\n",
    "        labels_dict[n] = node_labels[lb]\n",
    "#         labels_dict[n] = str(n) + \" : \" + node_labels[lb]\n",
    "        \n",
    "#     colors[0] = (0.9,0.1,0.1)\n",
    "#     nx.draw_networkx(G_class, ax=ax_l, node_color=colors)\n",
    "    if False:\n",
    "#     if highlight_nodes is not None:\n",
    "        hn_colors = []\n",
    "        hn_labels = {}\n",
    "        hn_edges = []\n",
    "        \n",
    "        for hn in hn_list:\n",
    "            hn_labels[hn] = labels_dict[hn]\n",
    "            hn_colors.append(colors[hn])\n",
    "        \n",
    "        for e in G_class.edges():\n",
    "            if e[0] in hn_list and e[1] in hn_list:\n",
    "                hn_edges.append(e)\n",
    "                \n",
    "        \n",
    "        \n",
    "        \n",
    "        nx.draw_networkx(G_class, labels=hn_labels, nodelist = hn_list, edgelist = hn_edges, ax=ax_l, node_color = hn_colors)\n",
    "    else:   \n",
    "        nx.draw_networkx(G_class, labels=labels_dict, ax=ax_l, node_color = colors)\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "total_adds = 2\n",
    "len_data = 8000\n",
    "max_repeat = 1\n",
    "init_max_nodes = 12\n",
    "max_nodes = total_adds*max_repeat*4 + init_max_nodes\n",
    "feat_dim = 6\n",
    "feats_data = np.zeros((len_data, max_nodes, feat_dim))\n",
    "adjs_data = np.zeros((len_data, max_nodes, max_nodes))\n",
    "labels_data = np.zeros((len_data),dtype=np.int32)\n",
    "sub_labels_data = np.zeros((len_data,total_adds),dtype=np.int32) - 1\n",
    "sub_label_nodes = np.zeros((len_data,total_adds,max_repeat,4),dtype=np.int32) - 1\n",
    "num_nodes_data = np.zeros((len_data),dtype=np.int32)\n",
    "\n",
    "#A/B/C/D/E/F\n",
    "# for i in range(len_data):\n",
    "for i in range(20):\n",
    "    adj = np.zeros((max_nodes, max_nodes))\n",
    "    feats = np.zeros((max_nodes, feat_dim))\n",
    "    graph_label = np.random.randint(3) #3 combs \n",
    "  \n",
    "    label_count = 0\n",
    "    \n",
    "    nodes = np.random.randint(6, init_max_nodes)\n",
    "    \n",
    "    for n_ix in range(nodes): # add nodes\n",
    "        l_n = np.random.randint(0,feat_dim)\n",
    "        feats[n_ix, l_n] = 1.0\n",
    "\n",
    "        if n_ix == 0:\n",
    "            continue\n",
    "        p_node = np.random.randint(n_ix) \n",
    "        while makesPattern(adj, feats, p_node, n_ix, nodes)[0] == True: #connect node to graph such that\n",
    "            p_node = np.random.randint(n_ix)                    #it doesn't result in any pattern\n",
    "\n",
    "        adj[p_node,n_ix] = 1.\n",
    "        adj[n_ix, p_node] = 1.\n",
    "\n",
    "    max_edges = min(28,int((nodes*(nodes-1))/3))\n",
    "    if max_edges > nodes:\n",
    "        edge_total = np.random.randint(nodes, max_edges)\n",
    "        for e_ix in range(edge_total-nodes+1):\n",
    "            rand_nix = np.random.randint(nodes)\n",
    "            rand_pix = np.random.randint(nodes)\n",
    "            if rand_nix == rand_pix:\n",
    "                continue\n",
    "            if not makesPattern(adj, feats, rand_pix, rand_nix, nodes)[0]: #add more random edges\n",
    "                adj[rand_nix, rand_pix] = 1.0\n",
    "                adj[rand_pix, rand_nix] = 1.0\n",
    "    \n",
    "    label_list = [(0,1,-1),(1,2,-1),(2,0,-1)]\n",
    "    sublabels_covered = []\n",
    "    for g_i in range(total_adds):\n",
    "        l_entry = label_list[graph_label]\n",
    "        label = 1\n",
    "        sublabel_ix = np.random.randint(total_adds)\n",
    "        while(l_entry[sublabel_ix] in sublabels_covered):\n",
    "            sublabel_ix = np.random.randint(total_adds)\n",
    "        sub_label = l_entry[sublabel_ix]\n",
    "        sublabels_covered.append(sub_label)\n",
    "        if sub_label == -1:\n",
    "            if np.random.rand() < 1.5:\n",
    "                continue\n",
    "\n",
    "#         print(\"g_i: \", g_i, \"label: \", label, \"sub_label,: \", sub_label)\n",
    "        if max_repeat > 1:\n",
    "            repeat = np.random.randint(1,max_repeat)\n",
    "        else:\n",
    "            repeat = 1\n",
    "        assert(repeat > 0)\n",
    "        for rep in range(repeat):\n",
    "      \n",
    "            #0 or 1\n",
    "            highlight_nodes = None\n",
    "\n",
    "            \n",
    "\n",
    "            adj, feats, highlight_nodes = addPattern(adj,feats,nodes, sub_label)\n",
    "#             print(\"sublabel: \", sub_label, highlight_nodes)\n",
    "            nodes = nodes + 4\n",
    "\n",
    "#             if sub_label != -1:\n",
    "#                 print(highlight_nodes)\n",
    "            sub_label_nodes[i, g_i, rep] = np.array(highlight_nodes)\n",
    "            \n",
    "        sub_labels_data[i,g_i] = sub_label\n",
    "    \n",
    "            \n",
    "\n",
    "\n",
    "    print(i, graph_label)\n",
    "            \n",
    "\n",
    "#     print(nodes)\n",
    "    drawGraph(adj, feats, nodes, highlight_nodes=sub_label_nodes[i], sublabel_d=sub_labels_data[i])\n",
    "    \n",
    "    \n",
    "    feats_data[i] = feats\n",
    "    labels_data[i] = graph_label\n",
    "    \n",
    "    num_nodes_data[i] = nodes\n",
    "    \n",
    "    adjs_data[i] = adj\n",
    "    \n",
    "\n",
    "synthetic_data = {}\n",
    "synthetic_data['adj'] = adjs_data\n",
    "synthetic_data['feat'] = feats_data\n",
    "synthetic_data['label'] = labels_data\n",
    "synthetic_data['sub_label'] = sub_labels_data\n",
    "synthetic_data['sub_label_nodes'] = sub_label_nodes\n",
    "synthetic_data['num_nodes'] = num_nodes_data\n",
    "\n",
    "# pickle.dump(synthetic_data, open(\"../../gcn_interpretation/data/synthetic_data_2label_3sublabel/synthetic_data_4000_comb_norep.p\", \"wb\"))\n",
    "        \n",
    "\n",
    "# pickle.dump(synthetic_data, open(\"../../gcn_interpretation/data/synthetic_data_3label_3sublabel/synthetic_data_8000_comb_norep_max20_nofake.p\", \"wb\"))\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "        \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(8000, 20, 20)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "synthetic_data['adj'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import pickle\n",
    "\n",
    "# synthetic_data = pickle.load(open(\"../../gcn_interpretation/data/synthetic_data_3label_3sublabel/synthetic_data_8000_comb_norep_max20_8dlbls.p\",\"rb\"))\n",
    "\n",
    "data = synthetic_data\n",
    "train_idx = 3000\n",
    "\n",
    "X_train_feat = torch.from_numpy(data['feat'][:train_idx]).float()\n",
    "X_train_adj = torch.from_numpy(data['adj'][:train_idx]).float()\n",
    "X_train_nodes = torch.from_numpy(data['num_nodes'][:train_idx])\n",
    "y_train_label = torch.from_numpy(data['label'][:train_idx])\n",
    "# data['sub_label']\n",
    "\n",
    "train_idx = 7000\n",
    "val_idx = 8000\n",
    "X_val_feat = torch.from_numpy(data['feat'][train_idx:val_idx]).float()\n",
    "X_val_adj = torch.from_numpy(data['adj'][train_idx:val_idx]).float()\n",
    "X_val_nodes = torch.from_numpy(data['num_nodes'][train_idx:val_idx])\n",
    "y_val_label = torch.from_numpy(data['label'][train_idx:val_idx])\n",
    "\n",
    "tensor_data_train = (X_train_adj, X_train_feat, y_train_label, X_train_nodes)\n",
    "tensor_data_val = (X_val_adj, X_val_feat, y_val_label, X_val_nodes)\n",
    "\n",
    "import torch\n",
    "torch.save(tensor_data_train, \"./data/synthetic/synthetic_train_4k8000_comb_12dlbls_nofake.pth\")\n",
    "torch.save(tensor_data_val, \"./data/synthetic/synthetic_val_4k8000_comb_12dlbls_nofake.pth\")\n",
    "\n",
    "#3 cases\n",
    "#ABCD form a ring\n",
    "#or\n",
    "#C has D, E and F among its 3 neighbors\n",
    "#or\n",
    "#A has B,C,D among its 3 neighbors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# d_list = [{'Weight_rt': 0.2741, 'Weight_wg': 0.2924, 'Weight sum': 0.5665, 'Max weight acc': 0.5980, 'Max weight acc ngt': 0.2243},\n",
    "#  {'Weight_rt': 0.3306, 'Weight_wg': 0.2197, 'Weight sum': 0.5504, 'Max weight acc': 0.4324, 'Max weight acc ngt': 0.3290},\n",
    "#  {'Weight_rt': 0.2708, 'Weight_wg': 0.4982, 'Weight sum': 0.7690, 'Max weight acc': 0.3607, 'Max weight acc ngt': 0.5020},\n",
    "#  {'Weight_rt': 0.7569, 'Weight_wg': 0.1291, 'Weight sum': 0.8861, 'Max weight acc': 0.9559, 'Max weight acc ngt': 0.0250},\n",
    "#  {'Weight_rt': 0.3957, 'Weight_wg': 0.0570, 'Weight sum': 0.4526, 'Max weight acc': 0.4700, 'Max weight acc ngt': 0.0543}, \n",
    "#  {'Weight_rt': 0.9443, 'Weight_wg': 0.0200, 'Weight sum': 0.9649, 'Max weight acc': 0.9752, 'Max weight acc ngt': 0.0143}]\n",
    "\n",
    "import pickle\n",
    "d_list = pickle.load(open(\"./results_final_results.p\", \"rb\"))\n",
    "d_list[4:6] = d_list[8:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAfAAAAI4CAYAAACV/7uiAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3de5zVdZ348ddbBkS8i1gKEaipyG1Q8pKlq63mrYuXNZG84rquXWzbWvu15S3Ly26ppZsaVl4K3VX3AZriQ/GSa2vuuKAgWJISDt6QRDShAN+/P86BHXFmGGDOHD7D6/l4+HC+3+853/M+1uHF93u+c05kJpIkqSwb1XsASZK05gy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEurERFvtfjnnYhY3GJ57Dru+5CIeDgiFkXEM501c71FRO+IyIj4U/W/U3NEXBoRG1W3/0NE/G9E/CUirqnnPBHRJyJ+GhFzq/87PBERB9d6JmldNdR7AGl9l5mbrfg5IuYAp2fm/Z20+7eA64CtgS900j7XJ7tmZnNEDAUeBmYBPwOagfOBz6wH80wEfg+cV53rM8DtEbFbZr7YxfNJHeYRuLSOImKTiLg6Il6qHtn9S0T0rG47NCJmR8QFEfHHiHg+Iv5mxX0z89eZ+XNgTgcepyEibo+IVyJiYUQ8GBG7tti+aUT8ICJeiIg3qkf2DdVtfxURj1XXz42IE9p4jIERcXd11t9FxMkttl0SET+PiAkR8WZEPBURjR35b5SZTwP/DQyrLv9HZk4C/ria59ynesS8c4t1/atnQbaOiPdHxOTqf48FEfHAms6Tma9n5kWZOTcz38nMO4CXgVEd2ZdULwZcWncXACOA4cCewF8B/9Ri+yCgF/B+4G+BGyJi8Fo+1kRgp+q+ngFuaLHtB8BuwIeBbYBvAlmN313AvwB9qzM+3cb+/wP4LbA9cAJweUTs12L7UcBPgK2AKcAVHRk6IoYD+wJTO3L7FTLzbWASMKbF6uOBezPzdeCc6rzbVmc+f13niYgBVP43m7kms0pdzYBL624scF5mvpaZrwAXASe22L4MuCAz/1I99X4/cOyaPkhmLsvMGzPzrcxcQuUvDntV39/tCZwEfDEzX87M5Zn5SGYur85yZ2beXt3H/Mx8ctX9R8SHgJHANzLzz5nZROUvCC2fywOZeV91vzcBqzsCfzoiXgfuAH4I/GJNn3f1Pi0DfkKL/SwFdgAGVv/7/mpd5omIjavrfpSZz6/FrFKX8T1waR1ERFA5Gv5Di9V/APq3WJ5fDW7L7TusxWM1AJdSOQreFngHCCpH1T2ovJ5/38pdP9DG+lXtUJ118SqzfrzF8sstfn4b2Iz2Dc3M5g48dnvupXLWYiSwBPgQcGd123eAC4EHI2Ip8G+Z+f21maf63/cWYAHwlXWcWao5j8CldZCVr/N7Gfhgi9UDgXktlreNiN6rbF+bi6NOBQ4GDgS2pHK6HCoRf4nKkf5OrdzvhTbWr+pFoF9EbLLKrPPauH2XyMylwG1UjsJPAP5zxV8yMvONzDw7Mz8IHAN8c5VT/h1SvTr+RqAP8NnqGQZpvWbApXU3ATgvIvpGxHbAPwM3t9jeE/hWRPSKiIOoRPh2qISjGveelcWVp8NbszmVI9AFwKZUTtUDKyN3I3BlRLwvInpExEcjogeVU91HRsRR1Qvh+kXEiFb2PxuYDlwUERtHxB7Ayas8l05RnaM3lTMHParPu0c7d/kFlfe+x9DitHdEfCoidqyeCXkDWE7lzMSazBLA9cAA4KjM/MuaPRupPgy4tO7OpXLB09PANOBR4LIW2+dQOTp+mcoFYKdm5nPVbYcAi6m8J7tL9ec7ad31wPzqfqYD/7XK9i9ROVU+lUrkvw1EZs4GPg18g8pV303A0FV3Xj2b8DfA7tXHuBX4Wmau+jid4SIqz/XLwOnVn7/Wzu1/RSX2W1K5hmCFIcCDwJvV2/xrZv73Gs6yC3AKlYv/Xo3/+x3/Y9ZwP1KXisprVlItRMShwFWZufNqbyxJa8AjcEmSCmTAJUkqkKfQJUkqkEfgkiQVqIgPctl2221z0KBB9R5DkqQu98QTT7yWmf1WXV9EwAcNGkRTU1O9x5AkqctFxB9aW+8pdEmSCmTAJUkqkAGXJKlARbwH3pqlS5fS3NzMkiVLVn/j9Uzv3r0ZMGAAPXu29ZHXkiS1r9iANzc3s/nmmzNo0CAq30VQhsxkwYIFNDc3M3jw4HqPI0kqVLGn0JcsWULfvn2LijdARNC3b98izxxIktYfxQYcKC7eK5Q6tyRp/VF0wCVJ2lAV+x74qgZ9/Zedur85lxyx2tv06NGD4cOHs2zZMoYMGcINN9xAnz59OO2007jrrrvYbrvtmDFjRqfOJUkSeAS+TjbZZBOmTZvGjBkz6NWrF9dccw0Ap5xyCpMnT67zdJKk7syAd5KPfexjzJ49G4D999+fbbbZps4TSZK6MwPeCZYtW8Y999zD8OHD6z2KJBXhyiuvZNiwYQwdOpQrrrgCgK997WvstttujBgxgqOOOoqFCxe2ef/ly5czatQojjzyyJXrxo4dy6677sqwYcM47bTTWLp0ac2fRz0Z8HWwePFiGhsbGT16NAMHDmTcuHH1HkmS1nszZszgxz/+MY8//jhPPvkkd911F7Nnz+bggw9mxowZPPXUU+yyyy5cfPHFbe7jyiuvZMiQIe9aN3bsWJ555hmmT5/O4sWLGT9+fK2fSl0Z8HWw4j3wadOm8cMf/pBevXrVeyRJWu/NmjWLvffemz59+tDQ0MABBxzAHXfcwSGHHEJDQ+Xa6n322Yfm5uZW79/c3Mwvf/lLTj/99HetP/zww4kIIoK99tqrzft3FwZcktSlhg0bxiOPPMKCBQt4++23ufvuu3nhhRfedZuf/OQnHHbYYa3e/8tf/jKXXXYZG23UesKWLl3KTTfdxKGHHtrps69Pus2vkXXk1766ypgxY3jooYd47bXXGDBgABdccIGn1yWpasiQIZxzzjkccsghbLrppjQ2NtKjR4+V27/zne/Q0NDA2LFj33PfFb+iu+eee/LQQw+1uv+zzjqL/fffn4997GO1egrrhW4T8Hp46623Wl0/YcKELp5Eksoybty4lQc23/jGNxgwYAAAP/vZz7jrrruYMmVKq59a+eijjzJp0iTuvvtulixZwqJFi/jc5z7HzTffDMAFF1zA/Pnzufbaa7vuydSJp9AlSV3u1VdfBWDu3LnccccdnHDCCUyePJnLLruMSZMm0adPn1bvd/HFF9Pc3MycOXO45ZZbOOigg1bGe/z48dx7771MmDChzdPr3Un3f4aSpPXOMcccw+67784nP/lJrr76arbaaiu+8IUv8Oabb3LwwQfT2NjImWeeCcCLL77I4Ycfvtp9nnnmmbzyyivsu+++NDY2cuGFF9b6adRVZGa9Z1it0aNHZ1NT07vWzZo16z2/QlCS0ueXJHWNiHgiM0evut4jcEmSCmTAJUkqkAGXJKlA3efXyM7fspP398Zqb9La14kuWLCAk046iVdeeYWI4IwzzuDss8/u3NkkqbN09p+dG7oOtKOzeAS+Dlr7OtGGhga+973vMXPmTB577DGuvvpqZs6cWe9RJUndjAHvJCu+TnT77bdnjz32AGDzzTdnyJAhzJs3r87TSZK6GwPeCdr6OtE5c+YwdepU9t577zpNJknqrrrPe+B1sOLrRKFyBN7y887feustjjnmGK644gq22GKLeo0oSeqmDPg6WPEe+KqWLl3KMcccw9ixYzn66KPrMJkkqbvzFHony0zGjRvHkCFD+MpXvlLvcSRJ3VT3OQLvwkv32/Poo49y0003MXz48JWn17/73e926HN8JUnqqO4T8Dpo7etEP/rRj1LC58tLksrmKXRJkgpkwCVJKlDRAS/1VHWpc0uS1h/FBrx3794sWLCguBhmJgsWLKB37971HkWSVLBiL2IbMGAAzc3NzJ8/v96jrLHevXszYMCAeo8hSSpYsQHv2bMngwcPrvcYkiTVRbGn0CVJ2pAZcEmSCmTAJUkqkAGXJKlABlySpAIZcEmSCmTAJUkqkAGXJKlABlySpAIZcEmSCmTAJUkqkAGXJKlABlySpAIZcEmSCmTAJUkqkAGXJKlABlySpAIZcEmSCmTAJUkqkAGXJKlABlySpAIZ8Dq5/PLLGTp0KMOGDWPMmDEsWbKEcePGMXLkSEaMGMGxxx7LW2+99Z77zZkzh0022YTGxkYaGxs588wz6zC9JKneDHgdzJs3jx/84Ac0NTUxY8YMli9fzi233MLll1/Ok08+yVNPPcXAgQO56qqrWr3/TjvtxLRp05g2bRrXXHNNF08vSVofGPA6WbZsGYsXL2bZsmW8/fbb7LDDDmyxxRYAZCaLFy8mIuo8pSRpfWXA66B///589atfZeDAgWy//fZsueWWHHLIIQCceuqpvP/97+eZZ57hi1/8Yqv3f/755xk1ahQHHHAAjzzySFeOLklaTxjwOnj99deZOHEizz//PC+++CJ/+tOfuPnmmwH46U9/yosvvsiQIUO49dZb33Pf7bffnrlz5zJ16lS+//3vc8IJJ7Bo0aKufgqSpDoz4HVw//33M3jwYPr160fPnj05+uij+fWvf71ye48ePTj++OO5/fbb33PfjTfemL59+wKw5557stNOO/G73/2uy2aXJK0fDHgdDBw4kMcee4y3336bzGTKlCkMGTKE2bNnA5X3wCdNmsRuu+32nvvOnz+f5cuXA/Dcc8/x7LPPsuOOO3bp/JKk+muo9wAbor333ptjjz2WPfbYg4aGBkaNGsUZZ5zBQQcdxKJFi8hMRo4cyY9+9CMAJk2aRFNTExdeeCG/+tWvOPfcc+nZsycbbbQR11xzDdtss02dn5EkqatFZtZ7htUaPXp0NjU11XsMSep+zt+y3hN0L+e/0em7jIgnMnP0qus9hS5JUoEMuCRJBTLgkiQVaIO8iG3Q139Z7xG6lTmXHFHvESRpg+MRuCRJBTLgkiQVyIBLklQgAy5JUoEMuCRJBTLgkiQVyIBLklQgAy5JUoEMuCRJBTLgkiQVyIBLklQgAy5JUoEMuCRJBTLgkiQVyIBLklSgmgY8Iv4hIp6OiBkRMSEiekfE4Ij4TUTMjohbI6JXLWeQJKk7qlnAI6I/8CVgdGYOA3oAxwOXApdn5s7A68C4Ws0gSVJ3VetT6A3AJhHRAPQBXgIOAm6rbr8B+EyNZ5AkqdupWcAzcx7wr8BcKuF+A3gCWJiZy6o3awb6t3b/iDgjIpoiomn+/Pm1GlOSpCLV8hT61sCngcHADsCmwKEdvX9mXpeZozNzdL9+/Wo0pSRJZarlKfS/Bp7PzPmZuRS4A9gP2Kp6Sh1gADCvhjNIktQt1TLgc4F9IqJPRATwcWAm8CBwbPU2JwMTaziDJEndUi3fA/8NlYvV/heYXn2s64BzgK9ExGygL3B9rWaQJKm7alj9TdZeZp4HnLfK6ueAvWr5uJIkdXd+EpskSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFaimAY+IrSLitoh4JiJmRcS+EbFNRNwXEc9W/711LWeQJKk7qvUR+JXA5MzcDRgJzAK+DkzJzA8BU6rLkiRpDdQs4BGxJbA/cD1AZv4lMxcCnwZuqN7sBuAztZpBkqTuqpZH4IOB+cBPI2JqRIyPiE2B92XmS9XbvAy8r7U7R8QZEdEUEU3z58+v4ZiSJJWnlgFvAPYAfpSZo4A/scrp8sxMIFu7c2Zel5mjM3N0v379ajimJEnlqWXAm4HmzPxNdfk2KkF/JSK2B6j++9UaziBJUrdUs4Bn5svACxGxa3XVx4GZwCTg5Oq6k4GJtZpBkqTuqqHG+/8i8POI6AU8B5xK5S8N/x4R44A/AMfVeAZJkrqdmgY8M6cBo1vZ9PFaPq4kSd2dn8QmSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFajDAY+InSPi5oi4PSL2reVQkiSpfQ1tbYiI3pm5pMWqbwP/VP35TqCxloNJkqS2tXcEfmdEnNRieSkwCPggsLyWQ0mSpPa1F/BDgS0iYnJE7A98FfgEcBQwtiuGkyRJrWvzFHpmLgeuioibgG8Bfw98MzN/31XDSZKk1rX3HvjewNeAvwDfBRYD34mIecC3M3Nh14woSZJW1WbAgWuBw4HNgJ9m5n7A8RFxAHArldPpkiSpDtoL+DIqF61tSuUoHIDMfBh4uLZjSZKk9rQX8BOAv6MS75PauZ0kSepi7V3E9jvgH7twFkmS1EF+lKokSQUy4JIkFWi1AY+ITSNioxbLG0VEn9qOJUmS2tORI/ApQMtg9wHur804kiSpIzoS8N6Z+daKherPHoFLklRHHQn4nyJijxULEbEnlU9lkyRJddLe74Gv8GXgPyLiRSCA9wOfrelUkiSpXasNeGb+T0TsBuxaXfXbzFxa27EkSVJ7OnIV+ueBTTNzRmbOADaLiLNqP5okSWpLR94D/9uW3zyWma8Df1u7kSRJ0up0JOA9IiJWLERED6BX7UaSJEmr05GL2CYDt0bEtdXlv6uukyRJddKRgJ9DJdp/X12+Dxhfs4kkSdJqdeQq9HeAH1X/kSRJ64HVBjwiPgRcDOwO9F6xPjN3rOFckiSpHR25iO2nVI6+lwEHAjcCN9dyKEmS1L6OBHyTzJwCRGb+ITPPB46o7ViSJKk9HbmI7c/VrxN9NiK+AMwDNqvtWJIkqT0dOQI/m8q3j30J2BP4HHByLYeSJEnt69BnoVd/fAs4tbbjSJKkjujIEbgkSVrPGHBJkgpkwCVJKtBaBTwizu3sQSRJUset7RH46Z06hSRJWiNtXoUeEYva2gRsUptxJElSR7T3a2QLgQ9n5iurboiIF2o3kiRJWp32TqHfCHywjW2/qMEskiSpg9o8As/Mb7az7ZzajCNJkjqiI5+FTkQcDXwUSOC/MvM/azqVJElq12qvQo+IfwPOBKYDM4C/i4iraz2YJElqW0eOwA8ChmRmAkTEDcDTNZ1KkiS1qyO/Bz4bGNhi+QPVdZIkqU46cgS+OTArIh6vLn8YaIqISQCZ+alaDSdJklrXkYD7samSJK1nOvJ94A9HxPuoHHkDPJ6Zr9Z2LEmS1J6OXIV+HPA48DfAccBvIuLYWg8mSZLa1pFT6P9M5SNVXwWIiH7A/cBttRxMkiS1rSNXoW+0yinzBR28nyRJqpGOHIFPjoh7gQnV5c8C99RuJEmStDoduYjtay0+ShXgOj9KVZKk+lptwCPi0uqXl9zRyjpJklQHHXkv++BW1h3W2YNIkqSOa/MIPCL+HjgL2DEinmqxaXPg0VoPJkmS2tbeKfRfULlY7WLg6y3Wv5mZf6zpVJIkqV1tBjwz3wDeAMZ03TiSJKkj/H1uSZIKZMAlSSqQAZckqUAGXJKkAhlwSZIKZMAlSSqQAZckqUAGXJKkAhlwSZIKZMAlSSqQAZckqUAGXJKkAtU84BHRIyKmRsRd1eXBEfGbiJgdEbdGRK9azyBJUnfTFUfgZwOzWixfClyemTsDrwPjumAGSZK6lZoGPCIGAEcA46vLARwE3Fa9yQ3AZ2o5gyRJ3VGtj8CvAP4JeKe63BdYmJnLqsvNQP/W7hgRZ0REU0Q0zZ8/v8ZjSpJUlpoFPCKOBF7NzCfW5v6ZeV1mjs7M0f369evk6SRJKltDDfe9H/CpiDgc6A1sAVwJbBURDdWj8AHAvBrOIElSt1SzI/DM/H+ZOSAzBwHHAw9k5ljgQeDY6s1OBibWagZJkrqrevwe+DnAVyJiNpX3xK+vwwySJBWtlqfQV8rMh4CHqj8/B+zVFY8rSVJ35SexSZJUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBahbwiPhARDwYETMj4umIOLu6fpuIuC8inq3+e+tazSBJUndVyyPwZcA/ZubuwD7A5yNid+DrwJTM/BAwpbosSZLWQM0CnpkvZeb/Vn9+E5gF9Ac+DdxQvdkNwGdqNYMkSd1Vl7wHHhGDgFHAb4D3ZeZL1U0vA+9r4z5nRERTRDTNnz+/K8aUJKkYNQ94RGwG3A58OTMXtdyWmQlka/fLzOsyc3Rmju7Xr1+tx5QkqSg1DXhE9KQS759n5h3V1a9ExPbV7dsDr9ZyBkmSuqNaXoUewPXArMz8fotNk4CTqz+fDEys1QySJHVXDTXc937AicD0iJhWXfcN4BLg3yNiHPAH4LgaziBJUrdUs4Bn5n8B0cbmj9fqcSVJ2hD4SWySJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4infaaaex3XbbMWzYsFa3//znP2fEiBEMHz6cj3zkIzz55JMALFmyhL322ouRI0cydOhQzjvvvK4cW5LWiQFX8U455RQmT57c5vbBgwfz8MMPM336dL71rW9xxhlnALDxxhvzwAMP8OSTTzJt2jQmT57MY4891lVjS9I6MeAq3v77788222zT5vaPfOQjbL311gDss88+NDc3AxARbLbZZgAsXbqUpUuXEhG1H1jFmzx5Mrvuuis777wzl1xyyXu2z507lwMPPJBRo0YxYsQI7r77bgDuu+8+9txzT4YPH86ee+7JAw880NWjqxsx4NqgXH/99Rx22GErl5cvX05jYyPbbbcdBx98MHvvvXcdp1MJli9fzuc//3nuueceZs6cyYQJE5g5c+a7bnPRRRdx3HHHMXXqVG655RbOOussALbddlvuvPNOpk+fzg033MCJJ55Yj6egbsKAa4Px4IMPcv3113PppZeuXNejRw+mTZtGc3Mzjz/+ODNmzKjjhCrB448/zs4778yOO+5Ir169OP7445k4ceK7bhMRLFq0CIA33niDHXbYAYBRo0at/Hno0KEsXryYP//5z137BNRtGHBtEJ566ilOP/10Jk6cSN++fd+zfauttuLAAw9s9710CWDevHl84AMfWLk8YMAA5s2b967bnH/++dx8880MGDCAww8/nB/+8Ifv2c/tt9/OHnvswcYbb1zzmdU9GXB1e3PnzuXoo4/mpptuYpdddlm5fv78+SxcuI1g8zMAAAX6SURBVBCAxYsXc99997HbbrvVa0x1IxMmTOCUU06hubmZu+++mxNPPJF33nln5fann36ac845h2uvvbaOU6p0DfUeQFpXY8aM4aGHHuK1115jwIABXHDBBSxduhSAM888kwsvvJAFCxasfB+yoaGBpqYmXnrpJU4++WSWL1/OO++8w3HHHceRRx5Zz6eiAvTv358XXnhh5XJzczP9+/d/122uv/76lWdz9t13X5YsWcJrr73GdtttR3NzM0cddRQ33ngjO+20U5fOru7FgKt4EyZMaHf7+PHjGT9+/HvWjxgxgqlTp9ZqLHVTH/7wh3n22Wd5/vnn6d+/P7fccgu/+MUv3nWbgQMHMmXKFE455RRmzZrFkiVL6NevHwsXLuSII47gkksuYb/99qvTM1B34Sl0SVoDDQ0NXHXVVXziE59gyJAhHHfccQwdOpRzzz2XSZMmAfC9732PH//4x4wcOZIxY8bws5/9jIjgqquuYvbs2Vx44YU0NjbS2NjIq6++WudnpFJFZtZ7htUaPXp0NjU1ddr+Bn39l522L8GcS46o9wiS1tb5W9Z7gu7l/Dc6fZcR8URmjl51vUfgkiQVyPfAte78G3znqsHf4CV1Px6BS5JUII/AJRXFa1g615ze9Z5Aa8sjcEmSClSXgEfEoRHx24iYHRFfr8cMkiSVrMsDHhE9gKuBw4DdgTERsXtXzyFJUsnqcQS+FzA7M5/LzL8AtwCfrsMckiQVqx4XsfUHXmix3Ay850uYI+IM4Izq4lsR8dsumE1rIWBb4LV6z9FtXBD1nkAbEF+/naw2r98PtrZyvb0KPTOvA66r9xxavYhoau1TgiSt/3z9lqsep9DnAR9osTyguk6SJHVQPQL+P8CHImJwRPQCjgcm1WEOSZKK1eWn0DNzWUR8AbgX6AH8JDOf7uo51Kl8q0Mql6/fQhXxbWSSJOnd/CQ2SZIKZMAlSSqQAe+mImJ5REyLiKcj4smI+MeI2Ki67a8i4o3q9qci4v6I2K66bbeI+O+I+HNEfLWLZh0UEYur88yMiGtazDo5IhZGxF1dMYu0PliH1+/Y6rrpEfHriBjZBbO2+vqNiMbqnyVPV2f6bK1n2dAY8O5rcWY2ZuZQ4GAqH117Xovtj1S3j6DymwGfr67/I/Al4F/X5cEjYus1vMvvM7MRGEHlI3Y/U13/L8CJ6zKLVKC1ff0+DxyQmcOBb7OWF6h10uv3beCk6nM4FLgiIrZam3nUOgO+AcjMV6l8qt0XIuJdHxNUXd4ceH3FbTPzf4Cl6/iwn42IGdUjh35rMOsy4NfAztXlKcCb6ziLVKw1fP3+OjNfr25+jMrnbKyNdX79ZubvMvPZ6voXgVeBDu9Lq2fANxCZ+RyVX9vbrrrqYxExDZgL/DXwkzXZX0TcWj1ltuo/J1Uf7xoqRw19gF9FxG3Vb6Fr9/9zEdEH+DgwfY2eoNSNreXrdxxwT2v76+rXb0TsBfQCft/Bp6wOWG8/SlU190hmHgkQEecAlwFndvTOmbna97My8wXg2xFxEZU/DH4CNAGfauXmO1X/QEpgYma2+gePJGA1r9+IOJBKwD/a2p278vUbEdsDNwEnZ+Y7HXt66ggDvoGIiB2B5VROYw1ZZfMk4PY13N+twK6tbPp+Zt7Y4nZ7AadSeR/v34Eft7HLFe+hSVrFmrx+I2IEMB44LDMXtLG/Lnn9RsQWwC+Bf87Mx9q4r9aSAd8AVN/Duga4KjNzlbfRoPK39DU6tbW6v8FHxCFULoR7mcofJmdXvz5W0hpYk9dvRAwE7gBOzMzftbXPrnj9Vj8q+z+BGzPztjW5rzrGgHdfm1RPafUEllE5hfX9FttXvIcWwBvA6QAR8X4qp8m2AN6JiC8Du2fmojV8/AXAJzPzD+vyJCLiEWA3YLOIaAbGZea967JPqQBr9foFzgX6Av9WDf2ytfymsc54/R4H7A/0jYhTqutOycxp67BPteBHqUqSVCCvQpckqUAGXJKkAhlwSZIKZMAlSSqQAZckqUAGXJKkAhlwSZIK9P8BPypi1l0fjYgAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 504x576 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAfAAAAI4CAYAAACV/7uiAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3dfZyVdZ3w8c8XEPH5ERGExJQUH3DEWc3XltvqUlYuarGakVILa7Xbpvaw0rYVPYrVtla263prOlqJrtbC2p3dhpnaWu6oaCqZVBijCESgYiqOfO8/zgU74AzMMHPO+Bs+79eLF+dc17mu8z3W4TPXdc6cE5mJJEkqy6D+HkCSJPWcAZckqUAGXJKkAhlwSZIKZMAlSSqQAZckqUAGXJKkAhlwaQsiYk2HP+si4rkO16f2ct9vjIifRMTTEfHLvpq5v0XEsIjIiHi2+u/UFhEXRcSgav35EXFvRKyNiEv7c56I2DEiroyI31X/O9wTEZPqPZPUW0P6ewDplS4zd15/OSIWAzMy80d9tPs1wGXAHsAH+mifryQHZ2ZbRBwG/ARYCFwFtAGzgFNfAfPMBX4NfKqa61Tgxog4JDOfaPB8Urd5BC71UkTsEBHfiIil1ZHdlyJiu2rdSRGxKCI+HRF/iIjfRsRfrd82M/87M78NLO7G/QyJiBsjYllErI6IH0fEwR3W7xQRX4uIJRHxVHVkP6Ra94aI+Fm1/HcR8c4u7uNVEfF/q1l/FRHTOqybHRHfjohrI+KZiHggIpq6898oMx8C7gIOr67/R2bOA/6whce8Y3XEfFCHZftVZ0H2iIh9I+Lm6r/Hyoi4tafzZOaqzPxcZv4uM9dl5neBJ4GjurMvqb8YcKn3Pg1MAI4AjgbeAPxDh/VjgaHAvsDfAC0RccBW3tdc4MBqX78EWjqs+xpwCPAnwJ7APwFZxe8m4EvAXtWMD3Wx//8AHgFGAu8E/iUi/rTD+tOAbwK7A/OBi7szdEQcARwH3Ned26+XmX8E5gFndlj8DuCHmbkKuKCad+9q5lm9nSciRlP73+zhnswqNZoBl3pvKvCpzPx9Zi4DPgec1WF9O/DpzFxbnXr/ETClp3eSme2ZeXVmrsnM56n94HBM9frudsDZwN9n5pOZ+VJm3pGZL1Wz/Fdm3ljtY0Vm3r/p/iNiHHAk8I+Z+UJmtlL7AaHjY7k1M2+p9nsNsKUj8IciYhXwXeDrwHd6+rirbToG/J0d9vMiMAp4VfXf9/bezBMR21fL/i0zf7sVs0oN42vgUi9ERFA7Gn6sw+LHgP06XF9RBbfj+lFbcV9DgIuoHQXvDawDgtpR9WBqz+dfd7LpmC6Wb2pUNetzm8x6YofrT3a4/EdgZzbvsMxs68Z9b84PqZ21OBJ4HhgH/Fe17vPAZ4AfR8SLwL9m5le2Zp7qv+8cYCXwoV7OLNWdR+BSL2Tt6/yeBPbvsPhVwOMdru8dEcM2Wb81b456DzAJ+HNgN2qny6EW8aXUjvQP7GS7JV0s39QTwPCI2GGTWR/v4vYNkZkvAjdQOwp/J/C99T9kZOZTmXluZu4PvB34p01O+XdL9e74q4EdgTOqMwzSK5oBl3rvWuBTEbFXROwDfBz4Vof12wGfiIihEXECtQjfCLVwVHHfrnZ1w+nwzuxC7Qh0JbATtVP1wIbIXQ18NSJGRMTgiHhdRAymdqr75Ig4rXoj3PCImNDJ/hcBvwA+FxHbR8REYNomj6VPVHMMo3bmYHD1uAdvZpPvUHvt+0w6nPaOiMkR8erqTMhTwEvUzkz0ZJYArgBGA6dl5tqePRqpfxhwqfc+Se0NTw8BC4CfAl/ssH4xtaPjJ6m9Aew9mfmbat0bgeeovSb7muryf9G5K4AV1X5+Ady5yfoPUjtVfh+1yH8WiMxcBJwC/CO1d323AodtuvPqbMJfAYdW93Ed8NHM3PR++sLnqD3W84AZ1eWPbub2t1OL/W7U3kOw3njgx8Az1W2+nJl39XCW1wDvpvbmv+Xxv7/j//Ye7kdqqKg9ZyXVQ0ScBFySmQdt8caS1AMegUuSVCADLklSgTyFLklSgTwClySpQEV8kMvee++dY8eO7e8xJElquHvuuef3mTl80+VFBHzs2LG0trb29xiSJDVcRDzW2XJPoUuSVCADLklSgQy4JEkFKuI1cEmStuTFF1+kra2N559/fss3fgUaNmwYo0ePZrvtuvo6hI0ZcEnSgNDW1sYuu+zC2LFjqX1HTTkyk5UrV9LW1sYBBxzQrW08hS5JGhCef/559tprr+LiDRAR7LXXXj06e2DAJUkDRonxXq+nsxtwSZIK5GvgkqQBaezM7/fp/hbPfusWbzN48GCOOOII2tvbGT9+PC0tLey444789V//NTfddBP77LMPDz74YJ/M4xG4JEl9ZIcddmDBggU8+OCDDB06lEsvvRSAd7/73dx88819el8GXJKkOnj961/PokWLADj++OPZc889+3T/BrxBHnnkEZqamjb82XXXXbn44ov5wx/+wKRJkxg3bhyTJk1i1apVL9v2scceY+LEiTQ1NXHYYYdt+InumWee2Wife++9N+edd16jH5okaRPt7e384Ac/4IgjjqjbffgaeIMcfPDBLFiwAICXXnqJ/fbbj9NOO43Zs2dz4oknMnPmTGbPns3s2bO56KKLNtp25MiR3HXXXWy//fasWbOGww8/nMmTJzNq1KgN+wQ4+uijedvb3tbQxyVJ+l/PPfccTU1NQO0IfPr06XW7LwPeD+bPn8+BBx7I/vvvz9y5c7ntttsAmDZtGm94wxteFvChQ4duuPzCCy+wbt26l+3zV7/6FcuXL+f1r399XWeXJHVt/WvgjeAp9H4wZ84czjzzTACWLVvGyJEjAdh3331ZtmxZp9ssWbKECRMmMGbMGC644AJGjRr1sn2eccYZRf8OpCSp+zwCb7C1a9cyb948Lrzwwpeti4guAzxmzBgeeOABnnjiCU499VSmTJnCiBEjNqyfM2cO11xzTd3mlqTSdOfXvhrlzDPP5LbbbuP3v/89o0eP5tOf/nSvT68b8Ab7wQ9+wMSJEzfEd8SIESxdupSRI0eydOlS9tlnn81uP2rUKA4//HDuuOMOpkyZAsD9999Pe3s7Rx99dN3nlyR1bc2aNZ0uv/baa/v8vjyF3mDXXnvthtPnAJMnT6alpQWAlpYWTjnllJdt09bWxnPPPQfAqlWruPPOOzn44IO73KckaeAz4A307LPPcsstt2z0TvGZM2dyyy23MG7cOH70ox8xc+ZMAFpbW5kxYwYACxcu5Nhjj+XII4/kz/7sz/jIRz6y0a8mXH/99QZckrYxkZn9PcMWNTc3Z2tra3+PIUl6BVu4cCHjx4/v7zF6pbPHEBH3ZGbzprf1CFySpAIZcEmSCmTAJUkqUN1+jSwiDgau67Do1cAngaur5WOBxcDpmfnyDwCvo77+irlt3Svpdy0laYNZu/Xx/p7a4k06+zrRlStXcvbZZ7Ns2TIignPOOYdzzz231+PU7Qg8Mx/JzKbMbAKOBv4IfA+YCczPzHHA/Oq6JEnF6+zrRIcMGcI///M/8/DDD/Ozn/2Mb3zjGzz88MO9vq9GnUI/Efh1Zj4GnAK0VMtbgFMbNIMkSQ2z/utER44cycSJEwHYZZddGD9+PI8//niv99+ogL8DWP8xNCMyc2l1+UlgRGcbRMQ5EdEaEa0rVqxoxIySJPWJrr5OdPHixdx3330ce+yxvb6Pugc8IoYCk4H/2HRd1n4JvdNfRM/MyzKzOTObhw8fXucpJUnqvfVfJ9rc3MyrXvWqjT7vfM2aNbz97W/n4osvZtddd+31fTXis9DfDNybmeu/ZmtZRIzMzKURMRJY3oAZJEmqu66+TvTFF1/k7W9/O1OnTt3o0zh7oxGn0M/kf0+fA8wDplWXpwFzGzCDJEn9IjOZPn0648eP50Mf+lCf7beuR+ARsRMwCXhvh8WzgesjYjrwGHB6PWeQJG2juvFrX43w05/+lGuuuYYjjjiCpqYmAL7whS/wlre8pVf7rWvAM/NZYK9Nlq2k9q50SZIGlM6+TvR1r3sd9fjeET+JTZKkAhlwSZIKZMAlSQNGCV+R3ZWezm7AJUkDwrBhw1i5cmWREc9MVq5cybBhw7q9TSN+D1ySpLobPXo0bW1tlPrpncOGDWP06NHdvr0BlyQNCNtttx0HHHBAf4/RMJ5ClySpQAZckqQCGXBJkgpkwCVJKpABlySpQAZckqQCGXBJkgpkwCVJKpABlySpQAZckqQCGXBJkgpkwCVJKpABlySpQAZckqQCGXBJkgpkwCVJKpABlySpQAZckqQCGXBJkgpkwCVJKpABlySpQAZckqQCGXBJkgpkwCVJDbd69WqmTJnCIYccwvjx47nrrrv4xCc+wYQJE2hqauKNb3wjTzzxxMu2W7BgAccddxyHHXYYEyZM4Lrrrtuw7tZbb2XixIkcfvjhTJs2jfb29kY+pIYz4JKkhjv33HM56aST+OUvf8n999/P+PHj+ehHP8oDDzzAggULOPnkk/nMZz7zsu123HFHrr76ah566CFuvvlmzjvvPFavXs26deuYNm0ac+bM4cEHH2T//fenpaWlHx5Z4xhwSVJDPfXUU9x+++1Mnz4dgKFDh7L77ruz6667brjNs88+S0S8bNvXvOY1jBs3DoBRo0axzz77sGLFClauXMnQoUN5zWteA8CkSZO48cYbG/Bo+o8BlyQ11G9/+1uGDx/Oe97zHo466ihmzJjBs88+C8DHP/5xxowZw7e//e1Oj8A7uvvuu1m7di0HHngge++9N+3t7bS2tgJwww03sGTJkro/lv5kwCVJDdXe3s69997L+9//fu677z522mknZs+eDcDnP/95lixZwtSpU7nkkku63MfSpUs566yzuPLKKxk0aBARwZw5czj//PM55phj2GWXXRg8eHCjHlK/MOCSpIYaPXo0o0eP5thjjwVgypQp3HvvvRvdZurUqV2eAn/66ad561vfyuc//3le+9rXblh+3HHHcccdd3D33Xdz/PHHbzidPlAZcElSQ+27776MGTOGRx55BID58+dz6KGH8uijj264zdy5cznkkENetu3atWs57bTTOPvss5kyZcpG65YvXw7ACy+8wEUXXcT73ve+Oj6K/jekvweQJG17vv71rzN16lTWrl3Lq1/9aq688kpmzJjBI488wqBBg9h///259NJLAWhtbeXSSy/l8ssv5/rrr+f2229n5cqVXHXVVQBcddVVNDU18aUvfYmbbrqJdevW8f73v58TTjihHx9h/UVm9vcMW9Tc3Jzr35jQF8bO/H6f7UuwePZb+3sESRqwIuKezGzedLmn0CVJKpABlySpQAZckqQC+SY2SdqWzdqtvycYWGY91bC78ghckqQCGXBJkgpkwCVJKpABlySpQAZckqQCGXBJkgpkwCVJKpABlySpQAZckqQCGXBJkgpkwCVJKpABlySpQAZckqQCGXBJkgpkwCVJKpABlySpQAZckqQCGXBJkgpkwCVJKlBdAx4Ru0fEDRHxy4hYGBHHRcSeEXFLRDxa/b1HPWeQJGkgqvcR+FeBmzPzEOBIYCEwE5ifmeOA+dV1SZLUA3ULeETsBhwPXAGQmWszczVwCtBS3awFOLVeM0iSNFDV8wj8AGAFcGVE3BcRl0fETsCIzFxa3eZJYERnG0fEORHRGhGtK1asqOOYkiSVp54BHwJMBP4tM48CnmWT0+WZmUB2tnFmXpaZzZnZPHz48DqOKUlSeeoZ8DagLTN/Xl2/gVrQl0XESIDq7+V1nEGSpAGpbgHPzCeBJRFxcLXoROBhYB4wrVo2DZhbrxkkSRqohtR5/38PfDsihgK/Ad5D7YeG6yNiOvAYcHqdZ5AkacCpa8AzcwHQ3MmqE+t5v5IkDXR+EpskSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFciAS5JUIAMuSVKBDLgkSQUy4JIkFWhIPXceEYuBZ4CXgPbMbI6IPYHrgLHAYuD0zFxVzzkkSRpoGnEE/ueZ2ZSZzdX1mcD8zBwHzK+uS5KkHuiPU+inAC3V5Rbg1H6YQZKkotU74An8v4i4JyLOqZaNyMyl1eUngRGdbRgR50REa0S0rlixos5jSpJUlrq+Bg68LjMfj4h9gFsi4pcdV2ZmRkR2tmFmXgZcBtDc3NzpbSRJ2lbV9Qg8Mx+v/l4OfA84BlgWESMBqr+X13MGSZIGoroFPCJ2iohd1l8G3gg8CMwDplU3mwbMrdcMkiQNVPU8hT4C+F5ErL+f72TmzRHxP8D1ETEdeAw4vY4zSJI0INUt4Jn5G+DITpavBE6s1/1KkrQt8JPYJEkqkAGXJKlABlySpAIZcEmSCmTAJUkqkAGXJKlABlySpAIZcEmSCmTAJUkqkAGXJKlABlySpAIZcEmSCmTAJUkqkAGXJKlABlySpAIZcEmSCmTAJUkqkAGXJKlABlySpAIZcEmSCmTAJUkqkAGXJKlABlySpAIZcEmSCmTAJUkqkAGXJKlABlySpAIZcEmSCmTAJUkqkAGXJKlABlySpAIZcEmSCmTAJUkqkAGXJKlABlySpAIZcEmSCmTAJUkqkAGXJKlABlySpAIZcEmSCmTAJUkqkAGXJKlABlySpAIZcEmSCmTAJUkqkAGXJKlABlySpAIZcEmSCmTAJUkqkAGXJKlABlySpAIZcEmSCmTAJUkqkAGXJKlABlySpAIZcEmSCmTAJUkqkAGXJKlABlySpAIZcEmSCmTAJUkqULcDHhEHRcS3IuLGiDiuB9sNjoj7IuKm6voBEfHziFgUEddFxNCtGVySpG1ZlwGPiGGbLPos8DHgPODfenAf5wILO1y/CPiXzDwIWAVM78G+JEkSmz8C/6+IOLvD9ReBscD+wEvd2XlEjAbeClxeXQ/gBOCG6iYtwKk9G1mSJG0u4CcBu0bEzRFxPPAR4E3AacDUbu7/YuAfgHXV9b2A1ZnZXl1vA/br8dSSJG3jugx4Zr6UmZcAZwCTga8CV2bmhzPzl1vacUScDCzPzHu2ZrCIOCciWiOidcWKFVuzC0mSBqwhXa2IiGOBjwJrgS8AzwGfj4jHgc9m5uot7PtPgckR8RZgGLArtR8Cdo+IIdVR+Gjg8c42zszLgMsAmpubs0ePSpKkAW5zp9D/HfggMAv498z8dWa+A5gHXLelHWfmxzJzdGaOBd4B3JqZU4EfA1Oqm00D5m79+JIkbZs2F/B2/vdNa2vXL8zMn2Tmm3pxnxcAH4qIRdReE7+iF/uSJGmb1OUpdOCdwHupxfvszdxuizLzNuC26vJvgGN6sz9JkrZ1XQY8M38FfLiBs0iSpG7yo1QlSSqQAZckqUBbDHhE7BQRgzpcHxQRO9Z3LEmStDndOQKfD3QM9o7Aj+ozjiRJ6o7uBHxYZq5Zf6W67BG4JEn9qDsBfzYiJq6/EhFHU/tUNkmS1E8293vg650H/EdEPAEEsC+1z0eXJEn9ZIsBz8z/iYhDgIOrRY9k5ov1HUuSJG1Od96F/nfATpn5YGY+COwcEX9b/9EkSVJXuvMa+N90/OaxzFwF/E39RpIkSVvSnYAPjohYfyUiBgND6zeSJEnaku68ie1m4LqI+Pfq+nurZZIkqZ90J+AXUIv2+6vrtwCX120iSZK0Rd15F/o64N+qP5Ik6RVgiwGPiHHAhcChwLD1yzPz1XWcS5IkbUZ33sR2JbWj73bgz4GrgW/VcyhJkrR53Qn4Dpk5H4jMfCwzZwFvre9YkiRpc7rzJrYXqq8TfTQiPgA8Duxc37EkSdLmdOcI/Fxq3z72QeBo4F3AtHoOJUmSNq9bn4VeXVwDvKe+40iSpO7ozhG4JEl6hTHgkiQVyIBLklSgrQp4RHyyrweRJEndt7VH4DP6dApJktQjXb4LPSKe7moVsEN9xpEkSd2xuV8jWw38SWYu23RFRCyp30iSJGlLNncK/Wpg/y7WfacOs0iSpG7q8gg8M/9pM+suqM84kiSpO7rzWehExNuA1wEJ3JmZ36vrVJIkabO2+C70iPhX4H3AL4AHgfdGxDfqPZgkSepad47ATwDGZ2YCREQL8FBdp5IkSZvVnd8DXwS8qsP1MdUySZLUT7pzBL4LsDAi7q6u/wnQGhHzADJzcr2GkyRJnetOwP3YVEmSXmG6833gP4mIEdSOvAHuzszl9R1LkiRtTnfehX46cDfwV8DpwM8jYkq9B5MkSV3rzin0j1P7SNXlABExHPgRcEM9B5MkSV3rzrvQB21yynxlN7eTJEl10p0j8Jsj4ofAtdX1M4Af1G8kSZK0Jd15E9tHO3yUKsBlfpSqJEn9a4sBj4iLqi8v+W4nyyRJUj/ozmvZkzpZ9ua+HkSSJHVfl0fgEfF+4G+BV0fEAx1W7QL8tN6DSZKkrm3uFPp3qL1Z7UJgZoflz2TmH+o6lSRJ2qwuA56ZTwFPAWc2bhxJktQd/j63JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVKC6BTwihkXE3RFxf0Q8FBGfrpYfEBE/j4hFEXFdRAyt1wySJA1U9TwCfwE4ITOPBJqAkyLitcBFwL9k5kHAKmB6HWeQJGlAqlvAs2ZNdXW76k8CJwA3VMtbgFPrNYMkSQNVXV8Dj4jBEbEAWA7cAvwaWJ2Z7dVN2oD9utj2nIhojYjWFStW1HNMSZKKU9eAZ+ZLmdkEjAaOAQ7pwbaXZWZzZjYPHz68bjNKklSihrwLPTNXAz8GjgN2j4gh1arRwOONmEGSpIGknu9CHx4Ru1eXdwAmAQuphXxKdbNpwNx6zSBJ0kA1ZMs32WojgZaIGEztB4XrM/OmiHgYmBMRnwPuA66o4wySJA1IdQt4Zj4AHNXJ8t9Qez1ckiRtJT+JTZKkAhlwSZIKZMAlSSqQAZckqUAGXJKkAhlwSZIKZMAlSSqQAZckqUAGXJKkAhlwSZIKZMAlSSqQAZckqUAGXJKkAhlwSZIKZMAlSSqQAZckqUAGXJKkAhlwSZIKZMAlSSqQAZckqUAGXJKkAhlwSZIKZMAlSSqQAZckqUAGXJKkAhlwSZIKZMAlSSqQAZckqUAGXJKkAhlwSZIKZMAlSSqQAZckqUAGXJKkAhlwSZIKZMAlSSqQAZckqUAGXJKkAhlwSZIKZMAlSSqQAZckqUAGXJKkAhlwSZIKZMAlSSqQAZckqUAGXJKkAhlwSZIKZMAlSSqQAZckqUAGXJKkAhlwSZIKZMAlSSqQAZckqUAGXJKkAhlwSZIKZMAlSSqQAZckqUAGXJKkAhlwSZIKZMAlSSqQAZckqUAGXJKkAtUt4BExJiJ+HBEPR8RDEXFutXzPiLglIh6t/t6jXjNIkjRQ1fMIvB34cGYeCrwW+LuIOBSYCczPzHHA/Oq6JEnqgboFPDOXZua91eVngIXAfsApQEt1sxbg1HrNIEnSQNWQ18AjYixwFPBzYERmLq1WPQmMaMQMkiQNJHUPeETsDNwInJeZT3dcl5kJZBfbnRMRrRHRumLFinqPKUlSUeoa8IjYjlq8v52Z360WL4uIkdX6kcDyzrbNzMsyszkzm4cPH17PMSVJKk4934UewBXAwsz8SodV84Bp1eVpwNx6zSBJ0kA1pI77/lPgLOAXEbGgWvaPwGzg+oiYDjwGnF7HGSRJGpDqFvDMvBOILlafWK/7lSRpW+AnsUmSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgeoW8Ij4ZkQsj4gHOyzbMyJuiYhHq7/3qNf9S5I0kNXzCPwq4KRNls0E5mfmOGB+dV2SJPVQ3QKembcDf9hk8SlAS3W5BTi1XvcvSdJA1ujXwEdk5tLq8pPAiK5uGBHnRERrRLSuWLGiMdNJklSIfnsTW2YmkJtZf1lmNmdm8/Dhwxs4mSRJr3yNDviyiBgJUP29vMH3L0nSgNDogM8DplWXpwFzG3z/kiQNCPX8NbJrgbuAgyOiLSKmA7OBSRHxKPAX1XVJktRDQ+q148w8s4tVJ9brPiVJ2lb4SWySJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgEuSVCADLklSgQy4JEkFMuCSJBXIgGvAeOmllzjqqKM4+eSTX7buqquuYvjw4TQ1NdHU1MTll18OwGOPPcbEiRNpamrisMMO49JLL2302JK0VYb09wBSX3xVxn8AAAf2SURBVPnqV7/K+PHjefrppztdf8YZZ3DJJZdstGzkyJHcddddbL/99qxZs4bDDz+cyZMnM2rUqEaMLElbzSNwDQhtbW18//vfZ8aMGT3abujQoWy//fYAvPDCC6xbt64e40lSnzPgGhDOO+88vvjFLzJoUNf/l77xxhuZMGECU6ZMYcmSJRuWL1myhAkTJjBmzBguuOACj77VLZt7yeaFF17gjDPO4KCDDuLYY49l8eLFG63/3e9+x84778yXv/zlBk2rgciAq3g33XQT++yzD0cffXSXt/nLv/xLFi9ezAMPPMCkSZOYNm3ahnVjxozhgQceYNGiRbS0tLBs2bJGjK3CrX/JpjNXXHEFe+yxB4sWLeL888/nggsu2Gj9hz70Id785jc3YkwNYAZcxfvpT3/KvHnzGDt2LO94xzu49dZbede73rXRbfbaa68Np8pnzJjBPffc87L9jBo1isMPP5w77rijIXOrXFt6yWbu3LkbfkicMmUK8+fPJzMB+M///E8OOOAADjvssIbNq4HJgKt4F154IW1tbSxevJg5c+Zwwgkn8K1vfWuj2yxdunTD5Xnz5m04cmpra+O5554DYNWqVdx5550cfPDBjRteRdrSSzaPP/44Y8aMAWDIkCHsttturFy5kjVr1nDRRRfxqU99qpHjaoDyXegasD75yU/S3NzM5MmT+drXvsa8efMYMmQIe+65J1dddRUACxcu5MMf/jARQWbykY98hCOOOKJ/B9crWseXbG677bYebTtr1izOP/98dt555/oMp21KrD+t80rW3Nycra2tfba/sTO/32f7Eiye/db+HkFqmI997GNcc801DBkyhOeff56nn36at73tbRud9XnTm97ErFmzOO6442hvb2ffffdlxYoVHH/88RveQLl69WoGDRrEZz7zGT7wgQ/018OBWbv1330PRLOe6vNdRsQ9mdm86XKPwCWpBy688EIuvPBCAG677Ta+/OUvv+wlm8mTJ9PS0sJxxx3HDTfcwAknnEBEbPT+ilmzZrHzzjv3b7xVNAMuSX2g40s206dP56yzzuKggw5izz33ZM6cOf09ngYgT6Gr1zyFLhXMU+h9y1PoKor/APStOvwDIGng8dfIJEkqkAGXJKlAnkKXVBTfw9K3Fg/r7wm0tTwClySpQAZckqQCGXBJkgpkwCVJKpABlySpQAZckqQCGXBJkgpkwCVJKpABlySpQAZckqQCGXBJkgpkwCVJKpABlySpQAZckqQCGXBJkgpkwCVJKpABlySpQAZckqQCGXBJkgpkwCVJKpABlySpQAZckqQCGXBJkgpkwCVJKpABlySpQP0S8Ig4KSIeiYhFETGzP2aQJKlkDQ94RAwGvgG8GTgUODMiDm30HJIklaw/jsCPARZl5m8ycy0wBzilH+aQJKlYQ/rhPvcDlnS43gYcu+mNIuIc4Jzq6pqIeKQBs2krBOwN/L6/5xgwPh39PYG2IT5/+1h9nr/7d7awPwLeLZl5GXBZf8+hLYuI1sxs7u85JPWcz99y9ccp9MeBMR2uj66WSZKkbuqPgP8PMC4iDoiIocA7gHn9MIckScVq+Cn0zGyPiA8APwQGA9/MzIcaPYf6lC91SOXy+VuoyMz+nkGSJPWQn8QmSVKBDLgkSQUy4ANURLwUEQsi4qGIuD8iPhwRg6p1b4iIp6r1D0TEjyJin2rdIRFxV0S8EBEfadCsYyPiuWqehyPi0g6z3hwRqyPipkbMIr0S9OL5O7Va9ouI+O+IOLIBs3b6/I2Ipurfkoeqmc6o9yzbGgM+cD2XmU2ZeRgwidpH136qw/o7qvUTqP1mwN9Vy/8AfBD4cm/uPCL26OEmv87MJmACtY/YPbVa/iXgrN7MIhVoa5+/vwX+LDOPAD7LVr5BrY+ev38Ezq4ew0nAxRGx+9bMo84Z8G1AZi6n9ql2H4iIjT4mqLq+C7Bq/W0z83+AF3t5t2dExIPVkcPwHszaDvw3cFB1fT7wTC9nkYrVw+fvf2fmqmr1z6h9zsbW6PXzNzN/lZmPVsufAJYD3d6XtsyAbyMy8zfUfm1vn2rR6yNiAfA74C+Ab/ZkfxFxXXXKbNM/Z1f3dym1o4Ydgdsj4obqW+g2+/+5iNgROBH4RY8eoDSAbeXzdzrwg8721+jnb0QcAwwFft3Nh6xueMV+lKrq7o7MPBkgIi4Avgi8r7sbZ+YWX8/KzCXAZyPic9T+Mfgm0ApM7uTmB1b/ICUwNzM7/YdHErCF529E/Dm1gL+us40b+fyNiJHANcC0zFzXvYen7jDg24iIeDXwErXTWOM3WT0PuLGH+7sOOLiTVV/JzKs73O4Y4D3UXse7Hvg/Xexy/WtokjbRk+dvREwALgfenJkru9hfQ56/EbEr8H3g45n5sy621VYy4NuA6jWsS4FLMjM3eRkNaj+l9+jU1pZ+go+IN1J7I9yT1P4xObf6+lhJPdCT529EvAr4LnBWZv6qq3024vlbfVT294CrM/OGnmyr7jHgA9cO1Smt7YB2aqewvtJh/frX0AJ4CpgBEBH7UjtNtiuwLiLOAw7NzKd7eP8rgb/MzMd68yAi4g7gEGDniGgDpmfmD3uzT6kAW/X8BT4J7AX8axX69q38prG+eP6eDhwP7BUR766WvTszF/Rin+rAj1KVJKlAvgtdkqQCGXBJkgpkwCVJKpABlySpQAZckqQCGXBJkgpkwCVJKtD/B6WzT4V0kO8zAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 504x576 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAfAAAAI4CAYAAACV/7uiAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3de5xW5X3v/c9PBjegYpQgUQlCjFEi6IhTjY26G63UmDxR1ETQNhqx1Bx2k50m0adpUm1Mk1irJk9sjTU2eAhgPQSjlTyI5mixGXRURI2HaASjIlFUQAX87T/uBXvEGZgB1gzX8Hm/XvOae13XOvxu9J7vXNdas1ZkJpIkqSzb9HYBkiSp+wxwSZIKZIBLklQgA1ySpAIZ4JIkFcgAlySpQAa4JEkFMsClDYiIV9p9vRERK9otn7KJ+x4fET+LiJci4qHNVXNvi4gBEZERsaz6d1oYEd+KiG2q/v8dEXdHxOsRcWlv1hMRgyLi3yPid9V/h3kRcVTdNUmbqqm3C5C2dJm5/ZrXEfEEcEZm3raZdv8KcBmwE/CZzbTPLcnembkwIvYFfgY8CPwAWAicAxy3BdQzE3gM+PuqruOA6yNin8x8uofrk7rMEbi0iSJiYERcEhG/r0Z2/xQR/au+oyPi0Yg4NyL+EBG/jYiPrtk2M+/MzGuAJ7pwnKaIuD4ino2IFyPijojYu13/dhHxnYh4KiKWViP7pqrvTyJibtX+u4g4uZNjjIiI/6xq/U1EnNqu75sRcU1ETIuIlyPivoho7sq/UWY+APwXMKZa/o/MvAn4wwbe86BqxPzudm27V7MgO0XEOyJiVvXvsSQibu9uPZn5Qmael5m/y8w3MvMG4BnggK7sS+otBri06c4F9gPGAgcCfwJ8qV3/SGBb4B3AXwJTI2LURh5rJrBnta+HgKnt+r4D7AP8EbAz8HdAVuF3M/BPwJCqxgc62f9/AA8DuwInAxdFxPvb9U8ArgDeBswBLu5K0RExFjgEuKcr66+RmcuBm4BJ7ZonAj/JzBeAs6p6317VfM6m1hMRw2n8N1vQnVqlnmaAS5vuFODvM/P5zHwWOA/4i3b9q4BzM/P1aur9NuDE7h4kM1dl5pWZ+UpmvkrjF4eDqvO7/YGPA/8rM5/JzNWZ+YvMXF3V8uPMvL7ax+LMvHfd/UfEXsD+wN9m5muZ2UrjF4T27+X2zJxd7fcqYEMj8Aci4gXgBuD/A37Y3fddbdM+wE9ut5+VwG7AiOrf9+ebUk9E/I+q7V8z87cbUavUYzwHLm2CiAgao+En2zU/CezebnlxFbjt+3fbiGM1Ad+iMQp+O/AGEDRG1f1ofJ4f62DTd3bSvq7dqlpXrFPrke2Wn2n3ejmwPeu3b2Yu7MKx1+cnNGYt9gdeBfYCflz1fR34B+COiFgJ/EtmXrgx9VT/vtOBJcDnN7FmqXaOwKVNkI3H+T0D7NGueQSwqN3y2yNiwDr9G3Nx1CeAo4APADvSmC6HRoj/nsZIf88Otnuqk/Z1PQ0MjYiB69S6qJP1e0RmrgSuozEKPxm4cc0vGZm5NDM/m5l7ACcAf7fOlH+XVFfHXwkMAk6qZhikLZoBLm26acDfR8SQiNgF+DJwdbv+/sBXImLbiDiCRghfD43gqMK9f2Nx7XR4R3agMQJdAmxHY6oeWBtyVwLfjohhEdEvIg6NiH40pro/HBETqgvhhkbEfh3s/1HgfuC8iPgfETEOOHWd97JZVHUMoDFz0K963/3Ws8kPaZz7nkS7ae+I+EhEvKuaCVkKrKYxM9GdWgL4PjAcmJCZr3fv3Ui9wwCXNt1XaVzw9ADQBvwKOL9d/xM0RsfP0LgA7BOZ+XjVNx5YQeOc7Huq1z+mY98HFlf7uR/45Tr9f01jqvweGiH/NSAy81HgWOBvaVz13Qrsu+7Oq9mEjwLvrY4xA/hiZq57nM3hPBrv9XPAGdXrL65n/Z/TCPsdaVxDsMZo4A7g5WqdCzLzv7pZy3uA02hc/Pdc/N+/8T+hm/uRelQ0PrOS6hARRwPfzcx3b3BlSeoGR+CSJBXIAJckqUBOoUuSVCBH4JIkFaiIG7m8/e1vz5EjR/Z2GZIk9bh58+Y9n5lD120vIsBHjhxJa2trb5chSVKPi4gnO2p3Cl2SpAIZ4JIkFcgAlySpQEWcA5ckaUNWrlzJwoULefXVVze88hZowIABDB8+nP79O3scwpsZ4JKkPmHhwoXssMMOjBw5ksYzasqRmSxZsoSFCxcyatSoLm3jFLokqU949dVXGTJkSHHhDRARDBkypFuzBwa4JKnPKDG81+hu7bUGeET874h4ICLmR8S06pm/oyLiroh4NCJmRMS2ddYgSVJfVNs58IjYncbzid+bmSsi4lpgInAMcFFmTo+IS4HJwL/WVYckaes08uxbNuv+nvjmhza4Tr9+/Rg7diyrVq1i9OjRTJ06lUGDBnH66adz8803s8suuzB//vzNUk/dU+hNwMCIaAIGAb8HjgCuq/qnAsfVXIMkST1i4MCBtLW1MX/+fLbddlsuvfRSAE477TRmzZq1WY9VW4Bn5iLgAuB3NIJ7KTAPeDEzV1WrLQR272j7iJgSEa0R0bp48eK6ypQkqRaHHXYYjz76KACHH344O++882bdf20BHhE7AccCo4DdgO2Ao7u6fWZelpktmdkydOhb7uFenIcffpjm5ua1X4MHD+biiy/mpJNOWts2cuRImpubO9z+oosuYt9992XMmDFMmjRp7ZWKc+bMYdy4cTQ3N3PooYeu/Z9FktR7Vq1axa233srYsWNrO0adU+h/Cvw2Mxdn5krgBuD9wNuqKXWA4cCiGmvYYuy99960tbXR1tbGvHnzGDRoEBMmTGDGjBlr20844QSOP/74t2y7aNEivvOd79Da2sr8+fNZvXo106dPB+CTn/wk11xzDW1tbZx88smcd955Pf3WJEmVFStW0NzcTEtLCyNGjGDy5Mm1HavOG7n8DnhfRAwCVgBHAq3AHcCJwHTgVGBmjTVskebMmcOee+7JHnvssbYtM7n22mu5/fbbO9xm1apVrFixgv79+7N8+XJ22203oPFnBy+99BIAS5cuXdsuSep5a86B94TaAjwz74qI64C7gVXAPcBlwC3A9Ig4r2r7fl01bKmmT5/OpEmT3tT2i1/8gmHDhrHXXnu9Zf3dd9+dL3zhC4wYMYKBAwcyfvx4xo8fD8Dll1/OMcccw8CBAxk8eDBz587tkfcgSeplmbnFfx144IHZV7z22ms5ZMiQfOaZZ97UfuaZZ+YFF1zQ4TZ/+MMf8gMf+EA+99xz+frrr+exxx6bV111VWZmTpgwIefOnZuZmeeff35Onjy53jcgSVuoBQsW9HYJud1223XYPnHixHzHO96RTU1Nufvuu+fll1/e4XodvQegNTvIRu+F3sNuvfVWxo0bx7Bhw9a2rVq1ihtuuIF58+Z1uM1tt93GqFGjWHMx3/HHH8+dd97Jn/3Zn3Hvvfdy8MEHA3DSSSdx9NFdvk5QkrSZvfLKKx22T5s2bbMfy1up9rBp06a9Zfr8tttuY5999mH48OEdbjNixAjmzp3L8uXLyUzmzJnD6NGj2WmnnVi6dCm/+c1vAJg9ezajR4+u/T1IknqfI/AetGzZMmbPns33vve9N7V3dE786aef5owzzuA///M/OfjggznxxBMZN24cTU1NHHDAAUyZMoWmpib+7d/+jRNOOIFtttmGnXbaiSuuuKIn35IkqZdEY3p9y9bS0pKtra29XYYkaQv24IMPFj8L2dF7iIh5mdmy7rpOoUuSVCADXJKkAhngkiQVaKu8iG1zP2Jua9eVR+xJUo87Z8fNvL+lG1ylo8eJLlmyhI9//OM8++yzRARTpkzhs5/97CaX4whckqTNpKPHiTY1NfHP//zPLFiwgLlz53LJJZewYMGCTT6WAS5JUg3WPE501113Zdy4cQDssMMOjB49mkWLNv05Xga4JEmbWWePE33iiSe455571t5Bc1NslefAJUmqw5rHiUJjBN7+caKvvPIKJ5xwAhdffDGDBw/e5GMZ4JIkbSadPU505cqVnHDCCZxyyikcf/zxm+VYTqFLklSjzGTy5MmMHj2az3/+85ttv47AJUl9Uxf+7Ksn/OpXv+Kqq65i7Nixa6fX//Ef/5Fjjjlmk/ZrgEuStJl09DjRQw89lDqeO+IUuiRJBTLAJUkqkAEuSeozSnhEdme6W7sBLknqEwYMGMCSJUuKDPHMZMmSJQwYMKDL23gRmySpTxg+fDgLFy5k8eLFvV3KRhkwYADDhw/v8voGuCSpT+jfvz+jRo3q7TJ6jFPokiQVyACXJKlABrgkSQUywCVJKpABLklSgQxwSVKPevjhh2lubl77NXjwYC6++GK++MUvss8++7DffvsxYcIEXnzxxQ63v+iii9h3330ZM2YMkyZN4tVXXwXglFNOYe+992bMmDGcfvrprFy5siffVo8zwCVJPWrvvfemra2NtrY25s2bx6BBg5gwYQJHHXUU8+fP57777uM973kP3/jGN96y7aJFi/jOd75Da2sr8+fPZ/Xq1UyfPh1oBPhDDz3E/fffz4oVK7j88st7+q31KANcktRr5syZw5577skee+zB+PHjaWpq3J7kfe97HwsXLuxwm1WrVrFixQpWrVrF8uXL2W233QA45phjiAgigoMOOqjT7fsKA1yS1GumT5/OpEmT3tJ+xRVX8MEPfvAt7bvvvjtf+MIXGDFiBLvuuis77rgj48ePf9M6K1eu5KqrruLoo4+ure4tgQEuSeoVr7/+OjfddBMf/ehH39T+9a9/naamJk455ZS3bPPCCy8wc+ZMfvvb3/L000+zbNkyrr766jet86lPfYrDDz+cww47rNb6e5sBLknqFbfeeivjxo1j2LBha9t+8IMfcPPNN3PNNdcQEW/Z5rbbbmPUqFEMHTqU/v37c/zxx3PnnXeu7T/33HNZvHgxF154YY+8h97kvdAlSb1i2rRpb5o+nzVrFueffz4/+9nPGDRoUIfbjBgxgrlz57J8+XIGDhzInDlzaGlpAeDyyy/nJz/5CXPmzGGbbfr++LTvv0NJ0hZn2bJlzJ49m+OPP35t22c+8xlefvlljjrqKJqbmznzzDMBePrppznmmGMAOPjggznxxBMZN24cY8eO5Y033mDKlCkAnHnmmTz77LMccsghNDc38w//8A89/8Z6UJTw3NSWlpZsbW3dbPsbefYtm21fgie++aHeLkGS+qyImJeZLeu2OwKXJKlABrgkSQUywCVJKpBXoUvS1uycHXu7gr7lnKU9dihH4JIkFcgAlySpQAa4JEkFMsAlSSqQAS5JUoEMcEmSCmSAS5JUIANckqQCGeCSJBXIAJckqUAGuCRJBTLAJUkqkAEuSVKBDHBJkgpUW4BHxN4R0dbu66WI+FxE7BwRsyPiker7TnXVIElSX1VbgGfmw5nZnJnNwIHAcuBG4GxgTmbuBcypliVJUjf01BT6kcBjmfkkcCwwtWqfChzXQzVIktRn9FSATwSmVa+HZebvq9fPAMM62iAipkREa0S0Ll68uCdqlCSpGLUHeERsC3wE+I91+zIzgexou8y8LDNbMrNl6NChNVcpSVJZemIE/kHg7sx8tlp+NiJ2Bai+P9cDNUiS1Kf0RIBP4v9OnwPcBJxavT4VmNkDNUiS1KfUGuARsR1wFHBDu+ZvAkdFxCPAn1bLkiSpG5rq3HlmLgOGrNO2hMZV6ZIkaSN5JzZJkgpkgEuSVCADXJKkAhngkiQVyACXJKlABrgkSQUywCVJKpABLklSgQxwSZIKZIBLklQgA1ySpAIZ4JIkFcgAlySpQAa4JEkFMsAlSSqQAS5JUoEMcEmSCmSAS5JUIANckqQCGeCSJBXIAJckqUAGuCRJBTLAJUkqkAEuSVKBDHBJkgpkgEuSVCADXJKkAhngkiQVyACXJKlABrgkSQUywCVJKpABLklSgQxwSZIKZIBLklQgA1ySpAIZ4JIkFcgAlySpQAa4JEkFMsAlSSqQAS5JUoEMcEmSCmSAS5JUIANckqQCGeCSJBXIAJckqUAGuCRJBTLAJUkqkAEuSVKBDHBJkgpkgEuSVCADXJKkAhngkiQVyACXJKlABrgkSQUywCVJKlCtAR4Rb4uI6yLioYh4MCIOiYidI2J2RDxSfd+pzhokSeqL6h6BfxuYlZn7APsDDwJnA3Mycy9gTrUsSZK6obYAj4gdgcOB7wNk5uuZ+SJwLDC1Wm0qcFxdNUiS1FfVOQIfBSwG/j0i7omIyyNiO2BYZv6+WucZYFhHG0fElIhojYjWxYsX11imJEnlqTPAm4BxwL9m5gHAMtaZLs/MBLKjjTPzssxsycyWoUOH1limJEnlqTPAFwILM/Ouavk6GoH+bETsClB9f67GGiRJ6pNqC/DMfAZ4KiL2rpqOBBYANwGnVm2nAjPrqkGSpL6qqeb9/y/gmojYFngc+ASNXxqujYjJwJPAx2quQZKkPqfWAM/MNqClg64j6zyuJEl9nXdikySpQAa4JEkFMsAlSSqQAS5JUoEMcEmSCmSAS5JUIANckqQCGeCSJBXIAJckqUAGuCRJBTLAJUkqkAEuSVKBDHBJkgpkgEuSVCADXJKkAhngkiQVyACXJKlABrgkSQUywCVJKpABLklSgQxwSZIKZIBLklQgA1ySpAIZ4JIkFcgAlySpQAa4JEkFMsAlSSqQAS5JUoEMcEmSCmSAS5JUIANckqQCGeCSJBXIAJckqUAGuCRJBTLAJUkqkAEuSVKBDHBJkgpkgEuSVCADXJKkAhngkiQVyACXJKlABrgkSQUywCVJKpABLklSgQxwSZIKZIBLklQgA1ySpAIZ4JIkFcgAlySpQAa4JEkFMsAlSSqQAS5JUoEMcEmSCmSAS5JUIANckqQCNdW584h4AngZWA2sysyWiNgZmAGMBJ4APpaZL9RZhyRJfU1PjMA/kJnNmdlSLZ8NzMnMvYA51bIkSeqG3phCPxaYWr2eChzXCzVIklS0ugM8gf8/IuZFxJSqbVhm/r56/QwwrKMNI2JKRLRGROvixYtrLlOSpLLUeg4cODQzF0XELsDsiHiofWdmZkRkRxtm5mXAZQAtLS0driNJ0taq1hF4Zi6qvj8H3AgcBDwbEbsCVN+fq7MGSZL6otoCPCK2i4gd1rwGxgPzgZuAU6vVTgVm1lWDJEl9VZ1T6MOAGyNizXF+mJmzIuLXwLURMRl4EvhYjTVIktQn1Rbgmfk4sH8H7UuAI+s6riRJWwPvxCZJUoEMcEmSCmSAS5JUIANckqQCGeCSJBXIAJckqUAGuCRJBTLAJUkqkAEuSVKBDHBJkgpkgEuSVCADXJKkAhngkiQVyACXJKlABrgkSQUywCVJKpABLklSgQxwSZIKZIBLklQgA1ySpAIZ4JIkFcgAlySpQAa4JEkFMsAlSSqQAS5JUoEMcEmSCmSAS5JUoC4HeES8OyKujojrI+KQOouSJEnr19RZR0QMyMxX2zV9DfhS9frHQHOdhUmSpM6tbwT+44j4eLvllcBIYA9gdZ1FSZKk9VtfgB8NDI6IWRFxOPAF4M+ACcApPVGcJEnqWKdT6Jm5GvhuRFwFfAX4JPB3mflYTxUnSZI6tr5z4AcDXwReB/4RWAF8PSIWAV/LzBd7pkRJkrSuTgMc+B5wDLA98O+Z+X5gYkT8T2AGjel0SZLUC9YX4KtoXLS2HY1ROACZ+TPgZ/WWJUmS1md9AX4y8Fc0wvvj61lPkiT1sPVdxPYb4G96sBZJktRF3kpVkqQCGeCSJBVogwEeEdtFxDbtlreJiEH1liVJktanKyPwOUD7wB4E3FZPOZIkqSu6EuADMvOVNQvVa0fgkiT1oq4E+LKIGLdmISIOpHFXNkmS1EvW93fga3wO+I+IeBoI4B3ASbVWJUmS1muDAZ6Zv46IfYC9q6aHM3NlvWVJkqT16cpV6J8GtsvM+Zk5H9g+Ij5Vf2mSJKkzXTkH/pftnzyWmS8Af1lfSZIkaUO6EuD9IiLWLEREP2Db+kqSJEkb0pWL2GYBMyLie9XyX1VtkiSpl3QlwM+iEdqfrJZnA5fXVpEkSdqgrlyF/gbwr9WXJEnaAmwwwCNiL+AbwHuBAWvaM/NdNdYlSZLWoysXsf07jdH3KuADwJXA1XUWJUmS1q8rAT4wM+cAkZlPZuY5wIfqLUuSJK1PVy5ie616nOgjEfEZYBGwfb1lSZKk9enKCPyzNJ4+9tfAgcCfA6d29QAR0S8i7omIm6vlURFxV0Q8GhEzIsK/KZckqZs2GOCZ+evMfCUzF2bmJzLzhMyc241jfBZ4sN3yt4CLMvPdwAvA5O6VLEmSujIC32gRMZzG+fLLq+UAjgCuq1aZChxXZw2SJPVFtQY4cDHwJeCNankI8GJmrqqWFwK7d7RhREyJiNaIaF28eHHNZUqSVJbaAjwiPgw8l5nzNmb7zLwsM1sys2Xo0KGbuTpJksq2UQEeEV/twmrvBz4SEU8A02lMnX8beFtErLn6fTiNq9olSVI3bOwI/IwNrZCZ/29mDs/MkcBE4PbMPAW4AzixWu1UYOZG1iBJ0lar078Dj4iXOusCBm7CMc8CpkfEecA9wPc3YV+SJG2V1ncjlxeBP8rMZ9ftiIinunOQzPwp8NPq9ePAQd3ZXpIkvdn6ptCvBPbopO+HNdQiSZK6qNMReGb+3Xr6zqqnHEmS1BVduRc6EXE8cCiQwC8z88Zaq5IkSeu1wavQI+JfgDOB+4H5wF9FxCV1FyZJkjrXlRH4EcDozEyAiJgKPFBrVZIkab268nfgjwIj2i2/s2qTJEm9pCsj8B2AByPiv6vlPwJaI+ImgMz8SF3FSZKkjnUlwLty21RJktSDNhjgmfmziBhGY+QN8N+Z+Vy9ZUmSpPXpylXoHwP+G/go8DHgrog4cf1bSZKkOnVlCv3LNG6p+hxARAwFbgOuq7MwSZLUua5chb7NOlPmS7q4nSRJqklXRuCzIuInwLRq+STg1vpKkiRJG9KVi9i+2O5WqgCXeStVSZJ61wYDPCK+VT285IYO2iRJUi/oyrnsozpo++DmLkSSJHVdpyPwiPgk8CngXRFxX7uuHYBf1V2YJEnq3Pqm0H9I42K1bwBnt2t/OTP/UGtVkiRpvToN8MxcCiwFJvVcOZIkqSv8e25JkgpkgEuSVCADXJKkAhngkiQVyACXJKlABrgkSQUywCVJKpABLklSgQxwSZIKZIBLklQgA1ySpAIZ4JIkFcgAlySpQAa4JEkFMsAlSSqQAS5JUoEMcEmSCmSAS5JUIANckqQCGeCSJBXIAJckqUAGuCRJBTLAJUkqkAEuSVKBDHBJkgpkgEuSVCADXJKkAhngkiQVyACXJKlABrgkSQUywCVJKpABLklSgQxwSZIKZIBLklQgA1ySpALVFuARMSAi/jsi7o2IByLi3Kp9VETcFRGPRsSMiNi2rhokSeqr6hyBvwYckZn7A83A0RHxPuBbwEWZ+W7gBWByjTVIktQn1Rbg2fBKtdi/+krgCOC6qn0qcFxdNUiS1FfVeg48IvpFRBvwHDAbeAx4MTNXVassBHavswZJkvqiWgM8M1dnZjMwHDgI2Ker20bElIhojYjWxYsX11ajJEkl6pGr0DPzReAO4BDgbRHRVHUNBxZ1ss1lmdmSmS1Dhw7tiTIlSSpGnVehD42It1WvBwJHAQ/SCPITq9VOBWbWVYMkSX1V04ZX2Wi7AlMjoh+NXxSuzcybI2IBMD0izgPuAb5fYw2SJPVJtQV4Zt4HHNBB++M0zodLkqSN5J3YJEkqkAEuSVKBDHBJkgpkgEuSVCADXJKkAhngkiQVyACXJKlABrgkSQUywCVJKpABLklSgQxwSZIKZIBLklQgA1ySpAIZ4JIkFcgAlySpQAa4JEkFMsAlSSqQAS5JUoEMcEmSCmSAS5JUIANckqQCGeCSJBXIAJckqUAGuCRJBTLAJUkqkAEuSVKBDHBJkgpkgEuSVCADXJKkAhngkiQVyACXJKlABrgkSQUywCVJKpABLklSgQxwSZIKZIBLklQgA1ySpAIZ4JIkFcgAlySpQAa4JEkFMsAlSSqQAS5JUoEMcEmSCmSAS5JUIANckqQCGeCSJBXIAJckqUAGuCRJBTLAJUkqkAEuSVKBDHBJkgpkgEuSVCADXJKkAhngkiQVyACXJKlAtQV4RLwzIu6IiAUR8UBEfLZq3zkiZkfEI9X3neqqQZKkvqrOEfgq4G8y873A+4BPR8R7gbOBOZm5FzCnWpYkSd1QW4Bn5u8z8+7q9cvAg8DuwLHA1Gq1qcBxddUgSVJf1SPnwCNiJHAAcBcwLDN/X3U9AwzrZJspEdEaEa2LFy/uiTIlSSpG7QEeEdsD1wOfy8yX2vdlZgLZ0XaZeVlmtmRmy9ChQ+suU5KkotQa4BHRn0Z4X5OZN1TNz0bErlX/rsBzddYgSVJfVOdV6AF8H3gwMy9s13UTcGr1+lRgZl01SJLUVzXVuO/3A38B3B8RbVXb3wLfBK6NiMnAk8DHaqxBkqQ+qbYAz8xfAtFJ95F1HVeSpK2Bd2KTJKlABrgkSQUywCVJKpABLklSgQxwSZIKZIBLklQgA1ySpAIZ4JIkFcgAlySpQAa4JEkFMsAlSSqQAS5JUoEMcEmSCmSAS5JUIANckqQCGeCSJBXIAJckqUAGuCRJBTLAJUkqkAEuSVKBDHBJkgpkgEuSVCADXJKkAhngkiQVyACXJKlABrgkSQUywCVJKpABLklSgQxwSZIKZIBLklQgA1ySpAIZ4JIkFcgAlySpQAa4JEkFMsAlSSqQAS5JUoEMcEmSCmSAS5JUIANckqQCGeCSJBXIAJckqUAGuCRJBTLAJUkqkAEuSVKBDHBJkgpkgEuSVCADXJKkAhngkiQVyACXJKlABrgkSQUywCVJKpABLklSgQxwSZIKZIBLklQgA1ySpALVFuARcUVEPBcR89u17RwRsyPiker7TnUdX5KkvqzOEfgPgKPXaTsbmJOZewFzqmVJktRNtQV4Zv4c+MM6zccCU6vXU4Hj6jq+JEl9WU+fAx+Wmb+vXj8DDMS2MnUAAAkrSURBVOtsxYiYEhGtEdG6ePHinqlOkqRC9NpFbJmZQK6n/7LMbMnMlqFDh/ZgZZIkbfl6OsCfjYhdAarvz/Xw8SVJ6hN6OsBvAk6tXp8KzOzh40uS1CfU+Wdk04D/AvaOiIURMRn4JnBURDwC/Gm1LEmSuqmprh1n5qROuo6s65iSJG0tvBObJEkFMsAlSSqQAS5JUoEMcEmSCmSAS5JUIANckqQCGeCSJBXIAJckqUAGuCRJBTLAJUkqkAEuSVKBDHBJkgpkgEuSVCADXJKkAhngkiQVyACXJKlABrgkSQUywCVJKpABLkkbYfXq1RxwwAF8+MMffkvfpZdeytixY2lububQQw9lwYIFAFxzzTU0Nzev/dpmm21oa2vr6dLVRxjgkrQRvv3tbzN69OgO+04++WTuv/9+2tra+NKXvsTnP/95AE455RTa2tpoa2vjqquuYtSoUTQ3N/dk2epDDHBJ6qaFCxdyyy23cMYZZ3TYP3jw4LWvly1bRkS8ZZ1p06YxceLE2mpU39fU2wVIUmk+97nPcf755/Pyyy93us4ll1zChRdeyOuvv87tt9/+lv4ZM2Ywc+bMOstUH+cIXJK64eabb2aXXXbhwAMPXO96n/70p3nsscf41re+xXnnnfemvrvuuotBgwYxZsyYOktVH2eAS1I3/OpXv+Kmm25i5MiRTJw4kdtvv50///M/73T9iRMn8qMf/ehNbdOnT2fSpEl1l6o+zgCXpG74xje+wcKFC3niiSeYPn06RxxxBFdfffWb1nnkkUfWvr7lllvYa6+91i6/8cYbXHvttZ7/1ibzHLgkbQZf/epXaWlp4SMf+Qjf/e53ue222+jfvz877bQTU6dOXbvez3/+c975znfyrne9qxerVV8QmdnbNWxQS0tLtra2brb9jTz7ls22L8ET3/xQb5cgaWOds2NvV9C3nLN0s+8yIuZlZsu67U6hS5JUIANckqQCGeCSJBXIi9gkFcVrWDavJwb0dgXaWI7AJUkqkAEuSVKBDHBJkgpkgEuSVCADXJKkAhngkiQVyABX8U4//XR22WWXTh/NOHPmTPbbbz+am5tpaWnhl7/85dq+s846izFjxjBmzBhmzJjRUyVL0iYzwFW80047jVmzZnXaf+SRR3LvvffS1tbGFVdcwRlnnAE0nhJ1991309bWxl133cUFF1zASy+91FNlS9ImMcBVvMMPP5ydd9650/7tt9+eiABg2bJla18vWLCAww8/nKamJrbbbjv222+/9f4iIElbEgNcW4Ubb7yRffbZhw996ENcccUVAOy///7MmjWL5cuX8/zzz3PHHXfw1FNP9XKlktQ1Bri2ChMmTOChhx7iRz/6EV/5ylcAGD9+PMcccwx//Md/zKRJkzjkkEPo169fL1cqSV1jgGurcvjhh/P444/z/PPPA/DlL3+ZtrY2Zs+eTWbynve8p5crlKSuMcDV5z366KNkJgB33303r732GkOGDGH16tUsWbIEgPvuu4/77ruP8ePH92apktRlPo1MxZs0aRI//elPef755xk+fDjnnnsuK1euBODMM8/k+uuv58orr6R///4MHDiQGTNmEBGsXLmSww47DIDBgwdz9dVX09TkR0JSGfxppeJNmzZtvf1nnXUWZ5111lvaBwwYwIIFC+oqS5Jq5RS6JEkFMsAlSSqQAS5JUoE8B65Nd86OvV1B33LO0t6uQFIBHIFLklQgA1ySpAIZ4JIkFcgAlySpQAa4JEkF6pUAj4ijI+LhiHg0Is7ujRokSSpZjwd4RPQDLgE+CLwXmBQR7+3pOiRJKllvjMAPAh7NzMcz83VgOnBsL9QhSVKxeuNGLrsDT7VbXggcvO5KETEFmFItvhIRD/dAbdoIAW8Hnu/tOvqMc6O3K9BWxM/vZlbP53ePjhq32DuxZeZlwGW9XYc2LCJaM7Olt+uQ1H1+fsvVG1Poi4B3tlseXrVJkqQu6o0A/zWwV0SMiohtgYnATb1QhyRJxerxKfTMXBURnwF+AvQDrsjMB3q6Dm1WnuqQyuXnt1CRmb1dgyRJ6ibvxCZJUoEMcEmSCmSA91ERsToi2iLigYi4NyL+JiK2qfr+JCKWVv33RcRtEbFL1bdPRPxXRLwWEV/ooVpHRsSKqp4FEXFpu1pnRcSLEXFzT9QibQk24fN7StV2f0TcGRH790CtHX5+I6K5+lnyQFXTSXXXsrUxwPuuFZnZnJn7AkfRuHXt37fr/0XVvx+Nvwz4dNX+B+CvgQs25eARsVM3N3ksM5uB/WjcYve4qv2fgL/YlFqkAm3s5/e3wP/MzLHA19jIC9Q20+d3OfDx6j0cDVwcEW/bmHrUMQN8K5CZz9G4q91nIuJNtwmqlncAXlizbmb+Gli5iYc9KSLmVyOHod2odRVwJ/DuankO8PIm1iIVq5uf3zsz84Wqey6N+2xsjE3+/GbmbzLzkar9aeA5oMv70oYZ4FuJzHycxp/t7VI1HRYRbcDvgD8FrujO/iJiRjVltu7Xx6vjXUpj1DAI+HlEXFc9hW69/89FxCDgSOD+br1BqQ/byM/vZODWjvbX05/fiDgI2BZ4rItvWV2wxd5KVbX7RWZ+GCAizgLOB87s6saZucHzWZn5FPC1iDiPxg+DK4BW4CMdrL5n9QMpgZmZ2eEPHknABj6/EfEBGgF+aEcb9+TnNyJ2Ba4CTs3MN7r29tQVBvhWIiLeBaymMY01ep3um4Dru7m/GcDeHXRdmJlXtlvvIOATNM7jXQv8Wye7XHMOTdI6uvP5jYj9gMuBD2bmkk721yOf34gYDNwCfDkz53ayrTaSAb4VqM5hXQp8NzNzndNo0PgtvVtTWxv6DT4ixtO4EO4ZGj9MPls9PlZSN3Tn8xsRI4AbgL/IzN90ts+e+PxWt8q+EbgyM6/rzrbqGgO87xpYTWn1B1bRmMK6sF3/mnNoASwFzgCIiHfQmCYbDLwREZ8D3puZL3Xz+EuA/yczn9yUNxERvwD2AbaPiIXA5Mz8yabsUyrARn1+ga8CQ4B/qYJ+1UY+aWxzfH4/BhwODImI06q20zKzbRP2qXa8laokSQXyKnRJkgpkgEuSVCADXJKkAhngkiQVyACXJKlABrgkSQUywCVJKtD/AeZY2sWuUy5fAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 504x576 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "for ix in range(0,3):\n",
    "    b1_d = d_list[2*ix]\n",
    "    b2_d = d_list[2*ix + 1]\n",
    "    \n",
    "    \n",
    "\n",
    "    labels = ['DB1 => P1', 'DB2 => P2']\n",
    "    theirs = [round(b1_d['Max weight acc']*100,2), round(b2_d['Max weight acc ngt']*100, 2)]\n",
    "    ours = [round(b1_d['Max weight acc ngt']*100, 2),  round(b2_d['Max weight acc']*100, 2)]\n",
    "\n",
    "    x = np.arange(len(labels))  # the label locations\n",
    "    width = 0.35  # the width of the bars\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=(7,8))\n",
    "    rects1 = ax.bar(x - width/2, theirs, width, label='P1')\n",
    "    rects2 = ax.bar(x + width/2, ours, width, label='P2')\n",
    "\n",
    "    # Add some text for labels, title and custom x-axis tick labels, etc.\n",
    "    ax.set_ylabel('top1 acc %')\n",
    "    ax.set_title('Top1 acc on P1 vs P2')\n",
    "    ax.set_xticks(x)\n",
    "    ax.set_xticklabels(labels)\n",
    "    ax.legend()\n",
    "\n",
    "    def autolabel(rects):\n",
    "        \"\"\"Attach a text label above each bar in *rects*, displaying its height.\"\"\"\n",
    "        for rect in rects:\n",
    "            height = rect.get_height()\n",
    "            ax.annotate('{}'.format(height),\n",
    "                        xy=(rect.get_x() + rect.get_width() / 2, height),\n",
    "                        xytext=(0, 3),  # 3 points vertical offset\n",
    "                        textcoords=\"offset points\",\n",
    "                        ha='center', va='bottom')\n",
    "\n",
    "\n",
    "    autolabel(rects1)\n",
    "    autolabel(rects2)\n",
    "\n",
    "    fig.tight_layout()\n",
    "\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "# for ix in range(0,3):\n",
    "#     b1_d = d_list[2*ix]\n",
    "#     b2_d = d_list[2*ix + 1]\n",
    "    \n",
    "    \n",
    "\n",
    "#     labels = ['DB1 => P1', 'DB2 => P2']\n",
    "#     theirs = [b1_d['Weight_rt']*100, b2_d['Weight_wg']*100]\n",
    "#     ours = [b1_d['Weight_wg']*100,  b2_d['Weight_rt']*100]\n",
    "\n",
    "#     x = np.arange(len(labels))  # the label locations\n",
    "#     width = 0.35  # the width of the bars\n",
    "\n",
    "#     fig, ax = plt.subplots(figsize=(7,8))\n",
    "#     rects1 = ax.bar(x - width/2, theirs, width, label='P1')\n",
    "#     rects2 = ax.bar(x + width/2, ours, width, label='P2')\n",
    "\n",
    "#     # Add some text for labels, title and custom x-axis tick labels, etc.\n",
    "#     ax.set_ylabel('Weight %')\n",
    "#     ax.set_title('Weights on P1 vs P2')\n",
    "#     ax.set_xticks(x)\n",
    "#     ax.set_xticklabels(labels)\n",
    "#     ax.legend()\n",
    "\n",
    "#     def autolabel(rects):\n",
    "#         \"\"\"Attach a text label above each bar in *rects*, displaying its height.\"\"\"\n",
    "#         for rect in rects:\n",
    "#             height = rect.get_height()\n",
    "#             ax.annotate('{}'.format(height),\n",
    "#                         xy=(rect.get_x() + rect.get_width() / 2, height),\n",
    "#                         xytext=(0, 3),  # 3 points vertical offset\n",
    "#                         textcoords=\"offset points\",\n",
    "#                         ha='center', va='bottom')\n",
    "\n",
    "\n",
    "#     autolabel(rects1)\n",
    "#     autolabel(rects2)\n",
    "\n",
    "#     fig.tight_layout()\n",
    "\n",
    "#     plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[4, 5, 4]"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "np.random.choice(range(6),3, replace=True).tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#8 node labels , both patterns have distinct\n",
    "#add even or odd number patterns\n",
    "#synthetic data 1: with 4 patterns \n",
    "import networkx as nx\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pickle\n",
    "\n",
    "max_nodes = 20\n",
    "def makesPattern(adj, feats, p, n, num_nodes):\n",
    "    l_p = np.argmax(feats[p]) \n",
    "    l_n = np.argmax(feats[n])\n",
    "    \n",
    "    p_neighbors = []\n",
    "    n_neighbors = []\n",
    "    for pix in range(num_nodes):\n",
    "        if p != pix and pix != n:\n",
    "            l_pix = np.argmax(feats[pix]) \n",
    "#             if l_pix == l_p or l_pix == l_n:\n",
    "#                 continue\n",
    "            if adj[p,pix] > 0.0:\n",
    "                p_neighbors.append(pix)\n",
    "            elif adj[n,pix] > 0.0:\n",
    "                n_neighbors.append(pix)\n",
    "    found = False\n",
    "    pattern = -1\n",
    "    for p_nbr in p_neighbors:\n",
    "        l_pnbr = np.argmax(feats[p_nbr]) \n",
    "        for n_nbr in n_neighbors:\n",
    "            l_nnbr = np.argmax(feats[n_nbr]) \n",
    "#             if l_pnbr == l_nnbr:\n",
    "#                 continue\n",
    "            if p_nbr != n_nbr and adj[p_nbr, n_nbr] > 0.:\n",
    "                if (adj[p,n_nbr] < 1. and adj[n, p_nbr] < 1.): #ring\n",
    "                    pattern = 0\n",
    "                    found = True\n",
    "                    return found, pattern\n",
    "                elif (adj[p,n_nbr] < 1. and adj[n, p_nbr] > 0.) or (adj[p,n_nbr] > 0. and adj[n, p_nbr] < 1.) :\n",
    "                    pattern = 1\n",
    "                    found = True\n",
    "                    return found, pattern\n",
    "                elif (adj[p,n_nbr] > 0. and adj[n, p_nbr] > 0.):\n",
    "                    pattern = 2\n",
    "                    found = True\n",
    "                    return found, pattern\n",
    "    return found, pattern\n",
    "        \n",
    "\n",
    "\n",
    "def addPattern(adj,feats,nodes, sub_label, hnodes_candidates):\n",
    "    degree_sum = np.sum(adj,axis=1)\n",
    "    avg_deg = int(np.sum(degree_sum)/nodes)\n",
    "    fake_sub_label = -1\n",
    "    if sub_label == -1:\n",
    "        fake_sub_label = np.random.randint(3)\n",
    "        \n",
    "#     if fake_sub_label > -1 or sub_label > -1: \n",
    "#         if fake_sub_label == -1: #true pattern\n",
    "#             lbls = np.random.choice(hnodes_candidates,4, replace=False).tolist()\n",
    "#         else:  #fake\n",
    "#             lbls = np.random.choice(range(feat_dim),3, replace=True).tolist()\n",
    "#             lbls.append(lbls[np.random.randint(3)])\n",
    "            \n",
    "#         for i in range(4): #add 4 nodes\n",
    "#             feats[nodes+i,lbls[i]] = 1.\n",
    "    if sub_label == 0 or fake_sub_label == 0:\n",
    "        lbls = np.random.choice(range(4),4, replace=False).tolist()\n",
    "    elif sub_label == 1 or fake_sub_label == 1:\n",
    "        lbls = np.random.choice(range(4,8),4, replace=False).tolist()\n",
    "    elif sub_label == 2 or fake_sub_label == 2:\n",
    "        lbls = np.random.choice(range(8,12),4, replace=False).tolist()\n",
    "    else:\n",
    "        assert(False)\n",
    "\n",
    "    for i in range(4): #add 4 nodes\n",
    "        feats[nodes+i,lbls[i]] = 1.\n",
    "\n",
    "    for i in range(3): #make ring\n",
    "        adj[nodes+i,nodes+i+1] = 1.0\n",
    "        adj[nodes+i+1, nodes+i] = 1.0\n",
    "    adj[nodes+3,nodes] = 1.0\n",
    "    adj[nodes,nodes+3] = 1.0\n",
    "            \n",
    "#         elif sub_label == 1 or fake_sub_label == 1:\n",
    "#             for i in range(3): #make ring\n",
    "#                 adj[nodes+i,nodes+i+1] = 1.0\n",
    "#                 adj[nodes+i+1, nodes+i] = 1.0\n",
    "#             adj[nodes+3,nodes] = 1.0\n",
    "#             adj[nodes,nodes+3] = 1.0\n",
    "#             diag = np.random.randint(2)\n",
    "#             adj[nodes+diag, nodes+diag+2] = 1.0\n",
    "#             adj[nodes+diag+2, nodes+diag] = 1.0\n",
    "                \n",
    "            \n",
    "        \n",
    "#         elif sub_label == 2 or fake_sub_label == 2: #make tetra real or fake\n",
    "#             for i in range(3): #make ring\n",
    "#                 adj[nodes+i,nodes+i+1] = 1.0\n",
    "#                 adj[nodes+i+1, nodes+i] = 1.0\n",
    "#             adj[nodes+3,nodes] = 1.0\n",
    "#             adj[nodes,nodes+3] = 1.0\n",
    "#             adj[nodes, nodes+2] = 1.0\n",
    "#             adj[nodes+2, nodes] = 1.0\n",
    "#             adj[nodes+1, nodes+3] = 1.0\n",
    "#             adj[nodes+3, nodes+1] = 1.0\n",
    "#     else:\n",
    "#         assert(False)\n",
    "            \n",
    "            \n",
    "\n",
    "        \n",
    "    for i in range(4):\n",
    "        deg_exp = np.random.randint(avg_deg-1, avg_deg+1)\n",
    "        dest_n = np.random.randint(nodes)\n",
    "        while_count = 0\n",
    "        skip = False\n",
    "        while(makesPattern(adj, feats, dest_n, nodes+i, nodes+4)[0] == True):\n",
    "            dest_n = np.random.randint(nodes)\n",
    "            while_count += 1\n",
    "            if while_count == 5:\n",
    "                skip = True\n",
    "                break\n",
    "        if skip:\n",
    "            continue\n",
    "            \n",
    "        adj[nodes+i, dest_n] = 1.0\n",
    "        adj[dest_n, nodes+i] = 1.0\n",
    "        deg = int(np.sum(adj[nodes+i,:]))\n",
    "        if deg_exp > deg:\n",
    "            for e in range(deg_exp-deg):\n",
    "                dest_n = np.random.randint(nodes)\n",
    "                if adj[nodes+i, dest_n] > 0:\n",
    "                    continue\n",
    "                if(makesPattern(adj, feats, dest_n, nodes+i, nodes+4)[0] == True):\n",
    "                    continue\n",
    "                adj[nodes+i, dest_n] = 1.0\n",
    "                adj[dest_n, nodes+i] = 1.0\n",
    "        \n",
    "    \n",
    "    pos_t = list(range(nodes+4))\n",
    "    pos_covered = []\n",
    "#     for i in range(4):\n",
    "        \n",
    "#         if np.random.rand() < 0.4:\n",
    "#             continue\n",
    "        \n",
    "#         dest_pos = np.random.randint(nodes)\n",
    "#         if dest_pos in pos_covered:\n",
    "#             continue\n",
    "        \n",
    "        \n",
    "        \n",
    "#         temp_feats = np.copy(feats[nodes+i,:])\n",
    "#         feats[nodes+i,:] = feats[dest_pos,:]\n",
    "#         feats[dest_pos,:] = temp_feats\n",
    "        \n",
    "        \n",
    "#         temp_adj1 = np.copy(adj[nodes+i,:])\n",
    "#         temp_adj2 = np.copy(adj[:,nodes+i])\n",
    "\n",
    "#         adj[nodes+i,:] = adj[dest_pos,:]\n",
    "#         adj[:,nodes+i] = adj[:,dest_pos]\n",
    "#         adj[dest_pos,:] = temp_adj1\n",
    "#         adj[:,dest_pos] = temp_adj2\n",
    "        \n",
    "        \n",
    "#         pos_t[nodes+i] = dest_pos\n",
    "#         pos_t[dest_pos] = nodes + i\n",
    "#         pos_covered.append(dest_pos)\n",
    "      \n",
    "        \n",
    "        \n",
    "#     print(\"nodes: \", nodes+4, pos_t[-4:])\n",
    "        \n",
    "        \n",
    "        \n",
    "\n",
    "    return adj, feats, pos_t[nodes:], lbls\n",
    "                            \n",
    "            \n",
    "def drawGraph(adj, feats, nodes, highlight_nodes=None, sublabel_d=None):\n",
    "    node_labels = ['A','B','C','D','E','F','G','H','I','J','K','L']\n",
    "    G_class = nx.from_numpy_array(adj[:nodes,:nodes])\n",
    "\n",
    "    fig, ax_l = plt.subplots(1,1, figsize=(15,10))\n",
    "    colors = []\n",
    "    hn_list = []\n",
    "    for n in range(nodes):\n",
    "        colors.append((0.9,0.9,0.9))\n",
    "    if highlight_nodes is not None:\n",
    "        for ix in range(highlight_nodes.shape[0]):\n",
    "            red = np.random.rand()\n",
    "            green = 1.0 - red \n",
    "            blue = np.random.rand()  \n",
    "#             if sublabel_d[ix] != -1:\n",
    "            if True:\n",
    "                for gx in range(highlight_nodes.shape[1]):\n",
    "                    for h_n in highlight_nodes[ix,gx]:\n",
    "                        if h_n == -1:\n",
    "                            continue\n",
    "#                         print(\"node\", h_n, red, green, blue)\n",
    "                        colors[h_n] = (red,green,blue)\n",
    "                        hn_list.append(h_n)\n",
    "    labels_dict = {}\n",
    "    for n in range(nodes):\n",
    "        lb = np.argmax(feats[n,:])\n",
    "        labels_dict[n] = node_labels[lb]\n",
    "#         labels_dict[n] = str(n) + \" : \" + node_labels[lb]\n",
    "        \n",
    "#     colors[0] = (0.9,0.1,0.1)\n",
    "#     nx.draw_networkx(G_class, ax=ax_l, node_color=colors)\n",
    "    if False:\n",
    "#     if highlight_nodes is not None:\n",
    "        hn_colors = []\n",
    "        hn_labels = {}\n",
    "        hn_edges = []\n",
    "        \n",
    "        for hn in hn_list:\n",
    "            hn_labels[hn] = labels_dict[hn]\n",
    "            hn_colors.append(colors[hn])\n",
    "        \n",
    "        for e in G_class.edges():\n",
    "            if e[0] in hn_list and e[1] in hn_list:\n",
    "                hn_edges.append(e)\n",
    "                \n",
    "        \n",
    "        \n",
    "        \n",
    "        nx.draw_networkx(G_class, labels=hn_labels, nodelist = hn_list, edgelist = hn_edges, ax=ax_l, node_color = hn_colors)\n",
    "    else:   \n",
    "        nx.draw_networkx(G_class, labels=labels_dict, ax=ax_l, node_color = colors)\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "total_adds = 2\n",
    "len_data = 8000\n",
    "max_repeat = 1\n",
    "init_max_nodes = 12\n",
    "max_nodes = total_adds*max_repeat*4 + init_max_nodes\n",
    "feat_dim = 12\n",
    "feats_data = np.zeros((len_data, max_nodes, feat_dim))\n",
    "adjs_data = np.zeros((len_data, max_nodes, max_nodes))\n",
    "labels_data = np.zeros((len_data),dtype=np.int32)\n",
    "sub_labels_data = np.zeros((len_data,total_adds),dtype=np.int32) - 1\n",
    "sub_label_nodes = np.zeros((len_data,total_adds,max_repeat,4),dtype=np.int32) - 1\n",
    "num_nodes_data = np.zeros((len_data),dtype=np.int32)\n",
    "\n",
    "#A/B/C/D/E/F\n",
    "for i in range(len_data):\n",
    "# for i in range(20):\n",
    "    adj = np.zeros((max_nodes, max_nodes))\n",
    "    feats = np.zeros((max_nodes, feat_dim))\n",
    "    graph_label = np.random.randint(3) #3 combs \n",
    "  \n",
    "    label_count = 0\n",
    "    \n",
    "    nodes = np.random.randint(6, init_max_nodes)\n",
    "    \n",
    "    for n_ix in range(nodes): # add nodes\n",
    "        l_n = np.random.randint(0,feat_dim)\n",
    "        feats[n_ix, l_n] = 1.0\n",
    "\n",
    "        if n_ix == 0:\n",
    "            continue\n",
    "        p_node = np.random.randint(n_ix) \n",
    "        while makesPattern(adj, feats, p_node, n_ix, nodes)[0] == True: #connect node to graph such that\n",
    "            p_node = np.random.randint(n_ix)                    #it doesn't result in any pattern\n",
    "\n",
    "        adj[p_node,n_ix] = 1.\n",
    "        adj[n_ix, p_node] = 1.\n",
    "\n",
    "    max_edges = min(28,int((nodes*(nodes-1))/3))\n",
    "    if max_edges > nodes:\n",
    "        edge_total = np.random.randint(nodes, max_edges)\n",
    "        for e_ix in range(edge_total-nodes+1):\n",
    "            rand_nix = np.random.randint(nodes)\n",
    "            rand_pix = np.random.randint(nodes)\n",
    "            if rand_nix == rand_pix:\n",
    "                continue\n",
    "            if not makesPattern(adj, feats, rand_pix, rand_nix, nodes)[0]: #add more random edges\n",
    "                adj[rand_nix, rand_pix] = 1.0\n",
    "                adj[rand_pix, rand_nix] = 1.0\n",
    "    \n",
    "    label_list = [(0,1,-1),(1,2,-1),(2,0,-1)]\n",
    "    sublabels_covered = []\n",
    "    highlight_nodes_covered = []\n",
    "    for g_i in range(total_adds):\n",
    "        l_entry = label_list[graph_label]\n",
    "        label = 1\n",
    "        sublabel_ix = np.random.randint(total_adds)\n",
    "        while(l_entry[sublabel_ix] in sublabels_covered):\n",
    "            sublabel_ix = np.random.randint(total_adds)\n",
    "        sub_label = l_entry[sublabel_ix]\n",
    "        sublabels_covered.append(sub_label)\n",
    "        if sub_label == -1:\n",
    "            if np.random.rand() < 0.5:\n",
    "                continue\n",
    "\n",
    "#         print(\"g_i: \", g_i, \"label: \", label, \"sub_label,: \", sub_label)\n",
    "        if max_repeat > 1:\n",
    "            repeat = np.random.randint(1,max_repeat)\n",
    "        else:\n",
    "            repeat = 1\n",
    "        assert(repeat > 0)\n",
    "        for rep in range(repeat):\n",
    "      \n",
    "            #0 or 1\n",
    "            highlight_nodes = None\n",
    "            hnodes_candidates = []\n",
    "\n",
    "            \n",
    "            for i_label in range(feat_dim):\n",
    "                if i_label not in highlight_nodes_covered:\n",
    "                    hnodes_candidates.append(i_label)\n",
    "#             print(hnodes_candidates)\n",
    "            adj, feats, highlight_nodes, label_hnodes = addPattern(adj,feats,nodes, sub_label, hnodes_candidates)\n",
    "            if sub_label != -1:\n",
    "                highlight_nodes_covered.extend(label_hnodes)\n",
    "#             print(\"sublabel: \", sub_label, highlight_nodes)\n",
    "            nodes = nodes + 4\n",
    "\n",
    "#             if sub_label != -1:\n",
    "#                 print(highlight_nodes)\n",
    "            sub_label_nodes[i, g_i, rep] = np.array(highlight_nodes)\n",
    "            \n",
    "        sub_labels_data[i,g_i] = sub_label\n",
    "    \n",
    "            \n",
    "\n",
    "\n",
    "    print(i, graph_label)\n",
    "            \n",
    "\n",
    "#     print(nodes)\n",
    "#     drawGraph(adj, feats, nodes, highlight_nodes=sub_label_nodes[i], sublabel_d=sub_labels_data[i])\n",
    "    \n",
    "    \n",
    "    feats_data[i] = feats\n",
    "    labels_data[i] = graph_label\n",
    "    \n",
    "    num_nodes_data[i] = nodes\n",
    "    \n",
    "    adjs_data[i] = adj\n",
    "    \n",
    "\n",
    "synthetic_data = {}\n",
    "synthetic_data['adj'] = adjs_data\n",
    "synthetic_data['feat'] = feats_data\n",
    "synthetic_data['label'] = labels_data\n",
    "synthetic_data['sub_label'] = sub_labels_data\n",
    "synthetic_data['sub_label_nodes'] = sub_label_nodes\n",
    "synthetic_data['num_nodes'] = num_nodes_data\n",
    "\n",
    "# pickle.dump(synthetic_data, open(\"../../gcn_interpretation/data/synthetic_data_2label_3sublabel/synthetic_data_4000_comb_norep.p\", \"wb\"))\n",
    "\n",
    "\n",
    "pickle.dump(synthetic_data, open(\"../../gcn_interpretation/data/synthetic_data_3label_3sublabel/synthetic_data_8000_comb_norep_max20_12dlbls_nofake.p\", \"wb\"))\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "r_dict = pickle.load(open(\"./data/synthetic/rule_dict_synthetic_train_4k8000_comb_12dlbls_nofake.p\",\"rb\"))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'boundary': [{'basis': array([ -8.318322 , -10.547413 ,  -9.506501 ,  -2.0877955,   6.074997 ,\n",
      "         9.731785 ,  -8.793197 ,  -3.5884438, -10.683054 , -16.14082  ,\n",
      "       -11.680099 ,  -7.010401 ,  13.33712  ,  -6.5393534,  10.34538  ,\n",
      "         9.182712 ,  12.544017 ,  -2.3591433,   9.005638 ,   9.35488  ],\n",
      "      dtype=float32), 'label': 1}, {'basis': array([  7.554856 , -11.220877 , -14.495297 ,   8.8098755,  11.397522 ,\n",
      "        13.182472 ,  -1.2428834, -11.398182 ,   3.991159 ,   1.7135191,\n",
      "       -11.942663 , -12.709032 ,   0.3304658,   6.4476576,   9.878927 ,\n",
      "         9.680606 ,   8.160595 ,  12.102753 ,  -1.8032672,  -2.0428364],\n",
      "      dtype=float32), 'label': 0}], 'label': 2}\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'boundary': [{'basis': array([ -8.318322 , -10.547413 ,  -9.506501 ,  -2.0877955,   6.074997 ,\n",
       "            9.731785 ,  -8.793197 ,  -3.5884438, -10.683054 , -16.14082  ,\n",
       "          -11.680099 ,  -7.010401 ,  13.33712  ,  -6.5393534,  10.34538  ,\n",
       "            9.182712 ,  12.544017 ,  -2.3591433,   9.005638 ,   9.35488  ],\n",
       "         dtype=float32),\n",
       "   'label': 1},\n",
       "  {'basis': array([  7.554856 , -11.220877 , -14.495297 ,   8.8098755,  11.397522 ,\n",
       "           13.182472 ,  -1.2428834, -11.398182 ,   3.991159 ,   1.7135191,\n",
       "          -11.942663 , -12.709032 ,   0.3304658,   6.4476576,   9.878927 ,\n",
       "            9.680606 ,   8.160595 ,  12.102753 ,  -1.8032672,  -2.0428364],\n",
       "         dtype=float32),\n",
       "   'label': 0}],\n",
       " 'label': 2}"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(r_dict['rules'][4])\n",
    "r_dict['rules'][5]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[1 2 4]\n",
      " [3 6 0]]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(1, 6, 2)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#explain_boundary\n",
    "Boundary wise top4 acc: 0.16889880952380953, top6 acc: 0.33209325396825395, top8 acc: 0.4993799603174603\n",
    "Rule wise top8 acc: 0.6876240079365079\n",
    "Average mask density: 0.14106938052767268\n",
    "    \n",
    "    \n",
    "#experiment_graph\n",
    "Boundary wise top4 acc: 0.6421666666666667, top6 acc: 0.8211666666666667, top8 acc: 0.9174583333333334\n",
    "Rule wise top8 acc: 0.66525  => 0.69\n",
    "    \n",
    "    \n",
    "#explain\n",
    "label => 0\n",
    "Rule wise top8 acc: 0.7739335317460317\n",
    "Average mask density: 0.18899344655847738\n",
    "\n",
    "label=>1\n",
    "Rule wise top4 acc: 0.33649114173228345, top6 acc: 0.48683562992125984, top8 acc: 0.6119586614173228\n",
    "Average mask density: 0.4156269320171312\n",
    "\n",
    "    \n",
    "    \n",
    "#explain_boundary_joint\n",
    "label => 0\n",
    "\n",
    "Rule wise top8 acc: 0.7920386904761905\n",
    "Average mask density: 0.38967474965408205\n",
    "    \n",
    "Boundary wise top4 acc: 0.4037698412698413, top6 acc: 0.6192956349206349, top8 acc: 0.8298611111111112\n",
    "Rule wise top8 acc: 0.8298611111111112\n",
    "Average mask density: 0.3493731921531319\n",
    "    \n",
    "Boundary wise top4 acc: 0.43005952380952384, top6 acc: 0.6479414682539683, top8 acc: 0.8599950396825397\n",
    "Rule wise top8 acc: 0.8599950396825397\n",
    "Average mask density: 0.30247850131080856\n",
    "    \n",
    "\n",
    "label=>1\n",
    "Rule wise top8 acc: 0.7052165354330708\n",
    "Average mask density: 0.3767998004882178\n",
    "\n",
    "Rule wise top8 acc: 0.7217027559055118\n",
    "Average mask density: 0.3166504028421922\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAfAAAAI4CAYAAACV/7uiAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3de5yWdZ3/8deHs0oCKZYyKlpWoAbqhOkaq5KmREmGptkqimuuppWn9Ze6q0ZWHmJzPYdnbTVdUzykGR5Wy9OgmIp4TAFPEAkEQop8f3/c19DNMHPPDXHPzRdez8djHnNf5891zX3P+76+1ylSSkiSpLx0qncBkiRpxRngkiRlyACXJClDBrgkSRkywCVJypABLklShgxwqYKIOCgiflvvOppFxDoRcXtEzI2Im+pdj9oXEaMj4uEqx/1BRIyvdU1aMxjg6hAR8c2IaIqI+RHxVkT8JiJ2qXdd7UkpXZ9S2rPedZQZBXwM2CCltF+9i1nVouQ7EfHHiHgvIt6OiAci4oCycR6IiEURsWlZvy9GxGtl3a9FxMyIWK+s3+ER8UCFZV8VEe8X79Hmn6dX/Vq2LaV0Vkrp8I5cpvJlgKvmIuI44L+AsyiFz2bARcA+9ayrPRHRpd41tGJz4MWU0uJaLygiOtd6Ga04H/gecDywAdAPOBXYq8V4C4DT2plXZ+C7K7j8s1NKPct+Bq3g9Kul1fS9rH+QAa6aiohewJnA0SmlW1JKC1JKH6SUbk8pnViM0z0i/isi3ix+/isiuhfDdo2IGRFxUrFH9VZEjIyI4RHxYkT8JSJ+ULa80yPi5oi4MSL+GhFPRsSgsuEnR8QrxbApEfG1smGjI+L3ETEuImYDp5c3fxZ7h+OKOuZFxDMRsU3zekbENRExKyJej4hTI6JT2XwfjohzI+LdiPhTROxdYZsNKPYy50TEcxHx1aL/GcB/AN8o9g7HtDJtpW25XFNuRKSI+GTx+qqIuDgi7oqIBcBuxXaeUmyvNyLihDaWOad5WxT9+kbEwojYKCI2jIg7inH+EhEPNW+bFvP5FHAUcEBK6d6U0sKU0ocppYdTSqNbjH4+cGBEfKKt7QicA5wQEb0rjFOViPhG8Xdbv+jeu2gd6Ft0p4g4NiJejYg/R8Q5ra1jMe7PI2J68R6aFBFfKBt2ekRcV7zuX8z3kIiYVsz3lLJxO5W9n2dHxK8i4qMtph0TEdOA+/7RbaDVjwGuWtsJ6AH8usI4pwCfBwYDg4AhlPa6mn28mEc/SgH2C+BbwA7AF4DTImKLsvH3AW4CPgr8Erg1IroWw14ppukFnAFcFxEbl027I/AqpZaCH7Woc09gKPCpYvr9gdnFsP8u+m0J/DNwMHBoi/m+AGwInA1cHhHRckMUdd4O/BbYCDgGuD4iPp1S+k9KrRg3FnuHl7ecnva3ZXu+Waz3R4CHgcuBb6eUPgJsQytBkFL6G3ALcGBZ7/2BB1NKMyntTc8A+lLarj8AWruH8+7A9JRSUxV1vkHpfXBGhXGagAeA5b50rKiU0o3AH4DzI2IDStvl8JTSrLLRvgY0AttTeg8e1sbsnqD092l+f94UET0qLH4X4NPAMOA/ImJA0f8YYCSl99smwLvAhS2m/WdgAPClKlZTuUkp+eNPzX6Ag4C32xnnFWB4WfeXgNeK17sCC4HORfdHKP3z37Fs/EnAyOL16cCjZcM6AW8BX2hj2ZOBfYrXo4FpLYaPBh4uXu8OvEgpIDuVjdMZeB8YWNbv28ADZfN4uWzYusU6fLyVer4AvN1i/v8DnF62ftet5LZcui5lwxPwyeL1VcA1LYZPK9Zl/Xb+hl8EXinr/j1wcPH6TOC25uVUmMep5X+7ot8MYA6wCNi86PcAcDilLwRzga2L5b9WNt1rRb9tinH6FtM8UGH5VxXLmVP2c3XZ8N7F9ngGuLSV7bhXWfdRwMS2tnuLad8FBrX8+wL9i/k2lI37OKUWCoDngWFlwzYGPgC6lE27ZS0+1/6sHj/ugavWZgMbRuVjcJsAr5d1v170WzqPlNKHxeuFxe93yoYvBHqWdU9vfpFSWkIpBDYBiIiDI2Jy0Zw7h9I/+A1bm7allNJ9wAWU9nJmRsRlRZPqhkDXVtahX1n322Xzea94WV5zs00o7YUuqTCvStrblu1puf5fB4YDr0fEgxGxUxvT3Q+sGxE7RkR/SnuYza0u5wAvA78tmphPbmMesymF0FIppQZK27c7EC2GzaL09zizrZVJKT0L3AEss8wone3dfKLaJWWDzk0p9S77OaRsXnMotexsA5zXyuLKt12b2z0iToiI56N0JcEcSi03G7Y2buHtstfv8ff3zebAr8vey88DH1Jq5WitJq1hDHDV2iPA3yg19bXlTUr/jJptVvRbWeVnJ3cCGoA3I2JzSs2u36F0Fndv4FmWDYaKj+dLKZ2fUtoBGEipKf1E4M+U9nxarsMbK1H7m8CmLY6frsi8Km3LBZT2/gGIiI+3Mv0y659SeiKltA+l5vxbgV+1ttDiC9avKDWjHwjckVL6azHsryml41NKWwJfBY6LiGGtzOY+oCEiGttdy787B9iN0uGUtvwn8K+UfQlKpbO9m09UO7KaBUXEYErN4v9D6Rh8S5uWvW71PVwc7z6J0iGGPsV7cC4tvpxUaTqwd4svHD1SSuXvFR83uQYzwFVTKaW5lI5bXxilk8/WjYiuxUlAZxej/Q9wanHi04bF+Nf9A4vdISL2Lfb6v0fpC8SjwHqU/qHNAoiIQyntTVUlIj5X7GF2pRSGi4AlZeH1o4j4SPFF4biVXIfHKO1lnVRsp12BrwA3VDl9pW35NLB1RAwujrmeXmlGEdEtStfB90opfQDMA5ZUmOSXwDcoHTb5Zdl8RkTEJ4tj/nMp7SUuN5+U0gvApcANEbFHlK557wzs3NYCi73i8yiFYlvjvAzcCBxbofaKiu11HaXj94cC/SLiqBajnRgRfaJ0edt3i2W29BFgMaX3YJeI+A9g/ZUs6xJK77nNixr7RsRqfWWHVi0DXDWXUjqPUqCdSukf13RKe8G3FqOMpXTC0R8pHV98sui3sm6jFCTvAv8C7JtKZ75PofTP/hFKTfDbUjpWW631Ke3Bv0upiXQ2pT1AKJ1QtIDSCXAPUwqwK1a08JTS+5QCe29Ke/YXUTqWPLXKWbS5LVNKL1Jqbv4d8FJRZ3v+BXgtIuYBR1IK57Zqf4zSNtgE+E3ZoK2KZc6ntO0vSind38Zsjqa0d/sz4C+UDn/8kNLfc1ob0/yc0peCSs6k9AWuPSfFsteB/7no/2NKhzYuTqWT9r4FjI2IrcqmvY3S+RiTgTspnejW0j3A3ZTOpXid0pfAlW3m/jkwgdKhib9S+pK640rOSxmKlGxh0ZojIk6ndLLUt+pdi9YeEZGArYq9falDuAcuSVKGDHBJkjJkE7okSRlyD1ySpAxlcYP7DTfcMPXv37/eZUiS1OEmTZr055RS35b9swjw/v3709RUze2RJUlas0TE6631twldkqQMGeCSJGXIAJckKUNZHANvzQcffMCMGTNYtGhRvUtRFXr06EFDQwNdu3Ztf2RJUruyDfAZM2bwkY98hP79+1N6RoJWVyklZs+ezYwZM9hiiy3qXY4krRGybUJftGgRG2ywgeGdgYhggw02sLVEklahbAMcMLwz4t9KklatrANckqS1VbbHwFvqf/Kdq3R+r/3ky+2O88477/D973+fRx99lD59+tCtWzdOOukk+vTpw2677caECRP4yle+AsCIESM44YQT2HXXXdl1112ZP3/+0pvTNDU1ccIJJ/DAAw8st4zRo0fz4IMP0qtXLwDWXXdd/vCHP6zw+jzwwAOce+653HHHHW2O09TUxDXXXMP555+/wvOXJHUs98BXUkqJkSNHMnToUF599VUmTZrEDTfcwIwZMwBoaGjgRz/6UZvTz5w5k9/85jdVLeucc85h8uTJTJ48eaXCu1qNjY2rLLw//PDDVTIfSVLrDPCVdN9999GtWzeOPPLIpf0233xzjjnmGAAGDRpEr169uPfee1ud/sQTT6wY8O357ne/y5lnngnAPffcw9ChQ1myZAmjR4/myCOPpLGxkU996lOt7nE//vjj7LTTTmy33XbsvPPOvPDCC0BpL33EiBEAnH766Rx22GHsuuuubLnllssE+3XXXceQIUMYPHgw3/72t5eGdc+ePTn++OMZNGgQjzzyyEqvmySpfQb4SnruuefYfvvtK45zyimnMHbs2FaH7bTTTnTr1o3777+/3WWdeOKJDB48mMGDB3PQQQcB8OMf/5gbb7yR+++/n2OPPZYrr7ySTp1Kf87XXnuNxx9/nDvvvJMjjzxyubO/P/OZz/DQQw/x1FNPceaZZ/KDH/yg1eVOnTqVe+65h8cff5wzzjiDDz74gOeff54bb7yR3//+90yePJnOnTtz/fXXA7BgwQJ23HFHnn76aXbZZZd210uSAMaNG8fWW2/NNttsw4EHHsiiRYuYOHEi22+/PYMHD2aXXXbh5ZdfXm662bNns9tuu9GzZ0++853v1KHy+lpjjoHX29FHH83DDz9Mt27dOOeccwAYOnQoAA8//HCr05x66qmMHTuWn/70pxXnfc455zBq1Khl+q277rr84he/YOjQoYwbN45PfOITS4ftv//+dOrUia222oott9ySqVOnLjPt3LlzOeSQQ3jppZeICD744INWl/vlL3+Z7t270717dzbaaCPeeecdJk6cyKRJk/jc5z4HwMKFC9loo40A6Ny5M1//+tcrrosklXvjjTc4//zzmTJlCuussw77778/N9xwA2eddRa33XYbAwYM4KKLLmLs2LFcddVVy0zbo0cPfvjDH/Lss8/y7LPP1mcF6sg98JW09dZb8+STTy7tvvDCC5k4cSKzZs1aZrxKe+G77747Cxcu5NFHH13a79BDD2Xw4MEMHz683RqeeeYZNthgA958881l+re8ZKtl92mnncZuu+3Gs88+y+23397m9dndu3df+rpz584sXryYlBKHHHLI0mPyL7zwAqeffjpQ+jB17ty53bolqdzixYtZuHAhixcv5r333mOTTTYhIpg3bx5Q2unYZJNNlptuvfXWY5dddqFHjx4dXfJqwQBfSbvvvjuLFi3i4osvXtrvvffeW268Pffck3fffZc//vGPrc7n1FNP5eyzz17afeWVVzJ58mTuuuuuist//fXXOe+883jqqaf4zW9+w2OPPbZ02E033cSSJUt45ZVXePXVV/n0pz+9zLRz586lX79+AMt9o23PsGHDuPnmm5k5cyYAf/nLX3j99VafdCdJ7erXrx8nnHACm222GRtvvDG9evVizz33ZPz48QwfPpyGhgauvfZaTj755HqXutpZY5rQq7nsa1WKCG699Va+//3vc/bZZ9O3b1/WW2+9VpvDTznlFPbZZ59W5zN8+HD69l3uOe3LOPHEE5fZi3/ssccYM2YM5557LptssgmXX345o0eP5oknngBgs802Y8iQIcybN49LLrlkuW+nJ510Eocccghjx47ly19ese02cOBAxo4dy5577smSJUvo2rUrF154IZtvvvkKzUeSAN59911uu+02/vSnP9G7d2/2228/rrvuOm655RbuuusudtxxR8455xyOO+44xo8fX+9yVyuRUqp3De1qbGxMzddMN3v++ecZMGBAnSpafY0ePZoRI0Ysd8x8deDfTFJLN910E3fffTeXX345ANdccw2PPPIIv/3tb3nllVcAmDZtGnvttRdTpkxpdR5XXXUVTU1NXHDBBR1Wd0eKiEkppcaW/W1ClyTVzWabbcajjz7Ke++9R0qJiRMnMnDgQObOncuLL74IwL333uuX/1asMU3oKlnRY9qSVE877rgjo0aNYvvtt6dLly5st912HHHEETQ0NPD1r3+dTp060adPH6644goAJkyYQFNT09L7YPTv35958+bx/vvvc+utt/Lb3/6WgQMH1nOVOoxN6Oow/s20phg3bhzjx48nIth222258sor6d69O6eeeio33XQTnTt35t/+7d849thjl5t22rRpHH744UyfPp2I4K677qJ///4dvxLKRltN6O6BS9IKaOu65ZQS06dPZ+rUqXTq1GnplRotHXzwwZxyyinssccezJ8/f+kNmKQVZYBL0gpqvm65a9euS69bPvXUU/nlL3+5NJCbb3BUbsqUKSxevJg99tgDKN1+WFpZfvWTpBXQ1nXLr7zyCjfeeCONjY3svffevPTSS8tN++KLL9K7d2/23XdftttuO0488UQf/KOVtubsgZ/eaxXPb267o8yYMYOjjz6aKVOmsGTJEkaMGME555xDt27dVm0tklYbbV23/Le//Y0ePXrQ1NTELbfcwmGHHcZDDz20zLSLFy9e+hyCzTbbjG984xtcddVVjBkzpk5rw6r/37m2qyI7VhX3wFdSSol9992XkSNH8tJLL/Hiiy8yf/58TjnllKrn4TdvKT+/+93v2GKLLejbty9du3Zl33335Q9/+AMNDQ3su+++AHzta19r9e6LDQ0NDB48mC233JIuXbowcuTIZW7JLK0IA3wl3XffffTo0YNDDz0UKN0rfNy4cVxxxRVcdNFFyzwZZ8SIETzwwAPA8o/cPPnkkxk4cCCf/exnOeGEE+qxKpJWQGvXLQ8YMICRI0cufbrggw8+yKc+9anlpv3c5z7HnDlzlj4z4b777ltrLnnSqrfmNKF3sOeee44ddthhmX7rr78+m222GYsXL25zuuZHbp533nnMnj2bMWPGMHXqVCKCOXPm1LpsSf+gtq5bXrhwIQcddBDjxo2jZ8+eS2/72dTUxCWXXML48ePp3Lkz5557LsOGDSOlxA477MC//uu/1nmNlCsDvIOVP3KzV69e9OjRgzFjxjBixAhGjBhR5+okVeOMM87gjDPOWKZf9+7dufPOO5cbt7GxcZl7eO+xxx5tPtxIWhE2oa+kgQMHMmnSpGX6zZs3j2nTptG7d2+WLFmytH/54zrLH7nZpUsXHn/8cUaNGsUdd9zBXnvt1THFS5KyZ4CvpGHDhvHee+9xzTXXAKUT0o4//nhGjx7NlltuyeTJk1myZAnTp0/n8ccfb3Ue8+fPZ+7cuQwfPpxx48bx9NNPd+QqSJIytuY0oXfgqftQepzor3/9a4466ih++MMfsmTJEoYPH85ZZ51Ft27d2GKLLRg4cCADBgxg++23b3Uef/3rX9lnn31YtGgRKSV+9rOfdeg6SJLyteYEeB1suumm3H777a0Ou/7661vtP3/+/KWvN9544zb3ziVJqsQAl5SV/icvf6KYVt5rPepdgVaWx8AlScpQ1gGew6NQVeLfSpJWrWwDvEePHsyePdtgyEBKidmzZ9OjR+3a6saNG8fWW2/NNttsw4EHHsiiRYsYM2YMgwYN4rOf/SyjRo1a5vyDZu+//z6HHnoo2267LYMGDVp6xzxJWt1lewy8oaGBGTNmLL0loVZvPXr0oKGhoSbzbuv5zOPGjWP99dcH4LjjjuOCCy7g5JNPXmbaX/ziFwA888wzzJw5k7333psnnnjCZzRLWu1lG+Bdu3Zliy22qHcZWk209nzm5vBOKbFw4UIiYrnppkyZwu677w6Unt/cu3dvmpqaGDJkSIfWL0kryt0MZa+t5zMDHHrooXz84x9n6tSpHHPMMctNO2jQICZMmMDixYv505/+xKRJk5g+fXpHr4IkrTADXNkrfz7zm2++yYIFC7juuusAuPLKK3nzzTcZMGAAN95443LTHnbYYTQ0NNDY2Mj3vvc9dt5556W3upWk1ZkBruy19XzmZp07d+aAAw7gf//3f5ebtkuXLowbN47Jkydz2223MWfOnFYfAylJqxsDXNlr6/nML7/8MlA6Bj5hwgQ+85nPLDfte++9x4IFCwC499576dKli89nlpSFbE9ik5q19Xzm3XffnXnz5pFSYtCgQVx88cUATJgwgaamJs4880xmzpzJl770JTp16kS/fv249tpr67w2klSdyOE66sbGxtTU1FTvMlapcePGMX78eCKCbbfdliuvvJIxY8bQ1NRE165dGTJkCJdeeildu3Ztdfp58+YxcOBARo4cyQUXXNDB1Uv1461UV63Xenyz3iWsWWrwYK2ImJRSamzZ3yb0Omi+brmpqYlnn32WDz/8kBtuuIGDDjqIqVOn8swzz7Bw4ULGjx/f5jxOO+00hg4d2oFVS5JWJwZ4nTRft7x48eKl1y0PHz6ciCAiGDJkCDNmzGh12kmTJvHOO+8svVRKkrT2McDroNJ1ywAffPAB1157LXvttddy0y5ZsoTjjz+ec889tyNLliStZjyJrQ7Kr1vu3bs3++23H9dddx3f+ta3ADjqqKMYOnQoX/jCF5ab9qKLLmL48OE1uy3pSjm9V70rWLPU4BiapDWPAV4H5dctA0uvW/7Wt77FGWecwaxZs7j00ktbnfaRRx7hoYce4qKLLmL+/Pm8//779OzZk5/85CcduQqSpDozwOug/LrlddZZh4kTJ9LY2Mj48eO55557mDhxYpsP07j++uuXvr7qqqtoamoyvCVpLeQx8Doov2552223ZcmSJRxxxBEceeSRvPPOO+y0004MHjyYM888E4CmpiYOP/zwOlctSVqdeB24/nEeA1+1PAZekdeBr1peB76KeR24JEmqxACXJClDBrgkSRlaK89C9xjaqvVaj3pXIElrH/fAJUnKkAEuSVKGDHBJkjJkgEuSlCEDXJKkDBngkiRlyACXJClDBrgkSRkywCVJypABLklShmoa4BHx/Yh4LiKejYj/iYgeEbFFRDwWES9HxI0R0a2WNUiStCaqWYBHRD/gWKAxpbQN0Bk4APgpMC6l9EngXWBMrWqQJGlNVesm9C7AOhHRBVgXeAvYHbi5GH41MLLGNUiStMapWYCnlN4AzgWmUQruucAkYE5KaXEx2gygX2vTR8QREdEUEU2zZs2qVZmSJGWplk3ofYB9gC2ATYD1gL2qnT6ldFlKqTGl1Ni3b98aVSlJUp5q2YT+ReBPKaVZKaUPgFuAfwJ6F03qAA3AGzWsQZKkNVItA3wa8PmIWDciAhgGTAHuB0YV4xwC3FbDGiRJWiPV8hj4Y5ROVnsSeKZY1mXAvwPHRcTLwAbA5bWqQZKkNVWX9kdZeSml/wT+s0XvV4EhtVyuJElrOu/EJklShgxwSZIyZIBLkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZMsAlScqQAS5JUoYMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuSVKGDHBJkjJkgEuSlCEDXJKkDBngkiRlyACXJClDBrgkSRkywCVJypABLklShgxwSZIyZIBLkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZMsAlScqQAS5JUoYMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuSVKGDHBJkjJkgEuSlCEDXJKkDBngkiRlyACXJClDBrgkSRkywCVJypABLklShgxwSZIyZIBLkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZMsAlScqQAS5JUoYMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuSVKGDHBJkjJkgEuSlCEDXJKkDBngkiRlyACXJClDBrgkSRkywCVJypABLklShgxwSZIyZIBLkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZMsAlScqQAS5JUoYMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuSVKGDHBJkjJkgEuSlCEDXJKkDBngkiRlyACXJClDBrgkSRkywCVJypABLklShgxwSZIyZIBLkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZqmmAR0TviLg5IqZGxPMRsVNEfDQi7o2Il4rffWpZgyRJa6Ja74H/HLg7pfQZYBDwPHAyMDGltBUwseiWJEkroGYBHhG9gKHA5QAppfdTSnOAfYCri9GuBkbWqgZJktZUtdwD3wKYBVwZEU9FxPiIWA/4WErprWKct4GP1bAGSZLWSLUM8C7A9sDFKaXtgAW0aC5PKSUgtTZxRBwREU0R0TRr1qwalilJUn5qGeAzgBkppceK7pspBfo7EbExQPF7ZmsTp5QuSyk1ppQa+/btW8MyJUnKT80CPKX0NjA9Ij5d9BoGTAEmAIcU/Q4BbqtVDZIkram61Hj+xwDXR0Q34FXgUEpfGn4VEWOA14H9a1yDJElrnJoGeEppMtDYyqBhtVyuJElrOu/EJklShgxwSZIyZIBLkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZMsAlScqQAS5JUoYMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuSVKGDHBJkjJkgEuSlCEDXJKkDBngkiRlyACXJClDbQZ4RHTpyEIkSVL1Ku2BP95hVUiSpBVSKcCjw6qQJEkrpFIzed+IOK6tgSmln9WgHkmSVIVKAd4Z6Il74pIkrXYqBfhbKaUzO6wSSZJUNY+BS5KUoUp74MPKOyKiH6VmdYA3U0qLa1aVJEmqqFKAfzsiupY1oz8CzAG6AVcDP651cZIkqXWVmtD3A84r656dUvossDXw5ZpWJUmSKqp4K9WU0oKyzp8X/T4E1qllUZIkqbJKAd4zIro2d6SUrgKIiO7A+jWuS5IkVVApwG8GLo2IdZt7RMR6wCXFMEmSVCeVAvw0YCYwLSImRcQk4DXgnWKYJEmqkzbPQi+OdZ8cEWcAnyx6v5xSWtghlUmSpDa1+8jQIrCf6YBaJElSlSqehS5JklZPBrgkSRlqN8AjYmI1/SRJUsdp8xh4RPQA1gU2jIg+/P3hJusD/TqgNkmS1IaK90IHvgdsAkzi7wE+D7igxnVJkqQKKl1G9nPg5xFxTErpvzuwJkmS1I52j4G3Ft4R8fHalCNJkqqxsmehX75Kq5AkSStkpQI8peTjRCVJqqOKAR4RnSNiakcVI0mSqtPe88A/BF6IiM06qB5JklSFdu+FDvQBntLUN5EAAA7ZSURBVIuIx4EFzT1TSl+tWVWSJKmiagLcR4dKkrSaqeZpZA92RCGSJKl6K3UWekRctqoLkSRJ1VvZ68AvXaVVSJKkFVLN08g+0bJfSmlSbcqRJEnVqOYktisiogF4AngI+L+U0jO1LUuSJFVSzUls/xwR3YDPAbsCd0ZEz5TSR2tdnCRJal27AR4RuwBfKH56A3dQ2hOXJEl1Uk0T+gOUngf+Y+CulNL7Na1IkiS1q5oA3xD4J2AocGxELAEeSSl5gxdJkuqkmmPgcyLiVWBToAHYGeha68IkSVLbqjkG/iowldJx74uBQ21GlySpvqppQv9kSmlJzSuRJElVa/dGLoa3JEmrn5W9laokSaojA1ySpAxVHeAR8fmIuDsiHoiIkbUsSpIkVdbmSWwR8fGU0ttlvY4DvgYE8Bhwa41rkyRJbah0FvolEfEkcHZKaREwBxgFLAHmdURxkiSpdW02oaeURgJPAXdExMHA94DuwAaATeiSJNVRxWPgKaXbgS8BvYBfAy+mlM5PKc3qiOIkSVLr2gzwiPhqRNwP3A08C3wD2CciboiIT3RUgZIkaXmVjoGPBYYA6wD3pJSGAMdHxFbAj4ADOqA+SZLUikoBPhfYF1gXmNncM6X0Eoa3JEl1VekY+NconbDWBfhmx5QjSZKq0eYeeErpz8B/d2AtkiSpSt5KVZKkDBngkiRlqN0Aj4hjIqJPRxQjSZKqU80e+MeAJyLiVxGxV0RErYuSJEmVtRvgKaVTga2Ay4HRwEsRcZY3c5EkqX6qOgaeUkrA28XPYqAPcHNEnF3D2iRJUhsq3cgFgIj4LnAw8GdgPHBiSumDiOgEvAScVNsSJUlSS+0GOPBRYN+U0uvlPVNKSyJiRG3KkiRJlVTThP4b4C/NHRGxfkTsCJBSer5WhUmSpLZVE+AXA/PLuucX/SRJUp1UE+BRnMQGlJrOqa7pXZIk1Ug1Af5qRBwbEV2Ln+8Cr9a6MEmS1LZqAvxIYGfgDWAGsCNwRC2LkiRJlbXbFJ5SmonP/5YkabVSzXXgPYAxwNZAj+b+KaXDaliXJEmqoJom9GuBjwNfAh4EGoC/1rIoSZJUWTUB/smU0mnAgpTS1cCXKR0HlyRJdVJNgH9Q/J4TEdsAvYCNaleSJElqTzXXc19WPA/8VGAC0BM4raZVSZKkiioGePHAknkppXeB/wO27JCqJElSRRWb0Iu7rvm0MUmSVjPVHAP/XUScEBGbRsRHm39qXpkkSWpTNcfAv1H8PrqsX8LmdEmS6qaaO7Ft0RGFSJKk6lVzJ7aDW+ufUrpm1ZcjSZKqUU0T+ufKXvcAhgFPAga4JEl1Uk0T+jHl3RHRG7ihZhVJkqR2VXMWeksLAI+LS5JUR9UcA7+d0lnnUAr8gcCvql1ARHQGmoA3UkojImILSnvwGwCTgH9JKb2/ooVLkrQ2q+YY+LllrxcDr6eUZqzAMr4LPA+sX3T/FBiXUrohIi6h9KjSi1dgfpIkrfWqaUKfBjyWUnowpfR7YHZE9K9m5hHRQOnpZeOL7gB2B24uRrkaGLmCNUuStNarJsBvApaUdX9Y9KvGf1G6FWvz9BsAc1JKi4vuGUC/1iaMiCMioikimmbNmlXl4iRJWjtUE+Bdyo9RF6+7tTdRRIwAZqaUJq1MYSmly1JKjSmlxr59+67MLCRJWmNVcwx8VkR8NaU0ASAi9gH+XMV0/wR8NSKGU7p+fH3g50DviOhS7IU3AG+sXOmSJK29qtkDPxL4QURMi4hpwL8D325vopTS/0spNaSU+gMHAPellA4C7gdGFaMdAty2UpVLkrQWq+ZGLq8An4+InkX3/H9wmf8O3BARY4GngMv/wflJkrTWaXcPPCLOiojeKaX5KaX5EdGnCN+qpZQeSCmNKF6/mlIaklL6ZEppv5TS31a2eEmS1lbVNKHvnVKa09yRUnoXGF67kiRJUnuqCfDOEdG9uSMi1gG6VxhfkiTVWDVnoV8PTIyIK4vuQ/FJZJIk1VU1J7H9NCKeBr5Y9PphSume2pYlSZIqqWYPnJTS3cDdABGxS0RcmFI6uqaVSZKkNlUV4BGxHXAgsD/wJ+CWWhYlSZIqazPAI+JTlEL7QEp3XrsRiJTSbh1UmyRJakOlPfCpwEPAiJTSywAR8f0OqUqSJFVU6TKyfYG3gPsj4hcRMQyIjilLkiRV0maAp5RuTSkdAHyG0v3LvwdsFBEXR8SeHVWgJElaXrs3ckkpLUgp/TKl9BVKTw97itL9zCVJUp1Ucye2pVJK7xbP6R5Wq4IkSVL7VijAJUnS6sEAlyQpQwa4JEkZMsAlScqQAS5JUoYMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuSVKGDHBJkjJkgEuSlCEDXJKkDBngkiRlyACXJClDBrgkSRkywCVJypABLklShgxwSZIyZIBLkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZMsAlScqQAS5JUoYMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuSVKGDHBJkjJkgEuSlCEDXJKkDBngkiRlyACXJClDBrgkSRkywCVJypABLklShgxwSZIyZIBLkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZMsAlScqQAS5JUoYMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuSVKGDHBJkjJkgEuSlCEDXJKkDBngkiRlyACXJClDBrgkSRkywCVJypABLklShgxwSZIyZIBLkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZMsAlScqQAS5JUoYMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuSVKGDHBJkjJkgEuSlCEDXJKkDBngkiRlyACXJClDBrgkSRkywCVJylDNAjwiNo2I+yNiSkQ8FxHfLfp/NCLujYiXit99alWDJElrqlrugS8Gjk8pDQQ+DxwdEQOBk4GJKaWtgIlFtyRJWgE1C/CU0lsppSeL138Fngf6AfsAVxejXQ2MrFUNkiStqTrkGHhE9Ae2Ax4DPpZSeqsY9DbwsTamOSIimiKiadasWR1RpiRJ2ah5gEdET+B/ge+llOaVD0spJSC1Nl1K6bKUUmNKqbFv3761LlOSpKzUNMAjoiul8L4+pXRL0fudiNi4GL4xMLOWNUiStCaq5VnoAVwOPJ9S+lnZoAnAIcXrQ4DbalWDJElrqi41nPc/Af8CPBMRk4t+PwB+AvwqIsYArwP717AGSZLWSDUL8JTSw0C0MXhYrZYrSdLawDuxSZKUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuSVKGDHBJkjJkgEuSlCEDXJKkDBngkiRlyACXJClDBrgkSRkywCVJypABLklShgxwSZIyZIBLkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZMsAlScqQAS5JUoYMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuSVKGDHBJkjJkgEuSlCEDXJKkDBngkiRlyACXJClDBrgkSRkywCVJypABLklShgxwSZIyZIBLkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZMsAlScqQAS5JUoYMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuSVKGDHBJkjJkgEuSlCEDXJKkDBngkiRlyACXJClDBrgkSRkywCVJypABLklShgxwSZIyZIBLkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZMsAlScqQAS5JUoYMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuSVKGDHBJkjJkgEuSlCEDXJKkDBngkiRlyACXJClDBrgkSRkywCVJypABLklShgxwSZIyZIBLkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZMsAlScqQAS5JUoYMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuSVKGDHBJkjJUlwCPiL0i4oWIeDkiTq5HDZIk5azDAzwiOgMXAnsDA4EDI2JgR9chSVLO6rEHPgR4OaX0akrpfeAGYJ861CFJUra61GGZ/YDpZd0zgB1bjhQRRwBHFJ3zI+KFDqhNKyFgQ+DP9a5jjXFG1LsCrUX8/K5itfn8bt5az3oEeFVSSpcBl9W7DrUvIppSSo31rkPSivPzm696NKG/AWxa1t1Q9JMkSVWqR4A/AWwVEVtERDfgAGBCHeqQJClbHd6EnlJaHBHfAe4BOgNXpJSe6+g6tEp5qEPKl5/fTEVKqd41SJKkFeSd2CRJypABLklShgxwERHz2xnePyKeXcF5XhURo1rp/9GIuDciXip+91nReiX9XQd/fveLiOciYklEeOlZnRng6mgnAxNTSlsBE4tuSXl4FtgX+L96FyIDXGUiomdETIyIJyPimYgov8Vtl4i4PiKej4ibI2LdYpodIuLBiJgUEfdExMbtLGYf4Ori9dXAyBqsirTW6YjPb0rp+ZSSd8VcTRjgKrcI+FpKaXtgN+C8iGi+L+CngYtSSgOAecBREdEV+G9gVEppB+AK4EftLONjKaW3itdvAx9b1SshraU64vOr1chqeytV1UUAZ0XEUGAJpfvWNwfs9JTS74vX1wHHAncD2wD3Fv8nOgNvUaWUUooIr2OUVo0O/fyq/gxwlTsI6AvskFL6ICJeA3oUw1oGbaL0D+O5lNJOK7CMdyJi45TSW0Vz3cx/tGhJQMd8frUasQld5XoBM4sP/24s+wSczSKi+YP+TeBh4AWgb3P/iOgaEVu3s4wJwCHF60OA21ZZ9dLarSM+v1qNGOAqdz3QGBHPAAcDU8uGvQAcHRHPA32Ai4vnuY8CfhoRTwOTgZ3bWcZPgD0i4iXgi0W3pH9czT+/EfG1iJgB7ATcGRH31GA9VCVvpSpJUobcA5ckKUMGuCRJGTLAJUnKkAEuSVKGDHBJkjJkgEuSlCEDXJKkDP1/5Tftbp+0Cp8AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 504x576 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "labels = ['label 0', 'label 1']\n",
    "theirs = [82.4, 86.6]\n",
    "ours = [83.9, 88.1]\n",
    "\n",
    "x = np.arange(len(labels))  # the label locations\n",
    "width = 0.35  # the width of the bars\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(7,8))\n",
    "rects1 = ax.bar(x - width/2, theirs, width, label='GNN-Explainer')\n",
    "rects2 = ax.bar(x + width/2, ours, width, label='Ours')\n",
    "\n",
    "# Add some text for labels, title and custom x-axis tick labels, etc.\n",
    "ax.set_ylabel('Accuracy % w.r.t GT')\n",
    "ax.set_title('Comparison of ours vs GNN-Explainer')\n",
    "ax.set_xticks(x)\n",
    "ax.set_xticklabels(labels)\n",
    "ax.legend()\n",
    "\n",
    "def autolabel(rects):\n",
    "    \"\"\"Attach a text label above each bar in *rects*, displaying its height.\"\"\"\n",
    "    for rect in rects:\n",
    "        height = rect.get_height()\n",
    "        ax.annotate('{}'.format(height),\n",
    "                    xy=(rect.get_x() + rect.get_width() / 2, height),\n",
    "                    xytext=(0, 3),  # 3 points vertical offset\n",
    "                    textcoords=\"offset points\",\n",
    "                    ha='center', va='bottom')\n",
    "\n",
    "\n",
    "autolabel(rects1)\n",
    "autolabel(rects2)\n",
    "\n",
    "fig.tight_layout()\n",
    "\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "ckpt = torch.load(\"../../gcn_interpretation/gnn-model-explainer/ckpt/Mutagenicity_base_h20_o20.pth.tar\")\n",
    "# ckpt = torch.load(\"../../gcn_interpretation/gnn-model-explainer/ckpt/synthetic_base_h20_o20.pth.tar\")\n",
    "# ckpt2 = torch.load(\"./dnn_invariant/mdls/Mutagenicity_base_h20_o20.pth.tar\")\n",
    "\n",
    "\n",
    "import sys\n",
    "sys.path.append('/home/mohit/Mohit/model_interpretation/ai-adversarial-detection/dnn_invariant')\n",
    "\n",
    "\n",
    "from dnn_invariant.models.models_gcn import *\n",
    "preds = ckpt['cg']['pred'][0,:,:]\n",
    "# model = \n",
    "\n",
    "model = GcnEncoderGraph(\n",
    "        14,#14,  # input_dim,\n",
    "        20,  # args.hidden_dim,\n",
    "        20,  # args.output_dim,\n",
    "        2,#2,  # args.num_classes,\n",
    "        3,#3,  # args.num_gc_layers,\n",
    "        bn=False,\n",
    "        dropout=0.0,\n",
    "        args=None,\n",
    "    ).cuda()\n",
    "\n",
    "# model = GcnEncoderGraph(\n",
    "#         14,#14,  # input_dim,\n",
    "#         20,  # args.hidden_dim,\n",
    "#         20,  # args.output_dim,\n",
    "#         2,#2,  # args.num_classes,\n",
    "#         3,#3,  # args.num_gc_layers,\n",
    "#         bn=False,\n",
    "#         dropout=0.0,\n",
    "#         args=None,\n",
    "#     ).cuda()\n",
    "\n",
    "\n",
    "\n",
    "model.load_state_dict(ckpt['model_state'])\n",
    "# model.load_state_dict(ckpt2['model_state'])\n",
    "\n",
    "import numpy as np\n",
    "label = 0 #0\n",
    "indices = np.argmax(ckpt['cg']['pred'][0,:,:], axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0, 2, 4, 10, 13, 14, 15, 16, 17, 18]\n"
     ]
    }
   ],
   "source": [
    "l_idx = (indices==label).nonzero()[0].tolist()\n",
    "print(l_idx[:10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1742 1742\n",
      "mask density:  0.5076695498441854\n"
     ]
    }
   ],
   "source": [
    "import pickle\n",
    "masked_adjs = pickle.load(open(\"../../gcn_interpretation/data/Mutagenicity/masked_adjs_0.p\",\"rb\"))\n",
    "# masked_adjs = pickle.load(open(\"../../gcn_interpretation/data/Mutagenicity/masked_adjs_explainer_0.p\",\"rb\"))\n",
    "\n",
    "# masked_adjs = pickle.load(open(\"../../gcn_interpretation/data/synthetic_data_3label_3sublabel/masked_adjs_2.p\",\"rb\"))\n",
    "# masked_adjs = pickle.load(open(\"../../gcn_interpretation/data/synthetic_data_3label_3sublabel/masked_adjs_explainer_0.p\",\"rb\"))\n",
    "print(len(l_idx), len(masked_adjs))\n",
    "masked_adjs = np.stack(masked_adjs,axis=0)\n",
    "print(\"mask density: \", np.sum(masked_adjs)/np.sum((ckpt['cg']['adj'].numpy())[l_idx,:,:]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1016, 20, 20)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1742\n",
      "torch.Size([1742, 20])\n"
     ]
    }
   ],
   "source": [
    "graph_emb_cat = torch.zeros((0,20)).float()\n",
    "flip_labels = 0.\n",
    "print(len(masked_adjs))\n",
    "for eix, ix in enumerate(l_idx):\n",
    "#     adj = torch.from_numpy(masked_adjs[eix]).unsqueeze(0)\n",
    "    adj = ckpt['cg']['adj'][ix:ix+1,:,:]\n",
    "    feat = ckpt['cg']['feat'][ix:ix+1,:,:]\n",
    "    num_nodes = [ckpt['cg']['num_nodes'][ix]]\n",
    "    with torch.no_grad():\n",
    "        \n",
    "        adj = adj.float().cuda()\n",
    "        feat = feat.float().cuda()\n",
    "        pred = model(feat, adj, num_nodes)[0].cpu().numpy()\n",
    "        if (np.argmax(pred[0]) != np.argmax(ckpt['cg']['pred'][0,ix,:])):\n",
    "            print(pred[0])\n",
    "            flip_labels += 1\n",
    "        \n",
    "        graph_emb, _ = model._getOutputOfOneLayer_Group(adj, feat, num_nodes)\n",
    "        graph_emb = graph_emb.cpu()\n",
    "        graph_emb_cat = torch.cat((graph_emb_cat, graph_emb),dim=0)\n",
    "print(graph_emb_cat.shape)\n",
    "# print(flip_labels)#14/1008, 111/17**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "# pickle.dump(l_idx, open(\"../../gcn_interpretation/data/Mutagenicity/index_label0.p\", \"wb\"))\n",
    "# torch.save(graph_emb_cat,\"../../gcn_interpretation/data/Mutagenicity/graph_embeds_label0.p\")\n",
    "pickle.dump(l_idx, open(\"../../gcn_interpretation/data/synthetic_data_3label_3sublabel/index_label0.p\", \"wb\"))\n",
    "torch.save(graph_emb_cat,\"../../gcn_interpretation/data/synthetic_data_3label_3sublabel/graph_embeds_label0.p\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import pickle\n",
    "graph_emb_cat = torch.load(\"../../gcn_interpretation/data/Mutagenicity/graph_embeds_label0.p\")\n",
    "ckpt = torch.load(\"../../gcn_interpretation/gnn-model-explainer/ckpt/Mutagenicity_base_h20_o20.pth.tar\")\n",
    "l_idx = pickle.load(open(\"../../gcn_interpretation/data/Mutagenicity/index_label0.p\", \"rb\"))\n",
    "masked_adjs = pickle.load(open(\"../../gcn_interpretation/data/Mutagenicity/masked_adjs.p\",\"rb\"))\n",
    "# masked_adjs = pickle.load(open(\"../../gcn_interpretation/data/Mutagenicity/masked_adjs_explainer.p\",\"rb\"))\n",
    "projections = pickle.load(open(\"../../gcn_interpretation/data/Mutagenicity/projections_label0.p\",\"rb\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import networkx as nx\n",
    "from grakel.utils import graph_from_networkx\n",
    "\n",
    "from grakel.kernels import WeisfeilerLehman, VertexHistogram, SubgraphMatching\n",
    "from grakel import Graph\n",
    "\n",
    "\n",
    "import networkx as nx\n",
    "def node_match_func(n1, n2):\n",
    "#     print(\"n1:\", n1['feat'])\n",
    "#     print(\"n2: \", n2['feat'])\n",
    "    n1q = (np.sum(n1['feat'] == n2['feat']) == 14)\n",
    "#     print(\"n1q: \", n1q)\n",
    "    return n1q\n",
    "\n",
    "def edge_match_func(e1, e2):\n",
    "    return e1['weight'] == e2['weight']\n",
    "\n",
    "\n",
    "def getGrakelGraph(nx_G):\n",
    "    mat = nx.to_numpy_matrix(nx_G)\n",
    "    attr = nx.get_node_attributes(nx_G,'label')\n",
    "    node_labels = {}\n",
    "    attr_count = 0\n",
    "    for k in range(100):\n",
    "        if k in attr:\n",
    "            node_labels[attr_count] = str(attr[k])\n",
    "            attr_count += 1\n",
    "    mat_dict = {}\n",
    "    edge_labels = {}\n",
    "    for i in range(mat.shape[0]):\n",
    "        if i not in mat_dict:\n",
    "            mat_dict[i] = []\n",
    "        for j in range(mat.shape[1]):\n",
    "            if mat[i,j] > 0. and i != j:\n",
    "                mat_dict[i].append(j)\n",
    "                edge_labels[(i,j)] = 'a'\n",
    "                \n",
    "    \n",
    "#     G_t = Graph(mat_dict, node_labels=node_labels)\n",
    "    G_t = Graph(mat_dict, node_labels=node_labels, edge_labels=edge_labels)\n",
    "    return G_t\n",
    "\n",
    "def getNXGraph(adj, feat, thresh=None, indx = None, edge_thresh = 16):\n",
    "    num_nodes = adj.shape[-1]\n",
    "    G = nx.Graph()\n",
    "    G.add_nodes_from(range(num_nodes))\n",
    "    \n",
    "    for node in G.nodes():\n",
    "        G.nodes[node][\"feat\"] = feat[node]\n",
    "        G.nodes[node][\"label\"] = np.argmax(feat[node])\n",
    "    \n",
    "    weighted_edge_list = []\n",
    "    if indx is not None:\n",
    "        hnodes = []\n",
    "        for b in range(2):\n",
    "#             hnodes.extend(hnodes_dict[(indx,b)])\n",
    "            b_hnodes = hnodes_dict[(indx,b)]\n",
    "            for i in range(len(b_hnodes)):\n",
    "                weighted_edge_list.append((b_hnodes[i], b_hnodes[((i+1)%4)], 1.0))\n",
    "                weighted_edge_list.append((b_hnodes[((i+1)%4)], b_hnodes[i], 1.0))\n",
    "                \n",
    "#         for i in range(num_nodes):\n",
    "#             if i not in hnodes:\n",
    "#                 continue\n",
    "#             for j in range(num_nodes):\n",
    "#                 if j not in hnodes:\n",
    "#                     continue\n",
    "#                 weighted_edge_list.append((i, j, 1.0))\n",
    "    else:\n",
    "        \n",
    "#         if thresh == None:\n",
    "#             threshold_num = 20\n",
    "#             # this is for symmetric graphs: edges are repeated twice in adj\n",
    "#             adj_threshold_num = threshold_num * 2\n",
    "#             #adj += np.random.rand(adj.shape[0], adj.shape[1]) * 1e-4\n",
    "#             neigh_size = len(adj[adj > 0])\n",
    "#             threshold_num = min(neigh_size, adj_threshold_num)\n",
    "#             threshold = np.sort(adj[adj > 0])[-threshold_num]\n",
    "#             thresh = threshold\n",
    "    #     thresh = 0.6\n",
    "        if thresh != None:\n",
    "            for i in range(num_nodes):\n",
    "                for j in range(num_nodes):\n",
    "\n",
    "                    if adj[i, j] > thresh:\n",
    "                        weighted_edge_list.append((i, j, 1.0))\n",
    "#                 valid_edges += 1.0\n",
    "        else:\n",
    "        \n",
    "            argsort_adj = np.dstack(np.unravel_index(np.argsort(adj.ravel()), (num_nodes, num_nodes)))[0]\n",
    "            edge_count = 0.\n",
    "            for i in range(num_nodes*num_nodes-1,-1,-1):\n",
    "                x = argsort_adj[i][0]\n",
    "                y = argsort_adj[i][1]\n",
    "                if adj[x, y] > 0.:\n",
    "                    weighted_edge_list.append((x, y, 1.0))\n",
    "                    edge_count += 1.0\n",
    "                if edge_count >= edge_thresh:\n",
    "                    break\n",
    "            \n",
    "            \n",
    "            \n",
    "            \n",
    "            \n",
    "    G.add_weighted_edges_from(weighted_edge_list)\n",
    "    remove_nodes = list(nx.isolates(G))\n",
    "#     print(\"remove nodes: \", remove_nodes)\n",
    "\n",
    "    G.remove_nodes_from(remove_nodes)\n",
    "    return G\n",
    "\n",
    "\n",
    "\n",
    "            \n",
    "def calc_GED(idx, nbrs_np):\n",
    "    ged_batch = 0.\n",
    "    nbr_count = 0.\n",
    "    ged_batch_top10 = 0.\n",
    "    nbrs = nbrs_np.tolist()\n",
    "#     assert(nbrs[0] == idx)\n",
    "#     print(\"num nodes: \", ckpt['cg']['num_nodes'][l_idx[idx]])\n",
    "\n",
    "    G_idx = getNXGraph(masked_adjs[idx], ckpt['cg']['feat'][l_idx[idx],:,:].numpy())#, indx=l_idx[idx])\n",
    "    G_i_t = getGrakelGraph(G_idx)\n",
    "#     gk = WeisfeilerLehman(normalize=True, base_graph_kernel=VertexHistogram)\n",
    "    gk = SubgraphMatching(normalize=True)\n",
    "\n",
    "    self_sim = gk.fit_transform([G_i_t])\n",
    "\n",
    "    \n",
    "    for nix, nbr in enumerate(nbrs[1:]):\n",
    "#         print(\"num nodes nbr: \", ckpt['cg']['num_nodes'][l_idx[nbr]])\n",
    "        G_nbr = getNXGraph(masked_adjs[nbr], ckpt['cg']['feat'][l_idx[nbr],:,:].numpy())#, indx=l_idx[nbr])\n",
    "        G_n_t = getGrakelGraph(G_nbr)\n",
    "        sim = gk.transform([G_n_t])[0][0]\n",
    "        \n",
    "#         print(\"calculating ged...\")\n",
    "        \n",
    "#         ged_nbr = nx.algorithms.similarity.graph_edit_distance(G_idx, G_nbr,node_match=node_match_func, upper_bound=9)\n",
    "        \n",
    "#         ged_nbr_l = nx.algorithms.similarity.optimize_graph_edit_distance(G_idx, G_nbr,node_match=node_match_func, upper_bound=10.0)\n",
    "#         ged_nbr = 10.0\n",
    "#         ged_count = 0\n",
    "#         for ged_nbr_e in ged_nbr_l:\n",
    "#             if ged_nbr_e == None:\n",
    "#                 ged_nbr_e = 10.0\n",
    "#             ged_nbr = min(ged_nbr, ged_nbr_e)\n",
    "#             if(ged_count > 2):\n",
    "#                 break\n",
    "#             ged_count += 1\n",
    "#         if ged_nbr == None:\n",
    "#             ged_nbr = 10.0\n",
    "#         print(\"nix: \", nix, \"ged: \", ged_nbr)\n",
    "    \n",
    "        \n",
    "        if nix < 10:\n",
    "            ged_batch_top10 += sim\n",
    "        ged_batch += sim\n",
    "        nbr_count += 1.0\n",
    "        \n",
    "    return ged_batch/nbr_count, ged_batch_top10/10.0\n",
    "        \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from sklearn.metrics.pairwise import euclidean_distances\n",
    "graph_emb_np = graph_emb_cat.numpy()\n",
    "indices_covered = []\n",
    "net_avg_ged = 0.\n",
    "net_avg_ged_top10 = 0.\n",
    "net_count = 0.\n",
    "ged_list = []\n",
    "for i in range(graph_emb_np.shape[0]):\n",
    "#     if i in indices_covered:\n",
    "#         continue\n",
    "    arr1 = graph_emb_np[i:i+1,:]\n",
    "    dist = euclidean_distances(graph_emb_np, arr1)\n",
    "    nbrs = dist[:,0].argsort()[:21]\n",
    "    \n",
    "#     arr1 = projections[i:i+1,:]\n",
    "#     dist = euclidean_distances(projections, arr1)\n",
    "#     nbrs = dist[:,0].argsort()[:21]\n",
    "    \n",
    "    avg_ged, avg_ged_top10 = calc_GED(i, nbrs)\n",
    "    \n",
    "    ged_list.append((avg_ged, avg_ged_top10, nbrs))\n",
    "    net_avg_ged += avg_ged\n",
    "    net_avg_ged_top10 += avg_ged_top10\n",
    "    net_count += 1.0\n",
    "    \n",
    "    if net_count > 400:\n",
    "        break\n",
    "    print(\"Net count: \", net_count, \"ged: \", avg_ged, \" \",avg_ged_top10)\n",
    "    indices_covered.extend(nbrs.tolist())\n",
    "    \n",
    "    tpl = (indices_covered, ged_list, net_avg_ged, net_avg_ged_top10, net_count)\n",
    "    pickle.dump(tpl, open(\"./ged_result.p\",\"wb\"))\n",
    "    \n",
    "\n",
    "print(\"Average Sim: \", net_avg_ged/net_count, \"Average Sim top10: \", net_avg_ged_top10/net_count)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "rule_file = \"../../gcn_interpretation/data/Mutagenicity/rule_dict_Mutagenicity_train.p\"\n",
    "# rule_file = \"../../gcn_interpretation/data/synthetic_data_3label_3sublabel/rule_dict_synthetic_train_4k8000_comb_12dlbls_nofake.p\"\n",
    "rule_dict = pickle.load(open(rule_file, \"rb\"))\n",
    "# rule_dict['rules'][r_num]['boundary'][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2.43796682 0.        ]\n"
     ]
    }
   ],
   "source": [
    "graph_emb_np = graph_emb_cat.numpy()\n",
    "projections = np.zeros((graph_emb_np.shape[0],2))\n",
    "for j in range(graph_emb_np.shape[0]):\n",
    "    ix = l_idx[j]\n",
    "    r_num = rule_dict['idx2rule'][ix]\n",
    "    for b_ix in range(len(rule_dict['rules'][r_num]['boundary'])):\n",
    "        boundary = rule_dict['rules'][r_num]['boundary'][b_ix]['basis']\n",
    "        projections[j,b_ix] = np.sum(graph_emb_np[j]*boundary)\n",
    "    \n",
    "print(projections[1])  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2.43796682 0.        ]\n"
     ]
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "hnodes_dict = pickle.load(open(\"../../gcn_interpretation/data/synthetic_data_3label_3sublabel/hnodes_dict_synthetic_train_4k8000_comb_12dlbls_nofake.p\", \"rb\"))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[   1 1236 1094  736 1417 1297  931 1293  855  607 1614  946  359  558\n",
      " 1722 1345  232  205  195  821  701]\n",
      "[   1   18 1446  995  203  878  846 1295   90  409 1392 1262  820  171\n",
      "  251 1120  674  228  262  543 1730]\n"
     ]
    }
   ],
   "source": [
    "# pickle.dump(projections, open(\"../../gcn_interpretation/data/Mutagenicity/projections_label0.p\",\"wb\"))\n",
    "graph_emb_np = graph_emb_cat.numpy()\n",
    "from sklearn.metrics.pairwise import euclidean_distances\n",
    "i=1\n",
    "arr1 = graph_emb_np[i:i+1,:]\n",
    "dist = euclidean_distances(graph_emb_np, arr1)\n",
    "nbrs = dist[:,0].argsort()[:21]\n",
    "print(nbrs)\n",
    "arr1 = projections[i:i+1,:]\n",
    "dist = euclidean_distances(projections, arr1)\n",
    "p_nbrs = dist[:,0].argsort()[:21]\n",
    "print(p_nbrs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import networkx as nx\n",
    "from grakel.utils import graph_from_networkx\n",
    "\n",
    "from grakel.kernels import WeisfeilerLehman, VertexHistogram\n",
    "from grakel import Graph\n",
    "from grakel.kernels import PropagationAttr, ShortestPath, SubgraphMatching\n",
    "\n",
    "\n",
    "# node_labels = ['A','B','C','D','E','F','G','H','I','J','K','L']\n",
    "node_labels = ['C','O','Cl','H','N','F','Br','S','P','I','Na','K','Li','Ca']\n",
    "\n",
    "def node_match_func(n1, n2):\n",
    "#     print(\"n1:\", n1['feat'])\n",
    "#     print(\"n2: \", n2['feat'])\n",
    "    n1q = (np.sum(n1['feat'] == n2['feat']) == 14)\n",
    "#     print(\"n1q: \", n1q)\n",
    "    return n1q\n",
    "\n",
    "def getGrakelGraph(nx_G):\n",
    "    mat = nx.to_numpy_matrix(nx_G)\n",
    "    attr = nx.get_node_attributes(nx_G,'label')\n",
    "    node_labels = {}\n",
    "    attr_count = 0\n",
    "    for k in range(100):\n",
    "        if k in attr:\n",
    "            node_labels[attr_count] = str(attr[k])\n",
    "            attr_count += 1\n",
    "    mat_dict = {}\n",
    "    edge_labels = {}\n",
    "    for i in range(mat.shape[0]):\n",
    "        if i not in mat_dict:\n",
    "            mat_dict[i] = []\n",
    "        for j in range(mat.shape[1]):\n",
    "            if mat[i,j] > 0. and i != j:\n",
    "                mat_dict[i].append(j)\n",
    "                edge_labels[(i,j)] = 'a'\n",
    "                \n",
    "    print(mat.shape)\n",
    "    print(node_labels)\n",
    "    print(mat_dict)\n",
    "#     G_t = Graph(mat_dict, node_labels=node_labels)\n",
    "    G_t = Graph(mat_dict, node_labels=node_labels, edge_labels=edge_labels)\n",
    "    return G_t\n",
    "\n",
    "        \n",
    "def getNXGraph(adj, feat, thresh=None, indx = None, edge_thresh = 16):\n",
    "    num_nodes = adj.shape[-1]\n",
    "    G = nx.Graph()\n",
    "    G.add_nodes_from(range(num_nodes))\n",
    "    node_hlight_dict = {}\n",
    "    for node in G.nodes():\n",
    "        G.nodes[node][\"feat\"] = feat[node]\n",
    "        G.nodes[node][\"label\"] = np.argmax(feat[node])\n",
    "    \n",
    "    weighted_edge_list = []\n",
    "    if indx is not None:\n",
    "        hnodes = []\n",
    "        for b in range(2):\n",
    "#             hnodes.extend(hnodes_dict[(indx,b)])\n",
    "            b_hnodes = hnodes_dict[(indx,b)]\n",
    "            for i in range(len(b_hnodes)):\n",
    "                weighted_edge_list.append((b_hnodes[i], b_hnodes[((i+1)%4)], 1.0))\n",
    "                weighted_edge_list.append((b_hnodes[((i+1)%4)], b_hnodes[i], 1.0))\n",
    "                \n",
    "#         for i in range(num_nodes):\n",
    "#             if i not in hnodes:\n",
    "#                 continue\n",
    "#             for j in range(num_nodes):\n",
    "#                 if j not in hnodes:\n",
    "#                     continue\n",
    "#                 weighted_edge_list.append((i, j, 1.0))\n",
    "    else:\n",
    "        \n",
    "#         if thresh == None:\n",
    "#             threshold_num = 20\n",
    "#             # this is for symmetric graphs: edges are repeated twice in adj\n",
    "#             adj_threshold_num = threshold_num * 2\n",
    "#             #adj += np.random.rand(adj.shape[0], adj.shape[1]) * 1e-4\n",
    "#             neigh_size = len(adj[adj > 0])\n",
    "#             threshold_num = min(neigh_size, adj_threshold_num)\n",
    "#             threshold = np.sort(adj[adj > 0])[-threshold_num]\n",
    "#             thresh = threshold\n",
    "#     #     thresh = 0.6\n",
    "\n",
    "        if thresh != None:\n",
    "            for i in range(num_nodes):\n",
    "                for j in range(num_nodes):\n",
    "\n",
    "                    if adj[i, j] > thresh:\n",
    "                        weighted_edge_list.append((i, j, 1.0))\n",
    "#                 valid_edges += 1.0\n",
    "\n",
    "        else:\n",
    "        \n",
    "            argsort_adj = np.dstack(np.unravel_index(np.argsort(adj.ravel()), (num_nodes, num_nodes)))[0]\n",
    "            edge_count = 0.\n",
    "            for i in range(num_nodes*num_nodes-1,-1,-1):\n",
    "                x = argsort_adj[i][0]\n",
    "                y = argsort_adj[i][1]\n",
    "                if adj[x, y] > 0.:\n",
    "                    node_hlight_dict[x] = 1\n",
    "                    node_hlight_dict[y] = 1\n",
    "                    weighted_edge_list.append((x, y, 1.0))\n",
    "                    edge_count += 1.0\n",
    "                if edge_count >= edge_thresh:\n",
    "                    break\n",
    "                \n",
    "            \n",
    "            \n",
    "            \n",
    "            \n",
    "            \n",
    "    G.add_weighted_edges_from(weighted_edge_list)\n",
    "    remove_nodes = list(nx.isolates(G))\n",
    "#     print(\"remove nodes: \", remove_nodes)\n",
    "\n",
    "    G.remove_nodes_from(remove_nodes)\n",
    "    return G, node_hlight_dict\n",
    "\n",
    "\n",
    "\n",
    "vmax=19\n",
    "cmap = plt.get_cmap(\"tab20\")\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "fig, ax_l = plt.subplots(4,1, figsize=(30,100))\n",
    "# plt.switch_backend(\"agg\")\n",
    "\n",
    "j = 1\n",
    "idx = l_idx[j]\n",
    "print(\"idx: \", idx)\n",
    "\n",
    "m_adj = masked_adjs[j]\n",
    "print(\"m_adj: \", m_adj.shape)\n",
    "print(\"mask density: \", np.sum(m_adj)/np.sum(ckpt['cg']['adj'][idx].numpy()))\n",
    "print(\"num nodes: \", ckpt['cg']['num_nodes'][idx].item())\n",
    "# m_adj = ckpt['cg']['adj'][idx].numpy()\n",
    "f = ckpt['cg']['feat'][idx].numpy()\n",
    "\n",
    "G1, highlight_dict = getNXGraph(m_adj, f)\n",
    "\n",
    "node_colors = []\n",
    "labels_dict = {}\n",
    "node_hlight = []\n",
    "for n in G1.nodes:\n",
    "    lbl = np.argmax(G1.nodes[n]['feat'])\n",
    "    node_colors.append(lbl)\n",
    "    labels_dict[n] = node_labels[lbl]\n",
    "    \n",
    "\n",
    "\n",
    "    \n",
    "# pos_layout = nx.kamada_kawai_layout(G1, weight=None)\n",
    "\n",
    "# nx.draw_networkx(G1, font_size=20, ax=ax_l[0],  vmax=19, vmin=0,\n",
    "#                      node_size=900, cmap=cmap, alpha=0.8, node_color=node_colors, labels = labels_dict)\n",
    "\n",
    "\n",
    "adj = ckpt['cg']['adj'][idx].numpy()\n",
    "G_orig, _ = getNXGraph(adj, f, thresh=0.1)\n",
    "\n",
    "node_colors = []\n",
    "labels_dict = {}\n",
    "for n in G_orig.nodes:\n",
    "    lbl = np.argmax(G_orig.nodes[n]['feat'])\n",
    "    node_colors.append(lbl)\n",
    "    labels_dict[n] = node_labels[lbl]\n",
    "    if n in highlight_dict:\n",
    "        node_hlight.append((0.9,0.2,0.2))\n",
    "    else:\n",
    "        node_hlight.append((0.9,0.9,0.9))\n",
    "    \n",
    "nx.draw_networkx(G_orig, font_size=20, ax=ax_l[0], \n",
    "                     node_size=900, alpha=0.8, node_color=node_hlight, labels = labels_dict)\n",
    "    \n",
    "nx.draw_networkx(G_orig, font_size=20, ax=ax_l[1],  vmax=19, vmin=0,\n",
    "                     node_size=900, cmap=cmap, alpha=0.8, node_color=node_colors, labels = labels_dict)\n",
    "\n",
    "\n",
    "j = 1236\n",
    "idx = l_idx[j]\n",
    "print(idx)\n",
    "m_adj = masked_adjs[j]\n",
    "print(\"m_adj: \", m_adj.shape)\n",
    "print(\"mask density: \", np.sum(m_adj)/np.sum(ckpt['cg']['adj'][idx].numpy()))\n",
    "print(\"num nodes: \", ckpt['cg']['num_nodes'][idx].item())\n",
    "# m_adj = ckpt['cg']['adj'][idx].numpy()\n",
    "f = ckpt['cg']['feat'][idx].numpy()\n",
    "G2, highlight_dict = getNXGraph(m_adj, f)\n",
    "\n",
    "\n",
    "\n",
    "G1_t = getGrakelGraph(G1)\n",
    "G2_t = getGrakelGraph(G2)\n",
    "\n",
    "# gk = WeisfeilerLehman(normalize=True, base_graph_kernel=VertexHistogram)\n",
    "# gk = WeisfeilerLehman(normalize=True, base_graph_kernel=ShortestPath)#PropagationAttr)\n",
    "gk = SubgraphMatching(normalize=True)\n",
    "self_sim = gk.fit_transform([G1_t])\n",
    "\n",
    "sim = gk.transform([G2_t])\n",
    "print(\"self sim: \", self_sim)\n",
    "print(\"sim: \", sim)\n",
    "# ged_nbr = nx.algorithms.similarity.graph_edit_distance(G1, G2,node_match=node_match_func, timeout=3) #upper_bound=19.0\n",
    "# ged_nbr_l = nx.algorithms.similarity.optimize_graph_edit_distance(G1, G2,node_match=node_match_func, upper_bound=10.0)\n",
    "# ged_nbr = 10.0\n",
    "# ged_count = 0\n",
    "# for ged_nbr_e in ged_nbr_l:\n",
    "#     if ged_nbr_e == None:\n",
    "#         ged_nbr_e = 10.0\n",
    "#     ged_nbr = min(ged_nbr, ged_nbr_e)\n",
    "#     if(ged_count > 2):\n",
    "#         break\n",
    "#     ged_count += 1\n",
    "# if ged_nbr == None:\n",
    "#     ged_nbr = 10.0\n",
    "# print(\"ged: \", ged_nbr)\n",
    "\n",
    "node_colors = []\n",
    "labels_dict = {}\n",
    "node_hlight = []\n",
    "\n",
    "for n in G2.nodes:\n",
    "    lbl = np.argmax(G2.nodes[n]['feat'])\n",
    "    node_colors.append(lbl)\n",
    "    labels_dict[n] = node_labels[lbl]\n",
    "#     labels_dict[n] = str(lbl)\n",
    "    \n",
    "# pos_layout = nx.kamada_kawai_layout(G2, weight=None)\n",
    "\n",
    "# nx.draw_networkx(G2, font_size=20, ax=ax_l[2],\n",
    "#                      node_size=900, cmap=cmap, alpha=0.8, vmax=19, vmin=0, node_color=node_colors, labels = labels_dict)\n",
    "\n",
    "\n",
    "adj = ckpt['cg']['adj'][idx].numpy()\n",
    "\n",
    "G_orig, _ = getNXGraph(adj, f, thresh=0.1)\n",
    "node_colors = []\n",
    "labels_dict = {}\n",
    "for n in G_orig.nodes:\n",
    "    lbl = np.argmax(G_orig.nodes[n]['feat'])\n",
    "    node_colors.append(lbl)\n",
    "    labels_dict[n] = node_labels[lbl]\n",
    "    if n in highlight_dict:\n",
    "        node_hlight.append((0.9,0.2,0.2))\n",
    "    else:\n",
    "        node_hlight.append((0.9,0.9,0.9))\n",
    "    \n",
    "nx.draw_networkx(G_orig, font_size=20, ax=ax_l[2], \n",
    "                     node_size=900, alpha=0.8, node_color=node_hlight, labels = labels_dict)\n",
    "    \n",
    "nx.draw_networkx(G_orig, font_size=20, ax=ax_l[3],  vmax=19, vmin=0,\n",
    "                     node_size=900, cmap=cmap, alpha=0.8, node_color=node_colors, labels = labels_dict)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "ename": "IndentationError",
     "evalue": "unindent does not match any outer indentation level (<tokenize>, line 148)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;36m  File \u001b[0;32m\"<tokenize>\"\u001b[0;36m, line \u001b[0;32m148\u001b[0m\n\u001b[0;31m    Average Sim:  0.19888569870152922 Average Sim top10:  0.19996254126560054\u001b[0m\n\u001b[0m    ^\u001b[0m\n\u001b[0;31mIndentationError\u001b[0m\u001b[0;31m:\u001b[0m unindent does not match any outer indentation level\n"
     ]
    }
   ],
   "source": [
    "\n",
    "#mutag label=>0\n",
    "\n",
    "    #explainer\n",
    "        #emb_dist\n",
    "            #mask density = 43%\n",
    "        Average Sim:  0.605687022388006 Average Sim top10:  0.6135237853281609\n",
    "        #top32 Average Sim:  0.582191012442349 Average Sim top10:  0.5900250171362915\n",
    "        #top16 Average Sim:  0.521146184919567 Average Sim top10:  0.5287009090012631\n",
    "\n",
    "                \n",
    "\n",
    "\n",
    "        #subgraphMatching\n",
    "        Average Sim:  0.7239744134928855 Average Sim top10:  0.7369110690810604\n",
    "        #top32 Average Sim:  0.7076160070107486 Average Sim top10:  0.7206085021332926\n",
    "        #top16 Average Sim:  0.702082991510595 Average Sim top10:  0.7099430291077392\n",
    "                \n",
    "Average mask density: 0.429056811546552\n",
    "Total graphs optimized:  1742\n",
    "Flips:  177.0\n",
    "Incorrect preds:  177\n",
    "                \n",
    "\n",
    "    #explain_boundary\n",
    "        mask density 50%\n",
    "        #emb_dist\n",
    "            Average Sim:  0.6071181693559123 Average Sim top10:  0.6141083523798263\n",
    "            #reverse\n",
    "#             Average Sim:  0.4003394140133608 Average Sim top10:  0.3968387194545565\n",
    "            #top32 Average Sim:  0.5890564027303362 Average Sim top10:  0.5958813414881281\n",
    "            #top16 Average Sim:  0.529105758339142 Average Sim top10:  0.5365310968594731\n",
    "\n",
    "\n",
    "            #SubgraphMatching\n",
    "            Average Sim:  0.7203930390000458 Average Sim top10:  0.7323282147519472\n",
    "            #reverse\n",
    "#             Average Sim:  0.3704747213739286 Average Sim top10:  0.36316027924782995\n",
    "            #top32 Average Sim:  0.7102692353599228 Average Sim top10:  0.7218997889857508\n",
    "            #top16 Average Sim:  0.7234416996740904 Average Sim top10:  0.7311349435997672\n",
    "\n",
    "                    \n",
    "        #projection\n",
    "            Average Sim:  0.5204907904326189 Average Sim top10:  0.5204994875298088\n",
    "            #r-projection\n",
    "            #Average Sim:  0.35233320097842413 Average Sim top10:  0.402782097190453\n",
    "            #top32 Average Sim:  0.4995944918184729 Average Sim top10:  0.49922168605374423\n",
    "            #top16 Average Sim:  0.44654142430272337 Average Sim top10:  0.4467787488579792\n",
    "\n",
    "            #SubgraphMatching       \n",
    "            Average Sim:  0.5444968876100226 Average Sim top10:  0.5469656673939503\n",
    "            #r-projection\n",
    "#             Average Sim:  0.330021045287607 Average Sim top10:  0.3914988701481082\n",
    "            #top32 Average Sim:  0.5397867852344546 Average Sim top10:  0.5408066013939422\n",
    "            #top16 Average Sim:  0.5942242171842282 Average Sim top10:  0.5944074384990546\n",
    "        \n",
    "        \n",
    "                    \n",
    "    \n",
    "Average mask density: 0.5406371715768286\n",
    "Flips:  14.0\n",
    "Incorrect preds:  177\n",
    "\n",
    "\n",
    "****************************\n",
    "\n",
    "#synthetic \n",
    "    #label 0\n",
    "        #explain_boundary  #md = 32.7%\n",
    "            #nearest nbr\n",
    "                Average Sim:  0.22586319114203468 Average Sim top10:  0.22761009731140125 \n",
    "                #top32 edges : Average Sim:  0.21680116615954914 Average Sim top10:  0.21797466639075383\n",
    "                #top16 Average Sim:  0.22349614971079276 Average Sim top10:  0.22384026607871818\n",
    "\n",
    "                #subgraphMatching\n",
    "                Average Sim:  0.41220313213567833 Average Sim top10:  0.41524322934465674\n",
    "                #top32 Average Sim:  0.1932423964541316 Average Sim top10:  0.19428077606177765\n",
    "                #top16 Average Sim:  0.4469595324961587 Average Sim top10:  0.44850401705959364\n",
    "  \n",
    "\n",
    "            #projection\n",
    "                Average Sim:  0.2222147290280678 Average Sim top10:  0.22072561334959878\n",
    "                #top32 Average Sim:  0.20670590685712883 Average Sim top10:  0.2068893038866337\n",
    "                #top16 Average Sim:  0.21544024721791316 Average Sim top10:  0.21381325528320638\n",
    "\n",
    "                #subgraphMatching\n",
    "                Average Sim:  0.40555562495146 Average Sim top10:  0.4046330304214079\n",
    "                #top32: Average Sim:  0.18188510143668993 Average Sim top10:  0.18204199215376032\n",
    "                #top16: Average Sim:  0.43433370693352497 Average Sim top10:  0.4338576906521956\n",
    "\n",
    "                \n",
    "Boundary wise top4 acc: 0.4212549603174603, top6 acc: 0.6502976190476191, top8 acc: 0.8813244047619048\n",
    "Rule wise top8 acc: 0.8813244047619048\n",
    "Average mask density: 0.32757252295287176\n",
    "\n",
    "    \n",
    "Boundary wise top4 acc: 0.4242311507936508, top6 acc: 0.6486855158730159, top8 acc: 0.8802083333333334\n",
    "Rule wise top8 acc: 0.8802083333333334\n",
    "mAP score: 0.8566623437566603\n",
    "Average mask density: 0.3280833968227463\n",
    "\n",
    "    \n",
    "\n",
    "\n",
    "        #explainer\n",
    "            #wl\n",
    "            Average Sim:  0.20187916547869134 Average Sim top10:  0.20268048825020357\n",
    "            #subgraph\n",
    "            Average Sim:  0.37156941061590376 Average Sim top10:  0.3727175356089619\n",
    "\n",
    "\n",
    "Rule wise top4 acc: 0.4446924603174603, top6 acc: 0.6630704365079365, top8 acc: 0.8660714285714286\n",
    "Average mask density: 0.27635921317795736\n",
    "    \n",
    "Rule wise top4 acc: 0.43861607142857145, top6 acc: 0.6515376984126984, top8 acc: 0.8550347222222222\n",
    "mAP score: 0.8095546292429388\n",
    "Average mask density: 0.30401947132740464\n",
    "Total graphs optimized:  1008\n",
    "    \n",
    "\n",
    "\n",
    "\n",
    "\n",
    "    #label=>1\n",
    "\n",
    "    #explain_boundary  : gt accuracy 76%  , md = 40%\n",
    "    \n",
    "        #wl\n",
    "            #nearest nbr\n",
    "                Average Sim:  0.21125174857847775 Average Sim top10:  0.21117583114966443  #md = 40% 16 flips\n",
    "                \n",
    "\n",
    "             #projections\n",
    "                #wl \n",
    "                Average Sim:  0.20544961000728287 Average Sim top10:  0.2053958762502683  #md = 40%\n",
    "               \n",
    "\n",
    "        #subgraph\n",
    "        \n",
    "            #nearest nbr \n",
    "                Average Sim:  0.37574226756594803 Average Sim top10:  0.3769780155728799\n",
    "                        \n",
    "            #projections\n",
    "                Average Sim:  0.36534551106851626 Average Sim top10:  0.36570882573919916\n",
    "                        \n",
    "                        \n",
    "Boundary wise top4 acc: 0.3682332677165354, top6 acc: 0.5725885826771654, top8 acc: 0.7743602362204725\n",
    "Rule wise top8 acc: 0.7743602362204725\n",
    "mAP score: 0.6979425923751065\n",
    "Average mask density: 0.4147253256789812\n",
    "(20, 20)\n",
    "Flips:  19.0\n",
    "                        \n",
    "    #explain #md=43%\n",
    "    \n",
    "        #wl\n",
    "            Average Sim:  0.19888569870152922 Average Sim top10:  0.19996254126560054\n",
    "        #subgraph\n",
    "            Average Sim:  0.3672896544438794 Average Sim top10:  0.36945677741494587\n",
    "            \n",
    "            \n",
    "                    \n",
    "\n",
    "                \n",
    "                \n",
    "#synthetic boundary\n",
    "Boundary wise top4 acc: 0.41756889763779526, top6 acc: 0.6343503937007874, top8 acc: 0.8389517716535433  \n",
    "Average mask density: 0.3373429980660664\n",
    "    \n",
    "#explain\n",
    "Rule wise top4 acc: 0.42322834645669294, top6 acc: 0.6290600393700787, top8 acc: 0.8240649606299213\n",
    "Average mask density: 0.3567382504184562\n",
    "    \n",
    "Rule wise top4 acc: 0.4015748031496063, top6 acc: 0.5965797244094488, top8 acc: 0.7854330708661418\n",
    "mAP score: 0.7123143433242682\n",
    "Average mask density: 0.40513322795887924\n",
    "    \n",
    "    \n",
    "    \n",
    "    \n",
    "#label = 2\n",
    "\n",
    "    #boundary\n",
    "            \n",
    "    \n",
    "                #wl\n",
    "\n",
    "\n",
    "                #subgraph\n",
    "                #top16 Average Sim:  0.4354027420865193 Average Sim top10:  0.43820961975133044\n",
    "\n",
    "            #projection\n",
    "                #top16 Average Sim:  0.41725636460462506 Average Sim top10:  0.41894670402498824\n",
    "    \n",
    "        Boundary wise top4 acc: 0.3853739754098361, top6 acc: 0.5959272540983607, top8 acc: 0.8018698770491803\n",
    "        Rule wise top8 acc: 0.8018698770491803\n",
    "        Average mask density: 0.3390896980513315\n",
    "\n",
    "\n",
    "        Boundary wise top4 acc: 0.4369877049180328, top6 acc: 0.6501024590163934, top8 acc: 0.8583504098360656\n",
    "        Rule wise top8 acc: 0.8583504098360656\n",
    "        mAP score: 0.8567959981313538\n",
    "        Average mask density: 0.2786614683868944\n",
    "        (20, 20)\n",
    "        Flips:  61.0    \n",
    "            \n",
    "        Boundary wise top4 acc: 0.41226946721311475, top6 acc: 0.623719262295082, top8 acc: 0.8358094262295082\n",
    "        Rule wise top8 acc: 0.8358094262295082\n",
    "        mAP score: 0.8300776163261389\n",
    "        Average mask density: 0.3092741308520075\n",
    "        (20, 20)\n",
    "        Flips:  16.0\n",
    "\n",
    "    #explain\n",
    "        Rule wise top4 acc: 0.43967725409836067, top6 acc: 0.6498463114754098, top8 acc: 0.8121157786885246\n",
    "        Average mask density: 0.22846055748399163\n",
    "        Total graphs optimized:  976\n",
    "        Flips:  0.0\n",
    "            \n",
    "        Rule wise top4 acc: 0.43148053278688525, top6 acc: 0.6509989754098361, top8 acc: 0.8422131147540983\n",
    "        mAP score: 0.7987645548750902\n",
    "        Average mask density: 0.24725454080788817\n",
    "        Total graphs optimized:  976\n",
    "        Flips:  0.0\n",
    "        Incorrect preds:  0\n",
    "            \n",
    "        \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[   5  103  767  441  843 1400  375 1154 1231 1574 1141 1018    2  653\n",
      " 1696  118  160  902 1299  138 1478]\n",
      "[   5  953  667  919 1701 1312 1429 1038  923 1066 1700  155  293 1222\n",
      " 1576  774 1510  257  654 1043 1556]\n"
     ]
    }
   ],
   "source": [
    "# pickle.dump(projections, open(\"../../gcn_interpretation/data/Mutagenicity/projections_label0.p\",\"wb\"))\n",
    "graph_emb_np = graph_emb_cat.numpy()\n",
    "from sklearn.metrics.pairwise import euclidean_distances\n",
    "orig_i=5\n",
    "i = orig_i\n",
    "arr1 = graph_emb_np[i:i+1,:]\n",
    "dist = euclidean_distances(graph_emb_np, arr1)\n",
    "nbrs = dist[:,0].argsort()[:21]\n",
    "print(nbrs)\n",
    "arr1 = projections[i:i+1,:]\n",
    "dist = euclidean_distances(projections, arr1)\n",
    "p_nbrs = dist[:,0].argsort()[:21]\n",
    "print(p_nbrs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import networkx as nx\n",
    "from grakel.utils import graph_from_networkx\n",
    "\n",
    "from grakel.kernels import WeisfeilerLehman, VertexHistogram\n",
    "from grakel import Graph\n",
    "from grakel.kernels import PropagationAttr, ShortestPath, SubgraphMatching\n",
    "\n",
    "from matplotlib.backends.backend_pdf import PdfPages\n",
    "\n",
    "# node_labels = ['A','B','C','D','E','F','G','H','I','J','K','L']\n",
    "node_labels = ['C','O','Cl','H','N','F','Br','S','P','I','Na','K','Li','Ca']\n",
    "\n",
    "def node_match_func(n1, n2):\n",
    "#     print(\"n1:\", n1['feat'])\n",
    "#     print(\"n2: \", n2['feat'])\n",
    "    n1q = (np.sum(n1['feat'] == n2['feat']) == 14)\n",
    "#     print(\"n1q: \", n1q)\n",
    "    return n1q\n",
    "\n",
    "def getGrakelGraph(nx_G):\n",
    "    mat = nx.to_numpy_matrix(nx_G)\n",
    "    attr = nx.get_node_attributes(nx_G,'label')\n",
    "    node_labels = {}\n",
    "    attr_count = 0\n",
    "    for k in range(100):\n",
    "        if k in attr:\n",
    "            node_labels[attr_count] = str(attr[k])\n",
    "            attr_count += 1\n",
    "    mat_dict = {}\n",
    "    edge_labels = {}\n",
    "    for i in range(mat.shape[0]):\n",
    "        if i not in mat_dict:\n",
    "            mat_dict[i] = []\n",
    "        for j in range(mat.shape[1]):\n",
    "            if mat[i,j] > 0. and i != j:\n",
    "                mat_dict[i].append(j)\n",
    "                edge_labels[(i,j)] = 'a'\n",
    "                \n",
    "    print(mat.shape)\n",
    "    print(node_labels)\n",
    "    print(mat_dict)\n",
    "#     G_t = Graph(mat_dict, node_labels=node_labels)\n",
    "    G_t = Graph(mat_dict, node_labels=node_labels, edge_labels=edge_labels)\n",
    "    return G_t\n",
    "\n",
    "        \n",
    "def getNXGraph(adj, feat, thresh=None, indx = None, edge_thresh = 16):\n",
    "    num_nodes = adj.shape[-1]\n",
    "    G = nx.Graph()\n",
    "    G.add_nodes_from(range(num_nodes))\n",
    "    node_hlight_dict = {}\n",
    "    for node in G.nodes():\n",
    "        G.nodes[node][\"feat\"] = feat[node]\n",
    "        G.nodes[node][\"label\"] = np.argmax(feat[node])\n",
    "    \n",
    "    weighted_edge_list = []\n",
    "    if indx is not None:\n",
    "        hnodes = []\n",
    "        for b in range(2):\n",
    "#             hnodes.extend(hnodes_dict[(indx,b)])\n",
    "            b_hnodes = hnodes_dict[(indx,b)]\n",
    "            for i in range(len(b_hnodes)):\n",
    "                weighted_edge_list.append((b_hnodes[i], b_hnodes[((i+1)%4)], 1.0))\n",
    "                weighted_edge_list.append((b_hnodes[((i+1)%4)], b_hnodes[i], 1.0))\n",
    "                \n",
    "#         for i in range(num_nodes):\n",
    "#             if i not in hnodes:\n",
    "#                 continue\n",
    "#             for j in range(num_nodes):\n",
    "#                 if j not in hnodes:\n",
    "#                     continue\n",
    "#                 weighted_edge_list.append((i, j, 1.0))\n",
    "    else:\n",
    "        \n",
    "#         if thresh == None:\n",
    "#             threshold_num = 20\n",
    "#             # this is for symmetric graphs: edges are repeated twice in adj\n",
    "#             adj_threshold_num = threshold_num * 2\n",
    "#             #adj += np.random.rand(adj.shape[0], adj.shape[1]) * 1e-4\n",
    "#             neigh_size = len(adj[adj > 0])\n",
    "#             threshold_num = min(neigh_size, adj_threshold_num)\n",
    "#             threshold = np.sort(adj[adj > 0])[-threshold_num]\n",
    "#             thresh = threshold\n",
    "#     #     thresh = 0.6\n",
    "\n",
    "        if thresh != None:\n",
    "            for i in range(num_nodes):\n",
    "                for j in range(num_nodes):\n",
    "\n",
    "                    if adj[i, j] > thresh:\n",
    "                        weighted_edge_list.append((i, j, 1.0))\n",
    "#                 valid_edges += 1.0\n",
    "\n",
    "        else:\n",
    "        \n",
    "            argsort_adj = np.dstack(np.unravel_index(np.argsort(adj.ravel()), (num_nodes, num_nodes)))[0]\n",
    "            edge_count = 0.\n",
    "            max_p = -1\n",
    "            for i in range(num_nodes*num_nodes-1,-1,-1):\n",
    "                x = argsort_adj[i][0]\n",
    "                y = argsort_adj[i][1]\n",
    "                if adj[x, y] < 0.75*max_p:\n",
    "                    break\n",
    "                if max_p == -1:\n",
    "                    max_p = adj[x,y]\n",
    "                if adj[x, y] > 0.:\n",
    "                    node_hlight_dict[x] = 1\n",
    "                    node_hlight_dict[y] = 1\n",
    "                    weighted_edge_list.append((x, y, 1.0))\n",
    "                    edge_count += 1.0\n",
    "                print(edge_count, adj[x, y])\n",
    "                \n",
    "                if edge_count >= edge_thresh:\n",
    "                    break\n",
    "                \n",
    "        \n",
    "            \n",
    "            \n",
    "            \n",
    "            \n",
    "            \n",
    "    G.add_weighted_edges_from(weighted_edge_list)\n",
    "    remove_nodes = list(nx.isolates(G))\n",
    "#     print(\"remove nodes: \", remove_nodes)\n",
    "\n",
    "    G.remove_nodes_from(remove_nodes)\n",
    "    return G, node_hlight_dict\n",
    "\n",
    "\n",
    "\n",
    "vmax=19\n",
    "cmap = plt.get_cmap(\"tab20\")\n",
    "\n",
    "\n",
    "proj = False\n",
    "\n",
    "max_nbrs = 20\n",
    "fig, ax_l = plt.subplots(2,1, figsize=(30,50))\n",
    "# plt.switch_backend(\"agg\")\n",
    "\n",
    "orig_j = orig_i\n",
    "idx = l_idx[orig_j]\n",
    "print(\"idx: \", idx)\n",
    "\n",
    "m_adj = masked_adjs[orig_j]\n",
    "print(\"m_adj: \", m_adj.shape)\n",
    "print(\"mask density: \", np.sum(m_adj)/np.sum(ckpt['cg']['adj'][idx].numpy()))\n",
    "print(\"num nodes: \", ckpt['cg']['num_nodes'][idx].item())\n",
    "# m_adj = ckpt['cg']['adj'][idx].numpy()\n",
    "f = ckpt['cg']['feat'][idx].numpy()\n",
    "\n",
    "G1, highlight_dict = getNXGraph(m_adj, f)\n",
    "\n",
    "node_colors = []\n",
    "labels_dict = {}\n",
    "node_hlight = []\n",
    "for n in G1.nodes:\n",
    "    lbl = np.argmax(G1.nodes[n]['feat'])\n",
    "    node_colors.append(lbl)\n",
    "    labels_dict[n] = node_labels[lbl]\n",
    "    \n",
    "\n",
    "\n",
    "    \n",
    "\n",
    "\n",
    "# nx.draw_networkx(G1, font_size=20, ax=ax_l[0],  vmax=19, vmin=0,\n",
    "#                      node_size=900, cmap=cmap, alpha=0.8, node_color=node_colors, labels = labels_dict)\n",
    "\n",
    "\n",
    "adj = ckpt['cg']['adj'][idx].numpy()\n",
    "G_orig, _ = getNXGraph(adj, f, thresh=0.1)\n",
    "\n",
    "node_colors = []\n",
    "labels_dict = {}\n",
    "for n in G_orig.nodes:\n",
    "    lbl = np.argmax(G_orig.nodes[n]['feat'])\n",
    "    node_colors.append(lbl)\n",
    "    labels_dict[n] = node_labels[lbl]\n",
    "    if n in highlight_dict:\n",
    "        node_hlight.append((0.9,0.2,0.2))\n",
    "    else:\n",
    "        node_hlight.append((0.9,0.9,0.9))\n",
    "\n",
    "pos_layout = nx.kamada_kawai_layout(G_orig, weight=None)\n",
    "nx.draw_networkx(G_orig, pos=pos_layout, font_size=20, ax=ax_l[0], \n",
    "                     node_size=900, alpha=0.8, node_color=node_hlight, labels = labels_dict)\n",
    "    \n",
    "nx.draw_networkx(G_orig, pos=pos_layout, font_size=20, ax=ax_l[1],  vmax=19, vmin=0,\n",
    "                     node_size=900, cmap=cmap, alpha=0.8, node_color=node_colors, labels = labels_dict)\n",
    "if proj:\n",
    "    pdffilepath = './visualizations/'+ str(orig_j) + 'mutag0_20nbrs_projspace_boundary_maxp.pdf'\n",
    "else:\n",
    "    pdffilepath = './visualizations/'+ str(orig_j) + 'mutag0_20nbrs_fspace_boundary_maxp.pdf'\n",
    "\n",
    "pdf = PdfPages(pdffilepath)\n",
    "\n",
    "\n",
    "pdf.savefig(fig)  # saves the current figure into a pdf page\n",
    "plt.close()\n",
    "\n",
    "G1_t = getGrakelGraph(G1)\n",
    "# gk = WeisfeilerLehman(normalize=True, base_graph_kernel=VertexHistogram)\n",
    "# gk = WeisfeilerLehman(normalize=True, base_graph_kernel=ShortestPath)#PropagationAttr)\n",
    "gk = SubgraphMatching(normalize=True)\n",
    "self_sim = gk.fit_transform([G1_t])\n",
    "\n",
    "if proj:\n",
    "    s_nbrs = p_nbrs\n",
    "else:\n",
    "    s_nbrs = nbrs\n",
    "    \n",
    "for j in s_nbrs[1:max_nbrs]:\n",
    "    fig, ax_l = plt.subplots(2,1, figsize=(30,50))\n",
    "\n",
    "\n",
    "    idx = l_idx[j]\n",
    "    print(\"j idx: \", idx)\n",
    "    m_adj = masked_adjs[j]\n",
    "    print(\"m_adj: \", m_adj.shape)\n",
    "    print(\"mask density: \", np.sum(m_adj)/np.sum(ckpt['cg']['adj'][idx].numpy()))\n",
    "    print(\"num nodes: \", ckpt['cg']['num_nodes'][idx].item())\n",
    "    # m_adj = ckpt['cg']['adj'][idx].numpy()\n",
    "    f = ckpt['cg']['feat'][idx].numpy()\n",
    "    G2, highlight_dict = getNXGraph(m_adj, f)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "    G2_t = getGrakelGraph(G2)\n",
    "\n",
    "    \n",
    "\n",
    "    sim = gk.transform([G2_t])\n",
    "    \n",
    "    print(\"sim: \", sim)\n",
    "\n",
    "\n",
    "    node_colors = []\n",
    "    labels_dict = {}\n",
    "    node_hlight = []\n",
    "\n",
    "    for n in G2.nodes:\n",
    "        lbl = np.argmax(G2.nodes[n]['feat'])\n",
    "        node_colors.append(lbl)\n",
    "        labels_dict[n] = node_labels[lbl]\n",
    "    #     labels_dict[n] = str(lbl)\n",
    "\n",
    "    # pos_layout = nx.kamada_kawai_layout(G2, weight=None)\n",
    "\n",
    "    # nx.draw_networkx(G2, font_size=20, ax=ax_l[2],\n",
    "    #                      node_size=900, cmap=cmap, alpha=0.8, vmax=19, vmin=0, node_color=node_colors, labels = labels_dict)\n",
    "\n",
    "\n",
    "    adj = ckpt['cg']['adj'][idx].numpy()\n",
    "\n",
    "    G_orig, _ = getNXGraph(adj, f, thresh=0.1)\n",
    "    node_colors = []\n",
    "    labels_dict = {}\n",
    "    for n in G_orig.nodes:\n",
    "        lbl = np.argmax(G_orig.nodes[n]['feat'])\n",
    "        node_colors.append(lbl)\n",
    "        labels_dict[n] = node_labels[lbl]\n",
    "        if n in highlight_dict:\n",
    "            node_hlight.append((0.9,0.2,0.2))\n",
    "        else:\n",
    "            node_hlight.append((0.9,0.9,0.9))\n",
    "            \n",
    "    \n",
    "    pos_layout = nx.kamada_kawai_layout(G_orig, weight=None)\n",
    "\n",
    "    nx.draw_networkx(G_orig, pos=pos_layout, font_size=20, ax=ax_l[0], \n",
    "                         node_size=900, alpha=0.8, node_color=node_hlight, labels = labels_dict)\n",
    "\n",
    "    nx.draw_networkx(G_orig, pos = pos_layout, font_size=20, ax=ax_l[1],  vmax=19, vmin=0,\n",
    "                         node_size=900, cmap=cmap, alpha=0.8, node_color=node_colors, labels = labels_dict)\n",
    "    \n",
    "    page_title = 'j= ' + str(j) + ' : indx =' + str(idx)  \n",
    "    plt.title(page_title)\n",
    "    pdf.savefig(fig)  # saves the current figure into a pdf page\n",
    "    plt.close()\n",
    "\n",
    "\n",
    "pdf.close()\n",
    "\n",
    "\n",
    "# plt.show()\n",
    "# plt.savefig(\"./visualizations/\"+str(orig_j) +\"_nbrs.pdf\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAfAAAAI4CAYAAACV/7uiAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deZyWdb3/8ddHFgmTRSNikcWdJR0RJcuDWyKiHjf0hP5OoBRathxzyV+2oNFiaqC/LI+ViumR0oOKW0qopaUQ6JgLormyqBDiAmiKfH9/3BfTgMMwg9z3zHd4PR+P+zHXfn2umWvu9319r+WOlBKSJCkvWzR1AZIkqfEMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuNQEIuLEiLi7qetoziJi/4hYsAmX962I+NUmWM74iLh2U9QkfRgGuLIWESdExOyIWB4RL0fEnRGxb1PXtSEppetSSsOauo7mJCJSROy4iZb1gfBPKf0wpfSFTbH8RtRxdURMqOQ6tfkwwJWtiPgGMAn4IdAV6AX8HDiyKevakIho3dQ1SMqfAa4sRURH4HzgtJTS1JTSipTSeymlW1NKZxXTbBkRkyJiUfGaFBFbFuP2j4gFEXF2RCwujt6PiogREfF0RLwWEd+qtb7xEXFjRPw2It6KiIcjYvda48+JiGeLcU9GxNG1xo2JiD9HxMSIWAqML4Y9UIyPYtziiHgzIh6LiIFrtjMiromIJRHxYkR8OyK2qLXcByLioohYFhHPR8Sh9fzO+kXEfRHxekQ8ERH/Xmvc1RFxWUTcXmzDzIjYYT3LaRcR10bE0mJZf42IrhFxXETMWWfab0TELRtaR0T8qZjl0aI15T9qLeOMWn+jk2oN37LY9pci4tWIuDwiPhIRWwF3At2LZS2PiO7rNn1HxL4R8ZdiG+ZHxJj1bG/fiPhjUfN04GPrjL8hIl6JiDci4k8RMaAYPg44ETi7qOHWYvh69xWpUVJKvnxl9wKGA6uA1vVMcz7wEPBxoAvwF+D7xbj9i/m/C7QBvggsAf4H2BoYALwN9C2mHw+8B4wspj8TeB5oU4w/DuhO6UPxfwArgG7FuDHFur4KtAY+Ugx7oBh/CDAH6AQE0K/WvNcAtxQ19QGeBsbWWu57Re2tgC8Bi4Co43fRBvg78C2gLXAg8BawSzH+amApsHdR43XAlPX8Xk8BbgXaF+vdE+gAbAm8BvSrNe0jwLENWQeQgB1r9a/5G51f1D8CWAl0LsZPBKYB2xS/n1uBH9Wad8E6dY8Hri26exfbP6pY9rZA1Xq290Hgp8X2DS3mu7bW+JOL9W9JqUWouta4q4EJ6yxvvfuKL1+NeTV5Ab58bcyL0pHNKxuY5llgRK3+Q4AXiu79KQV0q6J/6yJAhtSafg5wVNE9Hnio1rgtgJeBf1vPuquBI4vuMcBL64wfw78C/EBKwfwpYIta07QC3gX61xp2CnBfrWX8vda49sU2fKKOev4NeGWd5V8PjC+6rwZ+VWvcCOCp9WzbyZQ+DO1Wx7hfAD8ougcAy4AtG7IO6g7wt6n1IQ1YXPyeogi+HWqN2wd4vta89QX4/wVuasB+1ovSh4itag37H2oF+DrTdyq2o2OtbZ6wgXXU7Cu+fDXmZRO6crUU+NgGzid3B16s1f9iMaxmGSml94vut4ufr9Ya/zbw0Vr989d0pJRWAwvWLC8iPh8R1UVz7OvAQNZuap3PeqSU7gF+BlwGLI6IKyKiQzF/mzq2oUet/ldqLWdl0Vm75jW6A/OLuje4LEpHunUtB+A3wF3AlCidmvhJRLQpxk0GToiIAP4T+F1K6Z8bsY41lqaUVtUxTxdKH1jm1Pqd/74Y3hDbUfqAtyHdgWUppRW1htX8PSKiVUT8uGgSfxN4oRi1VjN7bQ3YV6QGMcCVqweBfwJH1TPNIkpNpWv0KoZtrO3WdBTnoXsCiyKiN/BL4CvAtimlTsDjlI4S16j3a/9SSpemlPYE+gM7A2cB/6DURL7uNizciNoXAdutOX/+YZaVStcanJdS6g98Gjgc+Hwx7iFKrQb/BpxAKezL4R+UPmANSCl1Kl4dU0prPhBs6GsW5wN1nuNfx8tA5+K8+hq9anWfQOmiyc8CHSmd5oB//e3XqqOB+4rUIAa4spRSeoPS+evLonTxWfuIaBMRh0bET4rJrge+HRFdIuJjxfQf5v7dPSPimOKo/78ofYB4CNiK0hv1EoDiQquBDV1oROwVEUOKo9gVwDvA6qJ14HfADyJi6+LN/xsbuQ0zKR29nl38nvYHjgCmNHZBEXFARHwyIloBb1L6kFH7yP4aSi0K76WUHmjEol8Ftm/IhEVLwi+BiRHx8aKuHhFxSK1lbRulix3rch3w2Yg4PiJaR8S2EVFVx3peBGYD50VE2yjdonhErUm2prQfLKXUIvDDDWzTh9pXpNoMcGUrpXQxpUD7NqU3xPmUjmxuLiaZQOnN92/AY8DDxbCNdQuli46WUWoePqY4Gn0SuJhSq8CrwCeBPzdiuR0ohdEySs2zS4ELi3FfpRTqzwEPUDr/emVjC08pvUspeA6ldPT6c+DzKaWnGrss4BPAjZTCey7wR9Y+0v4NpVBq7AeN8cDkomn5+AZM/01KF+Y9VDRf/wHYBaDYruuB54rl1T51QkrpJUrn4M+gdOFdNbA7dTsBGFJM9z1KH1DWuIbS32wh8CSlD3S1/RroX9Rw8ybYV6QakdKGWpokRcR4ShdY/Z+mrqW5i4iPULrYbFBK6ZmmrkdqqTwCl7SpfQn4q+EtlZdPhJK0yUTEC5QuyKrv4kJJm4BN6JIkZcgmdEmSMpRFE/rHPvax1KdPn6YuQ5KkipszZ84/UkofeEhRFgHep08fZs+e3dRlSJJUcRHxYl3DbUKXJClDBrgkSRkywCVJylAW58ClHLz33nssWLCAd955p6lLUQO0a9eOnj170qZNmw1PLDVDBri0iSxYsICtt96aPn36UPo2TTVXKSWWLl3KggUL6Nu3b1OXI20Um9ClTeSdd95h2223NbwzEBFsu+22tpYoawa4tAkZ3vnwb6XcGeCSJGXIc+BSmfQ55/ZNurwXfnzYBqd59dVXOf3003nooYfo3Lkzbdu25eyzz6Zz584ccMABTJs2jSOOOAKAww8/nDPPPJP999+f/fffn+XLl9c8MGn27NmceeaZ3HfffR9Yx5gxY/jjH/9Ix44dAWjfvj1/+ctfGr099913HxdddBG33XbbeqeZPXs211xzDZdeemmjly+1dB6BSy1ESomjjjqKoUOH8txzzzFnzhymTJnCggULAOjZsyc/+MEP1jv/4sWLufPOOxu0rgsvvJDq6mqqq6s3KrwbavDgwZssvN9///1NshypuTDApRbinnvuoW3btpx66qk1w3r37s1Xv/pVAHbffXc6duzI9OnT65z/rLPOqjfgN+TrX/86559/PgB33XUXQ4cOZfXq1YwZM4ZTTz2VwYMHs/POO9d5xD1r1iz22Wcf9thjDz796U8zb948oHSUfvjhhwMwfvx4Tj75ZPbff3+23377tYL92muvZe+996aqqopTTjmlJqw/+tGPcsYZZ7D77rvz4IMPbvS2Sc2RAa4aEydOZMCAAQwcOJBRo0bxzjvvMGbMGPr27UtVVRVVVVVUV1d/YL577723ZnxVVRXt2rXj5ptvboIt2Lw98cQTDBo0qN5pzj33XCZMmFDnuH322Ye2bdty7733bnBdZ511Vs3f+8QTTwTgRz/6Eb/97W+59957+drXvsZVV13FFluU3mJeeOEFZs2axe23386pp576gau/d911V+6//34eeeQRzj//fL71rW/Vud6nnnqKu+66i1mzZnHeeefx3nvvMXfuXH7729/y5z//merqalq1asV1110HwIoVKxgyZAiPPvoo++677wa3S/nZnN+3PAcuABYuXMill17Kk08+yUc+8hGOP/54pkyZApSaS0eOHLneeQ844ICaf5DXXnuNHXfckWHDhlWkbq3faaedxgMPPEDbtm258MILARg6dCgADzzwQJ3zfPvb32bChAlccMEF9S67rn2iffv2/PKXv2To0KFMnDiRHXbYoWbc8ccfzxZbbMFOO+3E9ttvz1NPPbXWvG+88QajR4/mmWeeISJ477336lzvYYcdxpZbbsmWW27Jxz/+cV599VVmzJjBnDlz2GuvvQB4++23+fjHPw5Aq1atOPbYY+vdFuVrc3/f8ghcNVatWsXbb7/NqlWrWLlyJd27d2/0Mm688UYOPfRQ2rdvX4YKVZ8BAwbw8MMP1/RfdtllzJgxgyVLlqw1XX1H4QceeCBvv/02Dz30UM2wk046iaqqKkaMGLHBGh577DG23XZbFi1atNbwdW/ZWrf/O9/5DgcccACPP/44t95663rvz95yyy1rulu1asWqVatIKTF69Oiac/Lz5s1j/PjxQOlpa61atdpg3crX5vy+ZYALgB49enDmmWfSq1cvunXrRseOHWs+jZ577rnstttunH766fzzn/+sdzlTpkxh1KhRlShZ6zjwwAN55513+MUvflEzbOXKlR+YbtiwYSxbtoy//e1vdS7n29/+Nj/5yU9q+q+66iqqq6u544476l3/iy++yMUXX8wjjzzCnXfeycyZM2vG3XDDDaxevZpnn32W5557jl122WWted944w169OgBwNVXX73Bba3toIMO4sYbb2Tx4sVA6WjqxRfr/PZFtTCb+/uWTegCYNmyZdxyyy08//zzdOrUieOOO45rr72WH/3oR3ziE5/g3XffZdy4cVxwwQV897vfrXMZL7/8Mo899hiHHHJIhatvnhpy29emFBHcfPPNnH766fzkJz+hS5cubLXVVnU2h5977rkceeSRdS5nxIgRdOnSpd51nXXWWWsdxc+cOZOxY8dy0UUX0b17d379618zZswY/vrXvwLQq1cv9t57b958800uv/xy2rVrt9byzj77bEaPHs2ECRM47LDG/d769+/PhAkTGDZsGKtXr6ZNmzZcdtll9O7du1HLUX42+/etlFKzf+25555J5fW73/0unXzyyTX9kydPTl/60pfWmubee+9Nhx122HqXMWnSpPTFL36xbDU2d08++WRTl9AsjR49Ot1www1NXUad/JvlbXN53wJmpzqy0SZ0AaUjpIceeoiVK1eSUmLGjBn069ePl19+GSh90Lv55psZOHDgepdx/fXXZ9kMJSlPm/v7lk3oAmDIkCGMHDmSQYMG0bp1a/bYYw/GjRvHoYceypIlS0gpUVVVxeWXXw6UnpB1+eWX86tf/Qoo3SY0f/589ttvv6bcDDVDjT2nLTXU5v6+FaWj8+Zt8ODBac0jHqXmau7cufTr16+py1Aj+DdTDiJiTkpp8LrDbUKXJClDBrgkSRkywCVJypAXsbUU4zs2dQWb3vg3mrqCD2dT/00a8PtYsGABp512Gk8++SSrV6/m8MMP58ILL6Rt27abthZpU/B960PxCFxqIVJKHHPMMRx11FE888wzPP300yxfvpxzzz23wcvwKzelfBjgUgtxzz330K5dO0466SSg9KzwiRMncuWVV/Lzn/+cr3zlKzXTHn744dx3333AB79y85xzzqF///7stttunHnmmU2xKZIawCZ0qYV44okn2HPPPdca1qFDB3r16sWqVavWO9+ar9y8+OKLWbp0KWPHjuWpp54iInj99dfLXbakjeQRuLSZq/2Vmx07dqRdu3aMHTuWqVOnZvftTNLmxACXWoj+/fszZ86ctYa9+eabvPTSS3Tq1InVq1fXDK/9dZ21v3KzdevWzJo1i5EjR3LbbbcxfPjwyhQvqdEMcKmFOOigg1i5ciXXXHMNULog7YwzzmDMmDFsv/32VFdXs3r1aubPn8+sWbPqXMby5ct54403GDFiBBMnTuTRRx+t5CZIagTPgUvlUuHb4CKCm266iS9/+ct8//vfZ/Xq1YwYMYIf/vCHtG3blr59+9K/f3/69evHoEGD6lzGW2+9xZFHHsk777xDSomf/vSnFd0GSQ1ngEstyHbbbcett95a57jrrruuzuHLly+v6e7Wrdt6j84lNS82oUuSlCEDXJKkDBng0iaUw9fzqsS/lXJngEubSLt27Vi6dKnBkIGUEkuXLqVdu3ZNXYq00byITdpEevbsyYIFC1iyZElTl6IGaNeuHT179mzqMqSNZoBLm0ibNm3o27dvU5chaTNhE/pGmjhxIgMGDGDgwIGMGjWKd955h5/97GfsuOOORAT/+Mc/1jvv5MmT2Wmnndhpp52YPHlyBauWJLUUBvhGWLhwIZdeeimzZ8/m8ccf5/3332fKlCl85jOf4Q9/+AO9e/de77yvvfYa5513HjNnzmTWrFmcd955LFu2rILVS5JaAgN8I61atYq3336bVatWsXLlSrp3784ee+xBnz596p3vrrvu4uCDD2abbbahc+fOHHzwwfz+97+vTNGSpBbDAN8IPXr04Mwzz6RXr15069aNjh07MmzYsAbNu3DhQrbbbrua/p49e7Jw4cJylSpJaqEM8I2wbNkybrnlFp5//nkWLVrEihUruPbaa5u6LEnSZsQA3wh/+MMf6Nu3L126dKFNmzYcc8wx/OUvf2nQvD169GD+/Pk1/QsWLKBHjx7lKlWS1EIZ4BuhV69ePPTQQ6xcuZKUEjNmzKBfv34NmveQQw7h7rvvZtmyZSxbtoy7776bQw45pMwVS5JaGgN8IwwZMoSRI0cyaNAgPvnJT7J69WrGjRvHpZdeWvMwj912240vfOELAMyePbume5tttuE73/kOe+21F3vttRff/e532WabbZpycyRtBubNm0dVVVXNq0OHDkyaNIlHH32UffbZh09+8pMcccQRvPnmm3XOf8kllzBw4EAGDBjApEmTKly96hI5PPZx8ODBafbs2U1dRvM2vmNTV7DpVfj7tKXNxfvvv0+PHj2YOXMmI0eO5KKLLmK//fbjyiuv5Pnnn+f73//+WtM//vjjfO5zn2PWrFm0bduW4cOHc/nll7Pjjjt+uEJ832qQiJiTUhq87nCPwCVpMzNjxgx22GEHevfuzdNPP83QoUMBOPjgg/nf//3fD0w/d+5chgwZQvv27WndujX77bcfU6dOrXTZWocBLkmbmSlTpjBq1CgABgwYwC233ALADTfcsNZFtmsMHDiQ+++/n6VLl7Jy5UruuOOOOqdTZRngkrQZeffdd5k2bRrHHXccAFdeeSU///nP2XPPPXnrrbdo27btB+bp168f3/zmNxk2bBjDhw+nqqqKVq1aVbp0rWOz/DKTPufc3tQlbHIv+K2IkhrgzjvvZNCgQXTt2hWAXXfdlbvvvhuAp59+mttvr/v9cezYsYwdOxaAb33rW36TWzPgEbgkbUauv/76muZzgMWLFwOwevVqJkyYwKmnnlrnfGume+mll5g6dSonnHBC+YtVvQxwSdpMrFixgunTp3PMMcfUDLv++uvZeeed2XXXXenevTsnnXQSAIsWLWLEiBE10x177LH079+fI444gssuu4xOnTpVvH6tray3kUXE6cAXgAQ8BpwEdAOmANsCc4D/TCm9W99yNvVtZC2zCb0Ffhr2NjKpZfM2sgap+G1kEdED+BowOKU0EGgFfA64AJiYUtoRWAaMLVcNkiS1VOVuQm8NfCQiWgPtgZeBA4Ebi/GTgaPKXIMkSS1O2QI8pbQQuAh4iVJwv0Gpyfz1lNKqYrIFQJ3f5BER4yJidkTMXrJkSbnKlCQpS2W7jSwiOgNHAn2B14EbgOENnT+ldAVwBZTOgZejRknKRcu8dqepK8hbOZvQPws8n1JaklJ6D5gKfAboVDSpA/QEFpaxBkmSWqRyBvhLwKcion1EBHAQ8CRwLzCymGY0cEsZa5AkqUUq5znwmZQuVnuY0i1kW1BqEv8m8I2I+DulW8l+Xa4aJElqqcr6KNWU0veA760z+Dlg73KuV5Kkls4nsUmSlCEDXJKkDBngkiRlyACXJClDBrgkSRkywCVJypABLqms5s2bR1VVVc2rQ4cOTJo0ifHjx9OjR4+a4XfccUed819yySUMHDiQAQMGMGnSpApXLzVfZb0PXJJ22WUXqqurAXj//ffp0aMHRx99NFdddRWnn346Z5555nrnffzxx/nlL3/JrFmzaNu2LcOHD+fwww9nxx13rFT5UrPlEbikipkxYwY77LADvXv3btD0c+fOZciQIbRv357WrVuz3377MXXq1DJXKeXBAJdUMVOmTGHUqFE1/T/72c/YbbfdOPnkk1m2bNkHph84cCD3338/S5cuZeXKldxxxx3Mnz+/kiVLzZYBLqki3n33XaZNm8Zxxx0HwJe+9CWeffZZqqur6datG2ecccYH5unXrx/f/OY3GTZsGMOHD6eqqopWrVpVunSpWTLAJVXEnXfeyaBBg+jatSsAXbt2pVWrVmyxxRZ88YtfZNasWXXON3bsWObMmcOf/vQnOnfuzM4771zJsqVmywCXVBHXX3/9Ws3nL7/8ck33TTfdxMCBA+ucb/HixQC89NJLTJ06lRNOOKG8hUqZ8Cp0SWW3YsUKpk+fzn//93/XDDv77LOprq4mIujTp0/NuEWLFvGFL3yh5rayY489lqVLl9KmTRsuu+wyOnXq1CTbIDU3Briksttqq61YunTpWsN+85vf1Dlt9+7d17on/P777y9rbVKubEKXJClDBrgkSRkywCVJypDnwCXVb3zHpq5g0xv/RlNXIH1oHoFLkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZMsAlScqQAS5JUoYMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuNSPz5s2jqqqq5tWhQwcmTZrEDTfcwIABA9hiiy2YPXt2nfPOnz+fAw44gP79+zNgwAAuueSSClcvqZL8PnCpGdlll12orq4G4P3336dHjx4cffTRrFy5kqlTp3LKKaesd97WrVtz8cUXM2jQIN566y323HNPDj74YPr371+p8iVVkAEuNVMzZsxghx12oHfv3g2avlu3bnTr1g2Arbfemn79+rFw4UIDXGqhbEKXmqkpU6YwatSojZr3hRde4JFHHmHIkCGbuCpJzYUBLjVD7777LtOmTeO4445r9LzLly/n2GOPZdKkSXTo0KEM1UlqDgxwqRm68847GTRoEF27dm3UfO+99x7HHnssJ554Isccc0yZqpPUHBjgUjN0/fXXN7r5PKXE2LFj6devH9/4xjfKVJmk5sIAl5qZFStWMH369LWOoG+66SZ69uzJgw8+yGGHHcYhhxwCwKJFixgxYgQAf/7zn/nNb37DPffcU3Mb2h133NEk2yCp/LwKXWpmttpqK5YuXbrWsKOPPpqjjz76A9N27969JqT33XdfUkoVqVFS0/MIXJKkDBngkiRlyACXJClDngOXNqE+59ze1CVsci+0a+oKJNXFI3BJkjJkgEuSlCEDXJKkDJUtwCNil4iorvV6MyL+KyK2iYjpEfFM8bNzuWqQJKmlKluAp5TmpZSqUkpVwJ7ASuAm4BxgRkppJ2BG0S9JkhqhUk3oBwHPppReBI4EJhfDJwNHVagGSZJajEoF+OeA64vurimll4vuV4A6v24pIsZFxOyImL1kyZJK1ChJUjbKHuAR0Rb4d+CGdcel0oOb63x4c0rpipTS4JTS4C5dupS5SkmS8lKJI/BDgYdTSq8W/a9GRDeA4ufiCtQgSVKLUokAH8W/ms8BpgGji+7RwC0VqEGSpBalrAEeEVsBBwNTaw3+MXBwRDwDfLbolyRJjVDWZ6GnlFYA264zbCmlq9IlSdJG8klskiRlyACXJClDBrgkSRkywCVJypABLklShgxwSZIyZIBLkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZMsAlScqQAS5JUoYMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuSVKGDHBJkjJkgEuSlCEDXJKkDBngkiRlyACXJClDBrgkSRkywCVJypABLklShgxwSZIyZIBLkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZMsAlScqQAS5JUoYMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuSVKGDHBJkjJkgEuSlCEDXJKkDBngkiRlqKwBHhGdIuLGiHgqIuZGxD4RsU1ETI+IZ4qfnctZgyRJLVG5j8AvAX6fUtoV2B2YC5wDzEgp7QTMKPolSVIjlC3AI6IjMBT4NUBK6d2U0uvAkcDkYrLJwFHlqkGSpJaqnEfgfYElwFUR8UhE/CoitgK6ppReLqZ5Beha18wRMS4iZkfE7CVLlpSxTEmS8lPOAG8NDAJ+kVLaA1jBOs3lKaUEpLpmTildkVIanFIa3KVLlzKWKUlSfsoZ4AuABSmlmUX/jZQC/dWI6AZQ/FxcxhokSWqRyhbgKaVXgPkRsUsx6CDgSWAaMLoYNhq4pVw1SJLUUrUu8/K/ClwXEW2B54CTKH1o+F1EjAVeBI4vcw2SJLU4ZQ3wlFI1MLiOUQeVc72SJLV0PolNkqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuSVKGDHBJkjJkgEuSlCEDXJKkDBngkiRlyACXJClDBrgkSRkywCVJypABLklShgxwSZIyZIBLkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZMsAlScqQAS5JUoYMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuSVKGDHBJkjJkgEuSlCEDXJKkDBngkiRlyACXJClDBrgkSRkywCVJypABLklShgxwSZIyZIBLkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZMsAlScqQAS5JUoYMcEmSMmSAS5KUIQNckqQMtS7nwiPiBeAt4H1gVUppcERsA/wW6AO8AByfUlpWzjokSWppKnEEfkBKqSqlNLjoPweYkVLaCZhR9EuSpEZoiib0I4HJRfdk4KgmqEGSpKyVO8ATcHdEzImIccWwrimll4vuV4Cudc0YEeMiYnZEzF6yZEmZy5QkKS9lPQcO7MlG5uYAABCQSURBVJtSWhgRHwemR8RTtUemlFJEpLpmTCldAVwBMHjw4DqnkSRpc1XWI/CU0sLi52LgJmBv4NWI6AZQ/FxczhokSWqJyhbgEbFVRGy9phsYBjwOTANGF5ONBm4pVw2SJLVUDW5Cj4j2KaWVjVh2V+CmiFiznv9JKf0+Iv4K/C4ixgIvAsc3pmBJktSAAI+ITwO/Aj4K9IqI3YFTUkpfrm++lNJzwO51DF8KHLRx5UqSJGhYE/pE4BBgKUBK6VFgaDmLkiRJ9WvQOfCU0vx1Br1fhlokSVIDNeQc+PyiGT1FRBvg68Dc8pYlSZLq05Aj8FOB04AewEKgquiXJElNpN4j8IhoBfxnSunECtUjSZIaoN4j8JTS+8AJFapFkiQ1UEPOgT8QET+j9BWgK9YMTCk9XLaqJElSvRoS4FXFz/NrDUvAgZu+HEmS1BAbDPCU0gGVKESSJDXcBq9Cj4iOEfHTNV/tGREXR0THShQnSZLq1pDbyK4E3qL0zPLjgTeBq8pZlCRJql9DzoHvkFI6tlb/eRFRXa6CJEnShjXkCPztiNh3TU9EfAZ4u3wlSZKkDWnIEfiXgMm1znsvA8aUrSJJkrRBDbkKvRrYPSI6FP1vlr0qSZJUr4Zchf7DiOiUUnozpfRmRHSOiAmVKE6SJNWtIefAD00pvb6mJ6W0DBhRvpIkSdKGNCTAW0XElmt6IuIjwJb1TC9JksqsIRexXQfMiIg1936fBEwuX0mSJGlDGnIR2wUR8SjwWUrPQP9+SumuslcmSZLWqyFH4KSUfh8RfwWGAv8ob0mSJGlD1nsOPCJui4iBRXc34HHgZOA3EfFfFapPkiTVob6L2PqmlB4vuk8CpqeUjgCGUApySZLUROoL8PdqdR8E3AGQUnoLWF3OoiRJUv3qOwc+PyK+CiwABgG/h5rbyNpUoDZJkrQe9R2BjwUGUHru+X/UepjLp/DrRCVJalLrPQJPKS0GTq1j+L3AveUsSpIk1a8hT2KTJEnNjAEuSVKGGvJtZB+rRCGSJKnh6nuQyxERsQR4LCIWRMSnK1iXJEmqR31H4D8A/i2l1A04FvhRZUqSJEkbUl+Ar0opPQWQUpoJbF2ZkiRJ0obU9yCXj0fEN9bXn1L6afnKkiRJ9akvwH/J2kfdtftT2SqSJEkbVN+DXM5b37iI2Ks85UiSpIZo0PeBA0REf2BU8XodGFyuoiRJUv3qDfCI6MO/Qvs9oDcwOKX0QrkLkyRJ61fffeAPArdTCvljU0p7Am8Z3pIkNb36biN7ldJFa12BLsUwL16TJKkZWG+Ap5SOAj4JzAHGR8TzQOeI2LtSxUmSpLrVew48pfQGpe/+vioiugLHAxMjoldKabtKFChJkj6owd9GllJ6NaX0/1JKnwH2LWNNkiRpAzZ4G1lEDAbOpXQFeu3pdytXUZIkqX4NuQ/8OuAs4DFgdXnLkSRJDdGQAF+SUppW9kokSVKDNSTAvxcRvwJmAP9cMzClNLVsVUmSpHo1JMBPAnYF2vCvJvQEGOCSJDWRhgT4XimlXcpeiSRJarCG3Eb2l+KLTDZKRLSKiEci4raiv29EzIyIv0fEbyOi7cYuW5KkzVVDAvxTQHVEzIuIv0XEYxHxt0as4+vA3Fr9FwATU0o7AsuAsY1YliRJomFN6MM3duER0RM4DPgB8I2ICOBA4IRiksnAeOAXG7sOSZI2RxsM8JTSix9i+ZOAsyl9KQrAtsDrKaVVRf8CoEddM0bEOGAcQK9evT5ECZIktTwNfpRqY0XE4cDilNKcjZk/pXRFSmlwSmlwly5dNjyDJEmbkYY0oW+szwD/HhEjgHZAB+ASoFNEtC6OwnsCC8tYgyRJLVLZjsBTSv83pdQzpdQH+BxwT0rpROBeYGQx2WjglnLVIElSS1W2AK/HNyld0PZ3SufEf90ENUiSlLVyNqHXSCndB9xXdD8H7F2J9UqS1FI1xRG4JEn6kAxwSZIyZIBLkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZMsAlScqQAS5JUoYMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuSVKGDHBJkjJkgEuSlCEDXJKkDBngkiRlyACXJClDBrgkSRkywCVJypABLklShgxwSZIyZIBLkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZMsAlScqQAS5JUoYMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuSVKGDHBJkjJkgEuSlCEDXJKkDBngkiRlyACXJClDBrgkSRkywCVJypABLklShsoW4BHRLiJmRcSjEfFERJxXDO8bETMj4u8R8duIaFuuGiRJaqnKeQT+T+DAlNLuQBUwPCI+BVwATEwp7QgsA8aWsQZJklqksgV4Klle9LYpXgk4ELixGD4ZOKpcNUiS1FKV9Rx4RLSKiGpgMTAdeBZ4PaW0qphkAdBjPfOOi4jZETF7yZIl5SxTkqTslDXAU0rvp5SqgJ7A3sCujZj3ipTS4JTS4C5dupStRkmSclSRq9BTSq8D9wL7AJ0ionUxqiewsBI1SJLUkpTzKvQuEdGp6P4IcDAwl1KQjywmGw3cUq4aJElqqVpveJKN1g2YHBGtKH1Q+F1K6baIeBKYEhETgEeAX5exBkmSWqSyBXhK6W/AHnUMf47S+XBJkrSRfBKbJEkZMsAlScqQAS5JUoYMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuSVKGDHBJkjJkgEuSlCEDXJKkDBngkiRlyACXJClDBrgkSRkywCVJypABLklShgxwSZIyZIBLkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZMsAlScqQAS5JUoYMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuSVKGDHBJkjJkgEuSlCEDXJKkDBngkiRlyACXJClDBrgkSRkywCVJypABLklShgxwSZIyZIBLkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZKluAR8R2EXFvRDwZEU9ExNeL4dtExPSIeKb42blcNUiS1FKV8wh8FXBGSqk/8CngtIjoD5wDzEgp7QTMKPolSVIjlC3AU0ovp5QeLrrfAuYCPYAjgcnFZJOBo8pVgyRJLVVFzoFHRB9gD2Am0DWl9HIx6hWg63rmGRcRsyNi9pIlSypRpiRJ2Sh7gEfER4H/Bf4rpfRm7XEppQSkuuZLKV2RUhqcUhrcpUuXcpcpSVJWyhrgEdGGUnhfl1KaWgx+NSK6FeO7AYvLWYMkSS1ROa9CD+DXwNyU0k9rjZoGjC66RwO3lKsGSZJaqtZlXPZngP8EHouI6mLYt4AfA7+LiLHAi8DxZaxBkqQWqWwBnlJ6AIj1jD6oXOuVJGlz4JPYJEnKkAEuSVKGDHBJkjJkgEuSlCEDXJKkDBngkiRlyACXJClDBrgkSRkywCVJypABLklShgxwSZIyZIBLkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZMsAlScqQAS5JUoYMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuSVKGDHBJkjJkgEuSlCEDXJKkDBngkiRlyACXJClDBrgkSRkywCVJypABLklShgxwSZIyZIBLkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZMsAlScqQAS5JUoYMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKUNkCPCKujIjFEfF4rWHbRMT0iHim+Nm5XOuXJKklK+cR+NXA8HWGnQPMSCntBMwo+iVJUiOVLcBTSn8CXltn8JHA5KJ7MnBUudYvSVJLVulz4F1TSi8X3a8AXdc3YUSMi4jZETF7yZIllalOkqRMNNlFbCmlBKR6xl+RUhqcUhrcpUuXClYmSVLzV+kAfzUiugEUPxdXeP2SJLUIlQ7wacDoons0cEuF1y9JUotQztvIrgceBHaJiAURMRb4MXBwRDwDfLbolyRJjdS6XAtOKY1az6iDyrVOSZI2Fz6JTZKkDBngkiRlyACXJClDBrgkSRkywCVJypABLklShgxwSZIyZIBLkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZMsAlScqQAS5JUoYMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuSVKGDHBJkjJkgEuSlCEDXJKkDBngkiRlyACXJClDBrgkSRkywCVJypABLklShgxwSZIyZIBLkpQhA1ySpAwZ4JIkZcgAlyQpQwa4JEkZMsAlScqQAS5JUoYMcEmSMmSAS5KUIQNckqQMGeCSJGXIAJckKUMGuCRJGTLAJUnKkAEuSVKGDHBJkjJkgEuSlCEDXJKkDDVJgEfE8IiYFxF/j4hzmqIGSZJyVvEAj4hWwGXAoUB/YFRE9K90HZIk5awpjsD3Bv6eUnoupfQuMAU4sgnqkCQpW62bYJ09gPm1+hcAQ9adKCLGAeOK3uURMa8CtWUr4GPAP5q6jk3qvGjqCoT7lsrHfavBetc1sCkCvEFSSlcAVzR1HbmIiNkppcFNXYdaHvctlYv71ofTFE3oC4HtavX3LIZJkqQGaooA/yuwU0T0jYi2wOeAaU1QhyRJ2ap4E3pKaVVEfAW4C2gFXJlSeqLSdbRAnm5QubhvqVzctz6ESCk1dQ2SJKmRfBKbJEkZMsAlScqQAd5MRMTyDYzvExGPN3KZV0fEyPWMGx0RzxSv0Y1ZrvLSBPvW7yPi9Yi4rTHLVH4quW9FRFVEPBgRT0TE3yLiPxpbb0vTbO8DV/lExDbA94DBQALmRMS0lNKypq1MLcSFQHvglKYuRC3KSuDzKaVnIqI7pfetu1JKrzd1YU3FI/BmJiI+GhEzIuLhiHgsImo/ZrZ1RFwXEXMj4saIaF/Ms2dE/DEi5kTEXRHRbQOrOQSYnlJ6rQjt6cDwMm2SmokK7VuklGYAb5VrO9T8VGLfSik9nVJ6puheBCwGupRtozJggDc/7wBHp5QGAQcAF0fEmmfz7QL8PKXUD3gT+HJEtAH+HzAypbQncCXwgw2so67H2fbYhNug5qkS+5Y2TxXdtyJib6At8Owm3Ibs2ITe/ATww4gYCqymFKxdi3HzU0p/LrqvBb4G/B4YCEwv/l9aAS9XtGLlwn1L5VKxfas4Uv8NMDqltHqTbUGGDPDm50RKzUJ7ppTei4gXgHbFuHVv2k+U/nGeSCnts74FRsQQ4L+L3u9SenTt/rUm6Qnc92ELV7NX9n0rpeRTFTdPFdm3IqIDcDtwbkrpoU25ATmyCb356QgsLv4JDmDtb6HpFRFrdvgTgAeAeUCXNcMjok1EDKi9wJTSzJRSVfGaRukpeMMionNEdAaGFcPUslVi39Lmqez7VvHo7ZuAa1JKN5Z9izJggDc/1wGDI+Ix4PPAU7XGzQNOi4i5QGfgF8V3qo8ELoiIR4Fq4NP1rSCl9BrwfUrPpf8rcH4xTC1b2fctgIi4H7gBOCgiFkTEIZt4O9T8VGLfOh4YCoyJiOriVbWpNyQnPkpVkqQMeQQuSVKGDHBJkjJkgEuSlCEDXJKkDBngkiRlyACXJClDBrgkSRn6//S1mAnjLv/YAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 504x576 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "labels = ['label-0', 'label-1', 'label-2']\n",
    "theirs = [81.0, 71.2, 79.9]\n",
    "ours = [85.7, 75.9, 85.7]\n",
    "\n",
    "\n",
    "x = np.arange(len(labels))  # the label locations\n",
    "width = 0.35  # the width of the bars\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(7,8))\n",
    "rects1 = ax.bar(x - width/2, theirs, width, label='GNN-Explainer')\n",
    "rects2 = ax.bar(x + width/2, ours, width, label='Ours')\n",
    "\n",
    "\n",
    "# Add some text for labels, title and custom x-axis tick labels, etc.\n",
    "ax.set_ylabel('mAP Score')\n",
    "ax.set_title('Comparison on synthetic data')\n",
    "ax.set_xticks(x)\n",
    "ax.set_xticklabels(labels)\n",
    "ax.legend()\n",
    "\n",
    "def autolabel(rects):\n",
    "    \"\"\"Attach a text label above each bar in *rects*, displaying its height.\"\"\"\n",
    "    for rect in rects:\n",
    "        height = rect.get_height()\n",
    "        ax.annotate('{}'.format(height),\n",
    "                    xy=(rect.get_x() + rect.get_width() / 2, height),\n",
    "                    xytext=(0, 3),  # 3 points vertical offset\n",
    "                    textcoords=\"offset points\",\n",
    "                    ha='center', va='bottom')\n",
    "\n",
    "\n",
    "autolabel(rects1)\n",
    "autolabel(rects2)\n",
    "\n",
    "\n",
    "fig.tight_layout()\n",
    "\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "data = torch.load(\"../../gcn_interpretation/data/syn2/torch_data_all_binary.pth\")\n",
    "max_s = -1\n",
    "size_l = 1400\n",
    "max_nodes = 800\n",
    "adj_t = torch.zeros((size_l,max_nodes,max_nodes)).float()\n",
    "feats_t = torch.zeros((size_l,max_nodes,10)).float()\n",
    "label_n = np.zeros((size_l), dtype=np.int32)\n",
    "pred_n = np.zeros((size_l), dtype=np.int32)\n",
    "new_idx_n = np.zeros((size_l), dtype=np.int32)\n",
    "old_idx_n = np.zeros((size_l), dtype=np.int32)\n",
    "num_nodes_n = np.zeros((size_l), dtype=np.int32)\n",
    "\n",
    "for i in range(len(data['adj'])):\n",
    "    adj = torch.from_numpy(data['adj'][i]).float()\n",
    "    feat = torch.from_numpy(data['feat'][i]).float()\n",
    "    \n",
    "    new_idx = data['new_idx'][i]\n",
    "    old_idx = data['old_idx'][i]\n",
    "    pred = data['pred'][i][new_idx]\n",
    "    label = data['label'][i][new_idx]\n",
    "    nodes = adj.shape[0]\n",
    "    \n",
    "   \n",
    "    adj_t[i,:nodes,:nodes] = adj\n",
    "    feats_t[i,:nodes,:] = feat\n",
    "    new_idx_n[i] = new_idx\n",
    "    old_idx_n[i] = old_idx\n",
    "    label_n[i] = label\n",
    "    pred_n[i] = pred\n",
    "    num_nodes_n[i] = nodes\n",
    "\n",
    "label_t = torch.from_numpy(label_n).long()\n",
    "pred_t = torch.from_numpy(pred_n).long()\n",
    "new_idx_t = new_idx_n\n",
    "old_idx_t = old_idx_n\n",
    "num_nodes_t = num_nodes_n\n",
    "\n",
    "\n",
    "\n",
    "train_idx = 1400\n",
    "val_idx = 100\n",
    "\n",
    "train_data = (adj_t[:train_idx], feats_t[:train_idx], label_t[:train_idx], num_nodes_t[:train_idx], new_idx_t[:train_idx], pred_t[:train_idx], old_idx_t[:train_idx])\n",
    "val_data = (adj_t[-val_idx:], feats_t[-val_idx:], label_t[-val_idx:], num_nodes_t[-val_idx:], new_idx_t[-val_idx:], pred_t[-val_idx:], old_idx_t[-val_idx:])\n",
    "torch.save(train_data, \"../../gcn_interpretation/data/syn2/data_train_all_binary.pth\")\n",
    "torch.save(val_data, \"../../gcn_interpretation/data/syn2/data_val_all_binary.pth\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# feats_t[1100,30]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "rule_dict_save = pickle.load(open(\"./data/syn2/rule_dict_syn2_train_all.p\",\"rb\"))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0\n",
      "1 6\n",
      "2 0\n",
      "3 2\n",
      "4 7\n",
      "5 3\n",
      "6 6\n",
      "7 5\n",
      "8 1\n",
      "9 2\n",
      "10 1\n",
      "11 3\n",
      "12 4\n",
      "13 7\n",
      "14 4\n",
      "15 5\n",
      "{'boundary': [{'basis': array([-0.02347088, -1.8197222 ,  1.5888853 , -1.7498988 , -1.6175075 ,\n",
      "       -1.0512319 ,  0.23229808,  1.1729697 , -0.22705048, -1.867039  ,\n",
      "       -0.15411872,  0.13588102,  1.7263393 , -0.37647748, -0.95449084,\n",
      "       -0.20250729,  0.92751074,  0.02742016,  0.39611596,  1.1381477 ],\n",
      "      dtype=float32), 'label': 1}, {'basis': array([-1.2900279 , -1.3250693 ,  0.34246606, -1.5184367 ,  0.03459716,\n",
      "       -1.617794  , -1.6998549 , -0.4130766 , -0.16435182, -0.96562403,\n",
      "       -0.99315643, -0.36107472,  0.27809435, -1.4130523 , -0.9585426 ,\n",
      "        0.88673985,  1.5287997 , -1.5177855 ,  0.745323  ,  0.60322267],\n",
      "      dtype=float32), 'label': 0}, {'basis': array([-0.07995409,  0.6771175 ,  1.7803952 , -1.2056451 , -1.3839505 ,\n",
      "       -1.8348911 , -0.6672793 ,  1.5163072 ,  0.14460826, -1.4733435 ,\n",
      "       -0.08043012,  0.70931345,  1.4786825 ,  0.14916843,  0.5295185 ,\n",
      "        1.2369479 ,  1.942304  , -0.74345994,  0.24632329,  0.53768086],\n",
      "      dtype=float32), 'label': 2}, {'basis': array([ 0.10984761,  0.27392077, -0.11600071, -0.13829672, -1.7389183 ,\n",
      "       -0.60024065, -1.6401954 ,  1.0220759 ,  1.5080538 , -0.3225721 ,\n",
      "        0.04317671, -0.29631454,  1.8684973 , -0.5401777 , -0.90113825,\n",
      "       -0.39788035,  0.80636847, -0.46592382,  2.0415363 , -0.6359102 ],\n",
      "      dtype=float32), 'label': 5}, {'basis': array([-0.45765457, -0.11724791,  1.2156559 , -0.80521685, -0.22446936,\n",
      "       -0.35310388,  0.35468143, -1.0041423 ,  0.87264556, -1.1374793 ,\n",
      "       -0.8096405 , -1.3784218 ,  0.67562115, -0.04139638,  0.46277976,\n",
      "       -0.24425724,  0.81083876, -0.7461382 ,  1.3261323 , -0.22820425],\n",
      "      dtype=float32), 'label': 7}], 'label': 6}\n",
      "\n",
      "\n",
      "\n",
      "{'boundary': [{'basis': array([-1.2900279 , -1.3250693 ,  0.34246606, -1.5184367 ,  0.03459716,\n",
      "       -1.617794  , -1.6998549 , -0.4130766 , -0.16435182, -0.96562403,\n",
      "       -0.99315643, -0.36107472,  0.27809435, -1.4130523 , -0.9585426 ,\n",
      "        0.88673985,  1.5287997 , -1.5177855 ,  0.745323  ,  0.60322267],\n",
      "      dtype=float32), 'label': 0}, {'basis': array([ 0.10984761,  0.27392077, -0.11600071, -0.13829672, -1.7389183 ,\n",
      "       -0.60024065, -1.6401954 ,  1.0220759 ,  1.5080538 , -0.3225721 ,\n",
      "        0.04317671, -0.29631454,  1.8684973 , -0.5401777 , -0.90113825,\n",
      "       -0.39788035,  0.80636847, -0.46592382,  2.0415363 , -0.6359102 ],\n",
      "      dtype=float32), 'label': 5}, {'basis': array([-0.45765457, -0.11724791,  1.2156559 , -0.80521685, -0.22446936,\n",
      "       -0.35310388,  0.35468143, -1.0041423 ,  0.87264556, -1.1374793 ,\n",
      "       -0.8096405 , -1.3784218 ,  0.67562115, -0.04139638,  0.46277976,\n",
      "       -0.24425724,  0.81083876, -0.7461382 ,  1.3261323 , -0.22820425],\n",
      "      dtype=float32), 'label': 7}, {'basis': array([-0.02347088, -1.8197222 ,  1.5888853 , -1.7498988 , -1.6175075 ,\n",
      "       -1.0512319 ,  0.23229808,  1.1729697 , -0.22705048, -1.867039  ,\n",
      "       -0.15411872,  0.13588102,  1.7263393 , -0.37647748, -0.95449084,\n",
      "       -0.20250729,  0.92751074,  0.02742016,  0.39611596,  1.1381477 ],\n",
      "      dtype=float32), 'label': 1}, {'basis': array([-0.07995409,  0.6771175 ,  1.7803952 , -1.2056451 , -1.3839505 ,\n",
      "       -1.8348911 , -0.6672793 ,  1.5163072 ,  0.14460826, -1.4733435 ,\n",
      "       -0.08043012,  0.70931345,  1.4786825 ,  0.14916843,  0.5295185 ,\n",
      "        1.2369479 ,  1.942304  , -0.74345994,  0.24632329,  0.53768086],\n",
      "      dtype=float32), 'label': 2}], 'label': 6}\n"
     ]
    }
   ],
   "source": [
    "for b in range(len(rule_dict_save['rules'])):\n",
    "    print(b, rule_dict_save['rules'][b]['label'])\n",
    "print(rule_dict_save['rules'][1])\n",
    "print(\"\\n\\n\")\n",
    "print(rule_dict_save['rules'][6])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "419\n",
      "[4, 7, 10]\n",
      "505\n",
      "[4, 7, 10]\n",
      "506\n",
      "[4, 7, 10]\n",
      "545\n",
      "[4, 7, 10]\n",
      "555\n",
      "[4, 7, 10]\n",
      "560\n",
      "[4, 7, 10]\n",
      "571\n",
      "[4, 7, 10]\n",
      "580\n",
      "[4, 7, 10]\n"
     ]
    }
   ],
   "source": [
    "for i in range(600):\n",
    "#     print((rule_dict_save['idx2rule'][i]))\n",
    "    if len(rule_dict_save['idx2rule'][i]) > 2:\n",
    "        print(i)\n",
    "        print((rule_dict_save['idx2rule'][i]))\n",
    "#     for r in range(len(rule_dict_save['idx2rule'][i])):\n",
    "#         print(\"rule: \",r, \"label: \", rule_dict_save['rules'][r]['label'])\n",
    "#         print(rule_dict_save['rules'][r])\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]\n",
      "[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]\n",
      "[-1. -1.  0.  0.  0.  0.  0.  0.  0.  0.]\n",
      "[0.5 0.5 1.  1.  1.  1.  1.  1.  1.  1. ]\n",
      "[1. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      "[0.5 0.5 1.  1.  1.  1.  1.  1.  1.  1. ]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([-1., -1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "random_mu = [0.0] * 8\n",
    "random_sigma = [1.0] * 8\n",
    "print(random_mu)\n",
    "print(random_sigma)\n",
    "# Create two grids\n",
    "mu_1, sigma_1 = np.array([-1.0] * 2 + random_mu), np.array([0.5] * 2 + random_sigma)\n",
    "mu_2, sigma_2 = np.array([1.0] * 2 + random_mu), np.array([0.5] * 2 + random_sigma)\n",
    "print(mu_1)\n",
    "print(sigma_1)\n",
    "\n",
    "print(mu_2)\n",
    "print(sigma_2)\n",
    "np.array([-1.0] * 2 + random_mu)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "#syn1\n",
    "    #explain\n",
    "    avg map score:  0.811522113951501, \"size\": 0.009,\n",
    "    avg map score:  0.810499875544593, \"size\": 0.02,\n",
    "\n",
    "    #boundary\n",
    "    avg map score:  0.9017117369954706  \"size\": 0.09,\n",
    "             \"size\": 0.15,# syn1\n",
    "\n",
    "                \n",
    "#syn2\n",
    "\n",
    "    #explain\n",
    "    avg map score:  0.7502992748124584 , \"size\": 0.02\n",
    "    avg map score:  0.7509465051411234, \"size\": 0.02\n",
    "    avg map score:  0.7735268011698395, \"size\": 0.01\n",
    "    avg map score:  0.7645930558803843, \"size\": 0.008\n",
    "\n",
    "    #boundary\n",
    "    avg map score:    0.6934569870477331, \"size\": 0.09,\n",
    "    avg map score:  0.6200851770965283, \"size\": 0.2,\n",
    "    avg map score:  0.7467577810590663,  \"size\": 0.05,\n",
    "    avg map score:  0.7696084388682304, ,  \"size\": 0.03,\n",
    "    avg map score:  0.7695666747119264, \"size\": 0.03,\n",
    "    avg map score:  0.73722353044399, \"size\": 0.01\n",
    "    avg map score:  0.7708977721762329, \"size\": 0.02\n",
    "    avg map score:  0.7717151540157421, \"size\": 0.02\n",
    "    avg map score:  0.7636393078139467, \"size\": 0.015\n",
    "            \n",
    "\n",
    "#syn2_bin\n",
    "    #explain\n",
    "    avg map score:  0.8185703908808911, \"size\": 0.009,\n",
    "    avg map score:  0.8212237823097907, \"size\": 0.015,\n",
    "    avg map score:  0.7890471662501071, \"size\": 0.025,\n",
    "    \n",
    "    #boundary\n",
    "    avg map score:  0.8349695331961251, \"size\": 0.025,\n",
    "    avg map score:  0.8331501899783295, \"size\": 0.015,\n",
    "    avg map score:  0.8252089044707217, \"size\": 0.035,\n",
    "\n",
    "\n",
    "#syn3\n",
    "\n",
    "    #explain\n",
    "    avg map score:  0.7534129195898359, \"size\": 0.008\n",
    "    avg map score:  0.756609433286107,  \"size\": 0.01\n",
    "    avg map score:  0.7443666076476425, \"size\": 0.006\n",
    "\n",
    "\n",
    "    #boundary\n",
    "    avg map score:  0.7599638752353364, \"size\": 0.025\n",
    "    avg map score:  0.7834871301842039, \"size\": 0.015\n",
    "    avg map score:  0.8508268504198123, \"size\": 0.009\n",
    "    avg map score:  0.872233121779854,  \"size\": 0.007\n",
    "    avg map score:  0.8849059867176325, \"size\": 0.005\n",
    "    avg map score:  0.8404759346707498, \"size\": 0.003"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
