{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b941618a",
   "metadata": {},
   "source": [
    "## ZINC Visualization with Init PE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58a892b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('../') # go to root folder of the project\n",
    "print(os.getcwd())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ab99754",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import networkx as nx\n",
    "from itertools import count\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from itertools import count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2adb4608",
   "metadata": {},
   "outputs": [],
   "source": [
    "from data.data import LoadData\n",
    "zinc_d = LoadData('ZINC')\n",
    "\n",
    "pos_enc_dim = 24\n",
    "zinc_d._init_positional_encodings(pos_enc_dim, 'rand_walk')\n",
    "zinc_d._add_eig_vecs(36)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b12e503",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_nodes = []\n",
    "num_unique_RWPEs = []\n",
    "num_unique_LapPEs = []\n",
    "\n",
    "for g_ in zinc_d.val:\n",
    "    num_nodes.append(g_[0].number_of_nodes())\n",
    "    num_unique_RWPEs.append(len(torch.unique(g_[0].ndata['pos_enc'], dim=0)))\n",
    "for g_ in zinc_d.val:\n",
    "    num_unique_LapPEs.append(len(torch.unique(g_[0].ndata['eigvec'], dim=0)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "edcf8c63",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_init_PE_comparison(num_nodes, num_unique_PEs, PE='RWPE'):\n",
    "    fig = plt.figure(dpi=100, figsize=(7, 6))\n",
    "    ax = plt.axes()\n",
    "    plt.xlabel(\"Number of nodes\", fontsize=10)\n",
    "    if PE == 'RWPE':\n",
    "        plt.title(\"ZINC val (1K): Comparison of no. of nodes v/s no. of unique RWPEs\", fontsize=15)\n",
    "        plt.ylabel(\"Number of unique RWPEs\", fontsize=10)\n",
    "    elif PE == 'LapPE':\n",
    "        plt.title(\"ZINC val (1K): Comparison of no. of nodes v/s no. of unique LapPEs\", fontsize=15)\n",
    "        plt.ylabel(\"Number of unique LapPEs\", fontsize=10)\n",
    "    x = np.array(num_nodes)\n",
    "    y1 = np.array(num_unique_PEs)\n",
    "    #plt.xticks(x)\n",
    "    plt.hist2d(x, y1, (50, 50), cmap=plt.cm.Reds)# plt.cm.jet)\n",
    "    plt.colorbar()\n",
    "    plt.xlim([10, 35])\n",
    "    plt.ylim([10, 35])\n",
    "    #plt.scatter(x, y1, marker=\"o\", color=\"green\", linewidth=0.5)\n",
    "    x = np.linspace(9,40,100)\n",
    "    y = x\n",
    "    plt.plot(x, y, '-r', )\n",
    "\n",
    "    #fig.savefig('out_ZINC_PE_viz/ZINC_valset_'+PE+'.pdf', bbox_inches='tight') \n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f1de5dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_init_PE_comparison(num_nodes, num_unique_RWPEs, 'RWPE')\n",
    "plot_init_PE_comparison(num_nodes, num_unique_LapPEs, 'LapPE')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "607a71d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "graph_ids = [212,672]   # not equal\n",
    "graph_ids = [91,967] # equal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "960aa77f",
   "metadata": {},
   "outputs": [],
   "source": [
    "for idx in graph_ids:\n",
    "    fig = plt.figure(dpi=100, figsize=(8, 6))\n",
    "    g_zinc_trial = zinc_d.val[idx][0]\n",
    "    g = g_zinc_trial.to_networkx(node_attrs=['pos_enc'])\n",
    "    #groups = set(nx.get_node_attributes(g,'pos_enc').values())\n",
    "    groups = torch.unique(g_zinc_trial.ndata['pos_enc'],dim=0)\n",
    "    mapping = dict(zip(groups,count()))\n",
    "    nodes = g.nodes()\n",
    "    colors = []\n",
    "    for n in nodes:\n",
    "        for key in mapping.keys():\n",
    "            if torch.equal(key, g.nodes[n]['pos_enc']):\n",
    "                color = mapping[key]\n",
    "        colors.append(color)\n",
    "        \n",
    "    pos = nx.spring_layout(g)\n",
    "    ec = nx.draw_networkx_edges(g, pos, alpha=0.2)\n",
    "    nc = nx.draw_networkx_nodes(g, pos, nodelist=nodes, node_color=colors, \n",
    "                                node_size=100, cmap=plt.cm.jet)\n",
    "    plt.colorbar(nc)\n",
    "    #plt.xlabel(\"ZINC Val id: \" +str(idx), fontsize=16)\n",
    "    plt.title(\"nodes: \"+ str(num_nodes[idx]) + \" | unique RWPEs: \"+str(num_unique_RWPEs[idx]), fontsize=16)\n",
    "    #fig.savefig('out_ZINC_PE_viz/ZINC_valset_graph_'+str(idx)+'.pdf', bbox_inches='tight') \n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59af93f9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e552577",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
