{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import numpy\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import networkx as nx\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "source": [
    "with open('generalized_graph_learning/data/debug/last_eval_graphs.pickle', 'rb') as f:\n",
    "    inf_train_graphs, val_graphs, test_graphs = pickle.load(f)\n",
    "\n",
    "with open('generalized_graph_learning/data/debug/last_training_graphs.pickle', 'rb') as f:\n",
    "    train_graphs = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../data/debug/graphs_C_non_smooth.pickle', 'rb') as f:\n",
    "    inf_train_graphs, val_graphs, test_graphs = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def mask_diagonal_and_sigmoid(A):\n",
    "    mask = torch.eye(A.shape[1], dtype=bool, device=A.device).repeat(A.shape[0], 1, 1)\n",
    "    return torch.sigmoid(A - mask*100000000)\n",
    "    #return A.masked_select(mask).view(A.shape[0], A.shape[1], A.shape[2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.colorbar.Colorbar at 0x7f7b98558da0>"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "text/plain": [
       "<Figure size 432x288 with 0 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQEAAAD3CAYAAAAdUOFNAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAARLUlEQVR4nO3df6xfdX3H8eeLIjYwdGIdYxQncSVZ57AzDWzRTAzqCn9YzRZDzTJcyOofdlnmtqTLFjTsH93izJYQtutsYCbKGAmziZ1XZRqSZZJeIukoG3jHQFqRWmCGhGjpva/98f22fHu93+85997z6/s9r0dycs8533M/591v03c/v87nyDYR0V/ntR1ARLQrSSCi55IEInouSSCi55IEInouSSCi55IEIqaIpAOSTkh6ZMznkvS3khYlHZH0tqIykwQipsudwK4Jn98AbBtue4E7igpMEoiYIrYfAJ6fcMlu4B898C3gpyVdNqnMJIGI2XI58PTI8bHhubHOrzWciOA33nWRn3t+qdS1Dx358VHgRyOn5mzP1RLYUJJARM1OPr/Eg/NbS137qsv+50e2d27gdseBK0aOtw7PjZXmQETtzJKXS20VOAj8znCU4FeBH9p+ZtIvpCYQUTMDy1TztK6kLwLXAVskHQM+DrwKwPbfAYeAG4FF4CXgd4vKTBKIqJkxL7tcn0BhWfaegs8NfHQtZSYJRDSgqppAHVrrE5C0S9Jjw5lN+9uKYy0kPSnpPyU9LGmh7XhWs9qMMkmXSPqapO8Mf76uzRhHjYn3E5KOD7/nhyXd2GaMG2VgCZfa2tBKEpC0Cbidweym7cAeSdvbiGUd3mV7xwZ7cOt0Jz85o2w/cL/tbcD9w+OuuJPVZ8B9Zvg977B9qOGYKreMS21taKsmcA2waPsJ26eAuxnMdIoNGjOjbDdw13D/LuD9jQY1QYkZcFPPwJJdamtDW0lgzbOaOsLAVyU9JGlv28GswaUjw0TfBy5tM5iS9g0fgDnQpebLei2X3NqQeQJr8w7bb2PQjPmopF9vO6C1GvYed7eXauAO4M3ADuAZ4NPthrMxLtkf0Ks+AdYxq6kLbB8f/jwB3MegWTMNnj3zEMnw54mW45nI9rO2l2wvA59ler7nVdnwcsmtDW0lgcPANklXSroAuInBTKfOknSRpIvP7APvBVZ9pruDDgI3D/dvBr7UYiyFVjz19gGm53seQyyV3NrQyjwB26cl7QPmgU3AAdtH24hlDS4F7pMEg+/tC7a/0m5IP2nMjLJPAvdIugV4CvhgexGea0y810nawaDZ8iTwkdYCrICB5Q43wJSXj0TU6y1XX+B7vvyGUtf+0hu/91DTw8+ZMRhRs8FkoXaq+mUkCUQ0YNlJAhG9lZpARM8Z8bI3tR3GWK1OFpqyWXdAYm7CtMVb5ExNoKtDhG3PGJzGv+zEXL9pi7eAWPJ5pbY2pDkQUbPBykJt/387XqNJ4AK92pu56OzxZi7kNbrknIkKV139UpMhFXr8yIXnHK835pXlNGm1mIs0+Wdaea83Xn4+O9+6+Zx42/z+VvMiL5y0XW7wnxnuGJS0C/gbBrP+/sH2Jyddv5mLuFbXTyxzfv7hjYRUud/4uR2F15SJuUw5XdLkn2kav7+v+96nyl5rq7WqfhnrjmzKFwaJaNQyKrW1YSM1gbMLgwBIOrMwyKNVBBYxK4w45e52v20kstUWBrl25UXD4Z69MGibRvRN7zsGh69QmgPW3DkVMSuWZnTa8FQuDBLRNCOWZrQmcHZhEAb/+G8CPlRJVBEzZrnDowPrTgLrWRjkqqtfKhwO6tpQ0Pz3GhwqK3GvIlXFUmpotIJ4u3ivqg2mDc9gEgAYrgc/9WvCR9Sp6w8QdXfcImJG2HR6slCSQETt2psIVEaSQETNBm8gSk0gotdmtmMwIooZZY3BiL5LTSCixzJEOOWanFzS1L2qmgA1i99NHQZvIEpNIKLXuryyUHfTU8SMsMWyzyu1FZG0S9JjkhYl7V/l8zdK+oakb0s6IunGojJTE4hoQBXzBEZW83oPg/U7Dks6aHt0IZ8/B+6xfcdwpa9DwJsmlZuaQETNBouKVLK82NnVvGyfAs6s5rXydq8Z7r8W+F5RoakJRNRuTQuNbpG0MHI8N1yYB8qt5vUJ4KuSfh+4CHh30Q2TBCJqZljLEOHJDb6afA9wp+1PS/o14POS3mJ7edwvJAlE1KzCGYNlVvO6BdgFYPs/JG0GtgAnxhWaPoGIBixzXqmtwNnVvCRdwGA1r4MrrvkucD2ApF8ENgM/mFRoagI9NM0Tb6bRYD2BjdcExq3mJek2YMH2QeCPgM9K+kMGLZEP2564wG+SQEQDqnqAaLXVvGzfOrL/KPD2tZSZJBBRs0GfQHdb3kkCEQ3o8rThJIGImhlxejlPEUb0WtYYjOixqkYH6pIkENGAdAxG9FjWGIyI9AlE9NlgebEkgYj+coYII3rtzKIiXZUkENGANAcieix9AhGRJBDRZ5knENF3htOZMTjw+JELC1e1qeoVWRFdkT6BiJjdJCDpSeBFYAk4vcGlkiNmUh/6BN5l+2QF5UTMLM94EoiIAl2eMbjRLkszeOXRQ5L2rnaBpL2SFiQtvMyPN3i7iOljD/oEymxt2GhN4B22j0v6GeBrkv7b9gOjFwzfozYH8BpdMnH984jZJJaWuztEuKHIbB8f/jwB3MfgrakRsYKtUlsb1p0EJF0k6eIz+8B7gUeqCixiVpyZJzCLzYFLgfsknSnnC7a/MukXrrr6JebnJ08GmtWJQE1Ngspkq8la+X486BfoqnUnAdtPAG+tMJaImdXl0YEMEUbUzGSeQETPzf6MwYgosLycJBDRW3aaAxG9l+ZARM/N5BBhRJSX5sBQmZWFplGXJujM4vdbpTa+H9PelOAyuvtUQ8QMccmtiKRdkh6TtChp/5hrPijpUUlHJX2hqMw0ByLqZnAFQ4SSNgG3A+8BjgGHJR20/ejINduAPwXebvuF4RO+E6UmENGAip4ivAZYtP2E7VPA3cDuFdf8HnC77RcG9/WJokKTBCIaYJfbClwOPD1yfGx4btRVwFWS/l3StyTtKio0zYGImq3x2YEtkhZGjueGC/OUdT6wDbgO2Ao8IOmXbf/fpF+IiDoZKJ8ETk5Ytfs4cMXI8dbhuVHHgAdtvwz8r6THGSSFw+NumOZARAMqag4cBrZJulLSBcBNwMEV1/wLg1oAkrYwaB48ManQJIGIJlQwRmj7NLAPmAf+C7jH9lFJt0l63/CyeeA5SY8C3wD+xPZzk8rtXHOgSxNvyupaPNE1qmSIEMD2IeDQinO3juwb+NhwK6VzSSBi5uQpwogoNR2wJUkCEY1ITSCi31ITiOi5JIGIHqvoAaK6JAlENCE1gYieyxBhRL8pNYGIHiu7bFBLkgQiaqc0ByJ6LzWBiJ5bbjuA8ZIEIuq2tkVFGpckENGAjA5E9F2SQHnTuEDHNC6EEnFG55JAxCzqcnOgcI1BSQcknZD0yMi5SyR9TdJ3hj9fV2+YEVPOKre1oMxCo3cCK19gsB+43/Y24P7hcUSsxgyGCMtsLShMArYfAJ5fcXo3cNdw/y7g/RXHFTFT5HJbG9bbJ3Cp7WeG+98HLq0onojZ1OE+gQ13DNq2ND6HSdoL7AXYzIUbvV3EdOpwEljvy0eelXQZwPDn2Def2p6zvdP2zlfx6nXeLmJ6lW0KtNUcWG8SOAjcPNy/GfhSNeFEzKgOjw4UNgckfZHBu822SDoGfBz4JHCPpFuAp4AP1hnkSl2bnJOJQFGow82BwiRge8+Yj66vOJaImaU8RRjRYy2298tIEohoQpJARM8lCUT0W5ebA+sdIoyIGZGaQEQTOlwTSBKIqJszRHjWVVe/xPz85Ik+ZSbeZHJOTJ3UBCL6S6RjMCJccisgaZekxyQtShq7mI+k35RkSTuLykwSiKhbRU8RStoE3A7cAGwH9kjavsp1FwN/ADxYJrwkgYgmVFMTuAZYtP2E7VPA3QxW+VrpL4BPAT8qE1qSQEQDtFxuK3A58PTI8bHhuVfuI70NuML2l8vGlo7BiCaU7xjcImlh5HjO9lyZX5R0HvDXwIfXElqSQETdSnb6DZ20Pa4z7zhwxcjx1uG5My4G3gJ8UxLAzwIHJb3P9mhiOUeSQEQDKhoiPAxsk3Qlg3/8NwEfOvOh7R8CW87eU/om8MeTEgA0nAQeP3Jh4USfrq0aFFGJCpKA7dOS9gHzwCbggO2jkm4DFmwfXE+5qQlENKCqyUK2DwGHVpy7dcy115UpM0kgogkdnjGYJBBRszaXEy8jSSCiCUkCEf2WmkBE3yUJRPRckkBEj6VjMCJSE4jouawxGNFzaQ5E9NnaniJsXJJARBOSBCL6q+urDScJRDQhSSCi3+TuZoEkgYi65TVkr6jqNWQRU6e7FYHUBCKa0OWOwcL3Dkg6IOmEpEdGzn1C0nFJDw+3G+sNM2LKVfQasjqUefnIncCuVc5/xvaO4XZolc8jAip7DVldCpOA7QeA5xuIJWJ2TXlNYJx9ko4MmwuvG3eRpL2SFiQt/OC5pQ3cLmI6nZksNLU1gTHuAN4M7ACeAT497kLbc7Z32t75htdvWuftIqabll1qa8O6koDtZ20v2V4GPsvgbakRsZqyTYFpqglIumzk8APAI+OujYjK3kpci8J5ApK+CFzH4G2px4CPA9dJ2sEgdz0JfKTMzfIasuitDs8TKEwCtvescvpzNcQSMbO6PFkoMwYj6mYgDxBF9FseIIrosSwqEtF3dpoDEX2XmkBE3yUJRPRbagIRfWagpecCykgSiGhAl4cIN/IocUSUdWaEoGgrIGmXpMckLUrav8rnH5P06PAx//sl/XxRmUkCEQ2oYj0BSZuA24EbgO3AHknbV1z2bWCn7auBe4G/LIotSSCibtU9SnwNsGj7CdungLuB3efcyv6G7ZeGh98CthYVmj6BiJoNZgyW7hjcImlh5HjO9txw/3Lg6ZHPjgHXTijrFuBfi26YJBDRhPIdgydt79zo7ST9NrATeGfRtUkCEQ2o6DVkx4ErRo63Ds+dey/p3cCfAe+0/eOiQpMEIupmVzVP4DCwTdKVDP7x3wR8aPQCSb8C/D2wy/aJMoV2LglM46pBWQ1ptpT5+9x0WeEl56hixqDt05L2AfPAJuCA7aOSbgMWbB8E/gr4KeCfJQF81/b7JpXbuSQQMZMqeopw+KKfQyvO3Tqy/+61lpkkEFG3vJU4IrKeQETfdTcHJAlENKGiIcJaJAlE1M3AUpJARG8JpyYQ0XtJArMtE4GmR3UTuxbXduMkgYgeM2t5gKhxSQIRDUifQETfJQlE9JgNy91tDyQJRDShuzkgSSCiCekTiOi7JIGIHssbiF7xIi+c/LrvfWrk1BbgZJMxVCAx16+2eMutCFRqIlDhSz1ekVeTn2X7DaPHkhaqWFm1SYm5ftMWbylJAhE9ZmCpu8MDSQIRtTM4SWCcueJLOicx12/a4i2W5sDqRl6vNDUSc/2mLd5CGR2IiNQEIvouSSCix2xYWmo7irGSBCKakJpARM8lCUT0WWVvJa5FkkBE3QzOZKGInktNIKLn0icQ0WMZIowIZ6HRiD7LoiIR/dbxB4jOazuAiF7wcrmtgKRdkh6TtChp/yqfv1rSPw0/f1DSm4rKTBKIqJkBL7vUNomkTcDtwA3AdmCPpO0rLrsFeMH2LwCfAT5VFF+SQETd7KpqAtcAi7afsH0KuBvYveKa3cBdw/17geslaVKh6ROIaICrGSK8HHh65PgYcO24a2yflvRD4PVMWL05SSCiZi/ywvzXfe+WkpdvlrQwcjxX90pLSQIRNbO9q6KijgNXjBxvHZ5b7Zpjks4HXgs8N6nQ9AlETI/DwDZJV0q6ALgJOLjimoPAzcP93wL+zZ48SSE1gYgpMWzj7wPmgU3AAdtHJd0GLNg+CHwO+LykReB5BoliIhUkiYiYcWkORPRckkBEzyUJRPRckkBEzyUJRPRckkBEzyUJRPRckkBEz/0/v2YzA4zYh+8AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 288x288 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "<Figure size 432x288 with 0 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQEAAADzCAYAAACGwaNbAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAWnElEQVR4nO3dbYxcZ3UH8P9/Zr279mbtBDZ1UtuUFByphqYutRJVoGIUXhw+YFArZKNKQYpqKtVVW9pKrloF5H6BVpT2Q4S6FCsREgkpUmAlXByagiJVJfWmiUKSNsU1efHi+GW9MY7X8e7MnH6Yu854vTPnzO59mZ37/0lXnpl9/Mzj2fXZ+3LuOTQziEh5VYpegIgUS0FApOQUBERKTkFApOQUBERKTkFApOQGil6ASL/7yAdGbPpcPTT2yWcuHzGzXRkv6SoKAiIZO3uujieObA6NXXPz/41lvJxrKAiIZM5Qt0bRi2hLQUAkYwaggd7NzFUQEMmYwTBvsXMCRVAQEMlBL+8JFHaJkOQuki+QPEbyQFHr6AbJF0n+mOTTJCeLXs9SSB4ieZrksy2vvYXk90n+JPnzhiLX2KrNej9Pcir5nJ8m+dEi17hSBqAOC21FKCQIkKwCuA/AXQC2AdhLclsRa1mGD5jZdjPbUfRC2rgfwOJLTAcAPGZmWwE8ljzvFffj2vUCwJeTz3m7mR3OeU2pa8BCWxGK2hO4HcAxMztuZnMAHgKwu6C19BUzexzAuUUv7wbwQPL4AQAfz3VRHbRZb18xAHWz0FaEooLAJgCvtDw/kbzW6wzAoySfJLmv6MV0YaOZnUwevwpgY5GLCdpP8pnkcKFnDl+WqxHciqC04e68z8zeg+ZhzB+Q/K2iF9Qta1aR6d2zVE1fAfAOANsBnATwpWKXszIWPB9QqnMCAKYAbGl5vjl5raeZ2VTy52kAj6B5WLManCJ5MwAkf54ueD0dmdkpM6ubWQPAV7F6PuclmQHzwa0IRQWBowC2kryF5CCAPQAmClpLCMkRkqMLjwF8GMCznf9Wz5gAcHfy+G4A3ylwLa6FgJX4BFbP59wGUQ9uRSgkT8DMaiT3AzgCoArgkJk9V8RaurARwCMkgebn9g0z+16xS7oWyQcB7AQwRvIEgM8B+AKAh0neA+AlAJ8sboVXa7PenSS3o3nY8iKAzxS2wBQYgEYPH4BRhUZFsvXu2wbt4e/eGBr7rrf97Mm8Lz8rY1AkY81koWJ29SMUBERy0DAFAZHS0p6ASMkZiHmrFr2MtgpNFlplWXcAtOY8rLb1ehb2BHr1EmHRGYOr8ZutNWdvta3XQdStEtqKoMMBkYw1KwsV/fu2vVyDwCCHbBgjV54PYx3W8y1XJSrcetusO8+Pz/nXXKtz/noq8/6YxcF5cOQGjIxtuWrN1Tk/1yJ0cjiSsuH8LLF27STDQxuw/rpNV77AQOaK0V8wAzkmVu1+F3d4aAPWj266evLAZ2MV/70ql2v+RFX/P+zPL508a2axi//o4xODJHcB+Ac0s/7+ycy+0Gn8MEZwB+/sOOeRI0+77/vOB3/fHTP6U/8bufaMf99Wfcj/5o2+fNmfZ9g/MVSZj6yn879r8DU/+lUv+mNs0P/R4LxfMqs+MuiOiWDd/2xq69a4Y4aPn3HHNK6/zh3z6FMHX3IHJcxY2K5+xLJXtsoLg4jkqgGGtiKsZE/gSmEQACC5UBjk+TQWJtIvDMSc9e7pt5WsbKnCIHcsHpRc7tkHNM8BiJRN6U8Mmtk4gHEA15wEFCmLep+mDa/KwiAieTMQ9T7dE7hSGATN//x7AHwqlVWJ9JlGD18dWHYQWE5hkFtvm3UvAX7kF7f77/33/vqWul5+zTyB70t9yB9TW+df/ptb749Zd8q/5NYY7LxbOXD+DXeO2vVr3TGVuXQ65kQuI9qA/42ozvj5I5VZ/8e5fsOoOwaB9XSjmTbch0EAAJJ68Ku+JrxIlnr9BqLevW4h0ifM0NPJQgoCIpkrLhEoQkFAJGPNDkTaExAptb49MSgiPgNVY1Ck7LQnIFJiukTY4sfnbnRrAUQSgd75xz9yx1z6eDrt69ZO+/exD09d8McE7s+P/LJYM+0nzXgGzr7uD2r4/27W/EQgDvg//DaUUs2BWb+uQyVQlyDyb+9GswOR9gRESq2XKwv1bngS6RNmRMMqoc1DchfJF0geI3lgia+/jeQPSD5F8hmSH/Xm1J6ASA7SyBNoqeb1ITTrdxwlOWFmrYV8/grAw2b2laTS12EAb+80r/YERDLWLCqSSnmxK9W8zGwOwEI1r8Vvtz55vAHAz7xJtScgkrmuCo2OkZxseT6eFOYBYtW8Pg/gUZJ/CGAEwAe9N1QQEMmYAd1cIjy7wtbkewHcb2ZfIvmbAL5O8t1m1vaSh4KASMZSzBiMVPO6B8AuADCz/yA5DGAMwOl2k+qcgEgOGqiENseVal4kB9Gs5jWxaMzLAO4EAJK/AmAYQMdmC7nuCVTn/KYgkYpAkUSgtd/+T3cMf+Nd7pj5DcPuGAQ68VTPX/TnSUFjNFA1aNavPmTTM/6brfPfi4ExOHPOH7PebwiC6dfcITYQaKoy6Dcx6UaznsDK9wTaVfMieRDApJlNAPhTAF8l+SdoHol82qzzD6gOB0RykNYNREtV8zKze1sePw/gvd3MqSAgkrHmOYHePfJWEBDJQS+nDSsIiGTMQNQauotQpNRUY1CkxNK6OpAVBQGRHOjEoEiJqcZgi8o8sPZM56otaQXMSCKQPdmxa1pznp3v8ecJtK3ifM2fZyiQpHLBSToKJAvhkp8shJtu9MdcnnOHWCXww9/wk60iVYwskOTDYb+vnL2efmKXzgmIlFizvJiCgEh5mS4RipTaQlGRXqUgIJIDHQ6IlJjOCYiIgoBImSlPQKTsDKgpY7DJKkB9qHNErPu5HKHWYJGKQJFEoOoP/8sdU7n5JndMROh3hVcZJ9CKC1X/chW9pCQANhJITApU+2GkatDcvD8mkARldT/pCEz3t7bOCYhI/wYBki8CuACgDqC2wlLJIn2pDOcEPmBmZ1OYR6RvWZ8HARFx9HLG4EpPWRqaLY+eJLlvqQEk95GcJDlZeyOfstsivcSseU4gshVhpXsC7zOzKZK/AOD7JP/HzB5vHZD0URsHgJGxLf49oyJ9h6g3evcS4YpWZmZTyZ+nATyCZtdUEVnEjKGtCMsOAiRHSI4uPAbwYQDPprUwkX6xkCfQj4cDGwE8wmZixQCAb5jZ9zr9heqcYfTlzskstXV+Isvw1AV/dYHWYJGKQJFEoNrJV90xAzdt9NcTqCzE12c7D5g5778PA5WQhgb9MRcvuWMi7cMs0BYNl9NJgookHTVmnc+4Wxb6cSzMsoOAmR0H8GsprkWkb/Xy1QFdIhTJmEF5AiIl1/8ZgyLiaDQUBERKy0yHAyKlp8MBkZLry0uEIhKnw4GEEagPd07omFsfSBYa9JddPe/frBRpDRYRSQSqvXrKHVPd+sv+mzlttEi/opJ5CUcAGjOBikCDfkIRLvrvZYFfk/QqKgGwQEIR1/qfT2V01B2Dc/6QBYbiUoIjeveuBpE+YsHNQ3IXyRdIHiN5oM2YT5J8nuRzJL/hzanDAZGsGWApXCIkWQVwH4APATgB4CjJCTN7vmXMVgB/AeC9ZjaT3OHbkfYERHKQ0l2EtwM4ZmbHzWwOwEMAdi8a83sA7jOzmeb72mlvUgUBkRyYxTbHJgCvtDw/kbzW6lYAt5L8d5I/IrnLm1SHAyIZ6/LegTGSky3Px5PCPFEDALYC2AlgM4DHSf6qmbU906sgIJI1Q/PSWMzZDlW7pwBsaXm+OXmt1QkAT5jZPICfkvxfNIPC0XZvqMMBkRykdDhwFMBWkreQHASwB8DEojHfRnMvACTH0Dw8ON5pUgUBkTykcI3QzGoA9gM4AuC/ATxsZs+RPEjyY8mwIwCmST4P4AcA/tzMpjvNm+/hgAGV+c4txNadCrSJSil0hSr5pDRPJBGo/pOOAbs5z/UbOg9oBK42bxxzh7AR+D5EEnhm/YQsBtp+WS0wTyARKNZiLO0cX6ZyiRAAzOwwgMOLXru35bEB+GyyheicgEjWdBehiKS+c5EiBQGRXGhPQKTctCcgUnIKAiIlltINRFlREBDJg/YEREpOlwgTFaA+1DnTpzHof1hrplNqE3Uh0Co9kBDjtgYD3IpAQCARCED9tc5txirr1vlrmfLbpnFkxB3TOOdXH6psCFTpqXdOIAMQa0M2H0g6qvtJUJHEpG5RewIiJRYtG1QQBQGRzFGHAyKlpz0BkZILnPYoioKASNa6KyqSOwUBkRzo6oBI2SkINLFmGHxtruOYgfNvpPJejdG1/qDImNnA9emZztfugVhnoEhBEC8PoDHr5yxUb7zRHWNv+N+HSqSIRyQHYNAvymKv+zkdlfV+ToIFOiKx6nfB6ifaExDJQS8fDriFukgeInma5LMtr72F5PdJ/iT584ZslymyyhljWwEi1fruB7C4gcEBAI+Z2VYAjyXPRWQphuYlwshWADcImNnjuLYH624ADySPHwDw8ZTXJdJXaLGtCMs9J7DRzE4mj18F4PfmFimzHj4nsOITg2ZmZPsYRnIfgH0AMDzk3yUn0pd6OAgst4L/KZI3A0DyZ9vOp2Y2bmY7zGzHmgH/9lSRfhM9FCjqcGC5QWACwN3J47sBfCed5Yj0qR6+OuAeDpB8EM3eZmMkTwD4HIAvAHiY5D0AXgLwycibsWGoXuycLFS73k/gGTj7ujumMhtIOroUGBNIHDH6sdQihUcCnYG8giCRRKD6mTPumEqgqIhV/H93qN9PIBGIgaIsDafgCgBYoOFfdeyt7hgE6tFc/cZdjs+RGwTMbG+bL92Z8lpE+hZ1F6FIiRV4vB+hICCSBwUBkZJTEBApt14+HFjuJUIR6RPaExDJQw/vCSgIiGTNdInwCiNhg53fsjLnd4hBw/9EbXrGn+cmP7GGgS5FHBp0xzRm/G49bPj/dq8zUKgiUKS70EX/3x1JqrHZS+4YBro8RToHRRKKIpWO6qfP+vN0S3sCIuVF6MSgiFhwc5DcRfIFksdIti3mQ/K3SRrJHd6cCgIiWUvpLkKSVQD3AbgLwDYAe0luW2LcKIA/AvBEZHkKAiJ5SGdP4HYAx8zsuJnNAXgIzSpfi/01gC8CCJXuVhAQyQEbsc2xCcArLc9PJK+9+T7kewBsMbPvRtemE4MieYifGBwjOdnyfNzMxiN/kWQFwN8B+HQ3S1MQEMla8KRf4qyZtTuZNwVgS8vzzclrC0YBvBvAD0kCwE0AJkh+zMxaA8tVFAREcpDSJcKjALaSvAXN//x7AHxq4Ytmdh7Alco0JH8I4M86BQAg7zZkZuB8IBnIm6cWmGNdoMXY5c5VjgDARvx5eDGQEDPoJxQhkDTTONc56SjSGixSESiSCFQ/O+3Pc0OgL00lUH+oVnOHeIlUAGCBpK3KBr+dGfziTIveuMvxS01hViO5H8ARAFUAh8zsOZIHAUya2cRy5tWegEgO0koWMrPDAA4veu3eNmN3RuZUEBDJQw9nDCoIiGSsyHLiEQoCInlQEBApN+0JiJSdgoBIySkIiJSYTgy+yapEfaRz0kwkmYgDfmswBpKFLJKkMu0nl2D9df6Yi34bMpv1E2LcRJZA5ZxQa7BARaBIIlB9xq/wxDV+IlXlukA1pEAiUKitXKQ9XbcUBETKTTUGRUpOhwMiZdbdXYS5UxAQyYOCgEh59Xq1YQUBkTwoCIiUG613o4CCgEjW1IasOzbgV72xQNsvnDnnj2n40ZmBRCCb9ZNLLPCbIKkL15mXDDS4xl/L64HWaoEqR5GKQJFEIJsPVHiqBVqMBdrTRZKFEGh51rXe3RHovSAg0o96+cSg+2uX5CGSp0k+2/La50lOkXw62T6a7TJFVrmU2pBlIdJ85H4Au5Z4/ctmtj3ZDi/xdREBUmtDlhU3CJjZ4wACB9gi0tYq3xNoZz/JZ5LDhba3k5HcR3KS5OT8vH9CSqTfLCQLrdo9gTa+AuAdALYDOAngS+0Gmtm4me0wsx1r1vi3g4r0IzYstBVhWUHAzE6ZWd3MGgC+ima3VBFZSvRQYDXtCZC8ueXpJwA8226siKTWlTgTbp4AyQcB7ESzW+oJAJ8DsJPkdjRj14sAPhN6NwPoJLtUZ/wKPCGBJJ9QO7O5eX/M5cv+ewWSbyzQast7r1Ai0LCfeGORhJnAeiMVgSKJQI0LF9wxHPLn4bz//eRooA2ZX3jpaj2cJ+D+ZJrZ3iVe/loGaxHpW72cLKSMQZGsGQDdQCRSbrqBSKTEVFREpOzMdDggUnbaExApOwUBkXLTnkDCKkRtXefKN5VZf0mc9ZNzIu3DLFCFB5GWVJHWVpGEorXD/nvNd67mU1nvJ7o0XjvvryWQUMSRlFqDBSoCRRKBQp/xsP8Z16dTvmnWEKpiVRTtCYjkoJcvEa7kVmIRiVq4QuBtDpK7SL5A8hjJA0t8/bMkn09u83+M5C95cyoIiOQgjXoCJKsA7gNwF4BtAPaS3LZo2FMAdpjZbQC+BeBvvLUpCIhkLb1biW8HcMzMjpvZHICHAOy+6q3MfmBmC3fh/QjAZm9SnRMQyVgzYzB8YnCM5GTL83EzG08ebwLwSsvXTgC4o8Nc9wD4F+8NFQRE8hA/MXjWzHas9O1I/i6AHQDe741VEBDJQUptyKYAbGl5vjl57er3Ij8I4C8BvN/M3OumCgIiWTNLK0/gKICtJG9B8z//HgCfah1A8tcB/COAXWZ2OjJprkGgcrmG4eNnOo6p3+Anu1S8VlwALFDJJ7UKO4HqQ6FEoEAbMm89dtGvzBRpiea2OwNgkUSgSNuvwJhQRaBAIlDjDT/5qxKYp9vKQmlkDJpZjeR+AEcAVAEcMrPnSB4EMGlmEwD+FsB1AP45aWv3spl9rNO82hMQyUNKdxEmjX4OL3rt3pbHH+x2TgUBkaypK7GIqJ6ASNn1bgxQEBDJQ0qXCDOhICCSNQNQVxAQKS3CtCcgUnoKAolqBY3rnfZgA4EbGyOVaAJVgyItuyIJPI1ZP0GnEmltFTh75LUqYyDxpjr2VndM/fRZd0xlg/9vskhlpkBCVqQ1WKQiUCQRKJJQ1DUFAZESM3RzA1HuFAREcqBzAiJlpyAgUmJmofNYRVEQEMlD78YABQGRPOicgEjZKQiIlJg6EL3p55dOnn30qYMvtbw0BsDPSukty1tzyp2tunT1mgM5UiGdi0StxLWfcZeVfNpKax7AberxJrUmv8LMbmx9TnIyjcqqedKas7fa1huiICBSYoZQzcaiKAiIZM4AUxBoZ9wf0nO05uyttvX6dDiwtJb2SquG1py91bZel64OiIj2BETKTkFApMTMQoVTiqIgIJIH7QmIlJyCgEiZpdaVOBMKAiJZM8CULCRSctoTECk5nRMQKTFdIhQRU6FRkTJTURGRcuvxG4gCjf9EZMWsEdscJHeRfIHkMZIHlvj6EMlvJl9/guTbvTkVBEQyZgCsYaGtE5JVAPcBuAvANgB7SW5bNOweADNm9k4AXwbwRW99CgIiWTNLa0/gdgDHzOy4mc0BeAjA7kVjdgN4IHn8LQB3kp1ba+ucgEgOLJ1LhJsAvNLy/ASAO9qNMbMayfMA3ooOFbIVBEQydgEzR/7VvjUWHD5McrLl+XjWlZYUBEQyZma7UppqCsCWluebk9eWGnOC5ACADQCmO02qcwIiq8dRAFtJ3kJyEMAeABOLxkwAuDt5/DsA/s2sc5KC9gREVonkGH8/gCMAqgAOmdlzJA8CmDSzCQBfA/B1ksfQ7Hu1x5uXTpAQkT6nwwGRklMQECk5BQGRklMQECk5BQGRklMQECk5BQGRklMQECm5/wfjh6ySl1d6FAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 288x288 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "A_pred, A = test_graphs[0]\n",
    "\n",
    "idx_graph = 5\n",
    "A_pred, A = mask_diagonal_and_sigmoid(A_pred)[idx_graph].cpu().detach().numpy(), A[idx_graph].cpu().numpy()\n",
    "\n",
    "\n",
    "plt.figure()\n",
    "plt.matshow(A)\n",
    "plt.colorbar()\n",
    "\n",
    "plt.figure()\n",
    "plt.matshow(A_pred)\n",
    "plt.colorbar()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(20, 20)"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "A_pred.shape"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
