{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import networkx as nx\n",
    "import torch\n",
    "from sig.utils.graph_utils import plot_important_subgraph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# NOTE: Modify these specs to optimize the visualization\n",
    "\n",
    "figsize = (12, 12)                                       # figure size\n",
    "\n",
    "flatui = [\"#9b59b6\", \"#3498db\", \"#95a5a6\", \"#e74c3c\"]\n",
    "node_cmap = sns.color_palette(\"muted\").as_hex() + flatui # node color map\n",
    "node_feat_cmap = 'Reds'                                  # matplotlib cmap for node feature heatmap\n",
    "\n",
    "edge_importance_cmap = plt.cm.YlOrBr                     # for edge heatmap\n",
    "important_edge_color = '#ff0000'                         # for important edge for extracted subgraph\n",
    "unimportant_edge_color = '#d3d3d3'                       # for unimportant edge for extracted subgraph\n",
    "\n",
    "draw_kwargs = {}                                         # additional kwargs for nx.draw_networkx\n",
    "draw_kwargs['width'] = 5                                 # edge width\n",
    "draw_kwargs['arrows'] = False                            # show edge arrows\n",
    "draw_kwargs['with_labels'] = False                       # show node labels\n",
    "draw_kwargs['node_size'] = 450                           # node size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_example_lookup = {                               # examples to plot\n",
    "    'mutagenicity': [98, 409],\n",
    "    'reddit_binary': [106, 157]\n",
    "}\n",
    "dataset_plot_mode_lookup = {\n",
    "    'mutagenicity': ['graph', 'node_feat'],\n",
    "    'reddit_binary': ['graph']\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_size_lookup = {\n",
    "    'mutagenicity': 16,\n",
    "    'reddit_binary': 32\n",
    "}\n",
    "model_lookup = {\n",
    "    'gnn': 'gcn',\n",
    "    'gat': 'gat',\n",
    "    'pred_grad': 'gcn',\n",
    "    'mag_pred_grad': 'gcn',\n",
    "    'pred_sig_grad': 'sigcn'\n",
    "}\n",
    "output_subdir_lookup = {\n",
    "    'gnn': 'nonsig',\n",
    "    'gat': 'nonsig',\n",
    "    'pred_grad': 'nonsig',\n",
    "    'mag_pred_grad': 'nonsig',\n",
    "    'pred_sig_grad': 'sig_small_reg'\n",
    "}\n",
    "explainer_plot_mode_lookup = {\n",
    "    'gnn': ['graph', 'node_feat'],\n",
    "    'gat': ['graph'],\n",
    "    'pred_grad': ['graph'],\n",
    "    'mag_pred_grad': ['node_feat'],\n",
    "    'pred_sig_grad': ['graph', 'node_feat']\n",
    "}\n",
    "allowed_explainers = ['gat', 'pred_grad', 'mag_pred_grad', 'pred_sig_grad', 'gnn']\n",
    "allowed_datasets = ['mutagenicity', 'reddit_binary']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_fig(\n",
    "    explainer, \n",
    "    plot_mode,\n",
    "    dataset,\n",
    "    index\n",
    "):\n",
    "    assert explainer in allowed_explainers, \\\n",
    "        'explainer has to {}'.format(allowed_explainers)\n",
    "    assert dataset in allowed_datasets, \\\n",
    "        'dataset has to be {}'.format(allowed_datasets)\n",
    "    assert plot_mode in explainer_plot_mode_lookup[explainer], \\\n",
    "        'plot_mode={} is not available for {}'.format(plot_mode, explainer)\n",
    "    \n",
    "    source_dir = 'output/real_graph/hidden_{}/{}/{}/{}/model_output_0/'.format(\n",
    "        model_size_lookup[dataset],\n",
    "        output_subdir_lookup[explainer], \n",
    "        dataset,\n",
    "        model_lookup[explainer]\n",
    "    )\n",
    "    output_dir = 'figs/{}/'.format(dataset)\n",
    "    if plot_mode == 'graph':\n",
    "        minsize = 15\n",
    "        source_dir = source_dir + '{}_explainer_{}_files_minsize_{}/'.format(explainer, plot_mode, minsize)\n",
    "        output_prefix = 'index_{}_{}_{}_minsize_{}'.format(index, explainer, plot_mode, minsize)\n",
    "    else: \n",
    "        source_dir = source_dir + '{}_explainer_{}_files/'.format(explainer, plot_mode)\n",
    "        output_prefix = 'index_{}_{}_{}'.format(index, explainer, plot_mode)\n",
    "    source_prefix = source_dir + 'index_{}'.format(index)\n",
    "    output_prefix = output_dir + output_prefix\n",
    "    \n",
    "    G = nx.read_gpickle(source_prefix + '.gpkl')\n",
    "    graph_info = torch.load(source_prefix + '_info.pt')\n",
    "    node_color = [node_cmap[i] for i in graph_info['node_type']]\n",
    "    pos = graph_info['pos']\n",
    "    \n",
    "    if plot_mode == 'graph':\n",
    "        edge_score = graph_info['edge_score']\n",
    "        edge_index = graph_info['edge_index']\n",
    "        important_edge_mask = graph_info['important_edge_mask']\n",
    "\n",
    "        # plot edge importance heatmap\n",
    "        fig = plt.figure(figsize=figsize)\n",
    "        nx.draw_networkx(\n",
    "            G,\n",
    "            pos=pos,\n",
    "            node_color=node_color,\n",
    "            edge_color=edge_score.detach().cpu().numpy(),\n",
    "            edge_cmap=edge_importance_cmap,\n",
    "            **draw_kwargs\n",
    "        )\n",
    "        plt.tight_layout()\n",
    "        plt.savefig(output_prefix + '_heatmap.png', format='PNG')\n",
    "        plt.close()\n",
    "\n",
    "        # plot important subgraph\n",
    "        fig = plt.figure(figsize=figsize)\n",
    "        plot_important_subgraph(\n",
    "            edge_index,\n",
    "            important_edge_mask,\n",
    "            node_color,\n",
    "            G=G,\n",
    "            pos=pos,\n",
    "            important_edge_color=important_edge_color,\n",
    "            unimportant_edge_color=unimportant_edge_color,\n",
    "            **draw_kwargs\n",
    "        )\n",
    "        plt.tight_layout()\n",
    "        plt.savefig(output_prefix + '_subgraph.png', format='PNG')\n",
    "        plt.close()\n",
    "    else:\n",
    "        node_feat_score = graph_info['node_feat_score']\n",
    "\n",
    "        fig, ax = plt.subplots()\n",
    "        ax.imshow(\n",
    "            node_feat_score.detach().cpu().numpy()[np.newaxis, :], \n",
    "            cmap=node_feat_cmap, \n",
    "            aspect='auto'\n",
    "        )\n",
    "        ax.set_xticks(\n",
    "            np.arange(-.5, node_feat_score.shape[0], 1), \n",
    "            minor=True\n",
    "        )\n",
    "        ax.grid(\n",
    "            which='minor', \n",
    "            color='black', \n",
    "            linestyle='-', \n",
    "            linewidth=2\n",
    "        )\n",
    "        ax.set_xticks([])\n",
    "        ax.set_yticks([])\n",
    "        plt.tight_layout()\n",
    "        plt.savefig(output_prefix + '_heatmap.png', format='PNG')\n",
    "        plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "for explainer in allowed_explainers:\n",
    "    for dataset in allowed_datasets:\n",
    "        plot_modes = set(explainer_plot_mode_lookup[explainer]).intersection(\n",
    "            set(dataset_plot_mode_lookup[dataset])\n",
    "        )\n",
    "        for plot_mode in plot_modes:\n",
    "            for index in dataset_example_lookup[dataset]:\n",
    "                plot_fig(explainer, plot_mode, dataset, index)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sig",
   "language": "python",
   "name": "sig"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
