{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A lot of this code is copied directly from the paper as it is not directly deep learning based. Slight adaptations were made to make it compatible with our code base\n",
    "https://github.com/InterpretableClustering/InterpretableClustering"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:516: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n",
      "C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n",
      "C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n",
      "C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n",
      "C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n",
      "C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n",
      "C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\tensorboard\\compat\\tensorflow_stub\\dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n",
      "C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\tensorboard\\compat\\tensorflow_stub\\dtypes.py:542: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n",
      "C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\tensorboard\\compat\\tensorflow_stub\\dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n",
      "C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\tensorboard\\compat\\tensorflow_stub\\dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n",
      "C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\tensorboard\\compat\\tensorflow_stub\\dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n",
      "C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\tensorboard\\compat\\tensorflow_stub\\dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n",
      "Using TensorFlow backend.\n",
      "Using backend: pytorch\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import random\n",
    "from utils import *\n",
    "\n",
    "from dynamicgem.embedding.dynAERNN  import DynAERNN \n",
    "import dgl  \n",
    "import scipy as sp\n",
    "import scipy.linalg as linalg\n",
    "import networkx as nx\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.cluster.vq import kmeans,vq\n",
    "from scipy import stats \n",
    "from sklearn.cluster import SpectralClustering\n",
    "from sklearn import metrics\n",
    "\n",
    "from itertools import permutations\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "#read in the data sets\n",
    "def read_data(id: str): #DBLP3, DBLP5, Brain, Reddit, DBLPE\n",
    "    #pick which dataset to load\n",
    "    dataset_dict=dict()\n",
    "    dataset_dict[\"DBLP3\"]=\"Datasets/DBLP3.npz\"\n",
    "    dataset_dict[\"DBLP5\"]=\"Datasets/DBLP5.npz\"\n",
    "    dataset_dict[\"Brain\"]=\"Datasets/Brain.npz\"\n",
    "    dataset_dict[\"Reddit\"]=\"Datasets/reddit.npz\"\n",
    "    dataset_dict[\"DBLPE\"]=\"Datasets/DBLPE.npz\"\n",
    "\n",
    "    dataset = np.load(dataset_dict[id])\n",
    "\n",
    "    #get the adjacency matrix\n",
    "    adjs = dataset[\"adjs\"] #(time, node, node)\n",
    "\n",
    "    #Remove nodes with no connections at any timestep\n",
    "    temporal_sum = np.add.reduce(adjs, axis=0, keepdims=False)\n",
    "    row_sum = np.add.reduce(temporal_sum, axis=0, keepdims=False)\n",
    "    non_zero_indices = np.flatnonzero(row_sum)\n",
    "    adjs = adjs[:,non_zero_indices,:]\n",
    "    adjs = adjs[:,:,non_zero_indices]\n",
    "\n",
    "    #DBLPE is a dynamic featureless graph\n",
    "    if id==\"DBLPE\":\n",
    "        labels = dataset[\"labels\"] #(nodes, time, class)\n",
    "\n",
    "        # labels = np.argmax(labels,axis=2)\n",
    "        labels=labels[non_zero_indices]\n",
    "        feats=np.zeros([adjs.shape[1], adjs.shape[0], adjs.shape[2]])\n",
    "\n",
    "        for i in range(feats.shape[1]):\n",
    "            feats[:,i,:]=np.eye(feats.shape[0])\n",
    "      \n",
    "    #All others are static feature-full graphs\n",
    "    else:\n",
    "        labels = dataset[\"labels\"] #(nodes, class)\n",
    "        feats = dataset[\"attmats\"] #(node, time, feat)\n",
    "\n",
    "        # labels = np.argmax(labels, axis=1)\n",
    "        labels = labels[non_zero_indices]\n",
    "        feats = feats[non_zero_indices]\n",
    "\n",
    "    #Other important variables\n",
    "    n_nodes = adjs.shape[1]\n",
    "    n_timesteps = adjs.shape[0]\n",
    "    n_class = int(labels.shape[1])\n",
    "    n_feat = feats.shape[2]\n",
    "\n",
    "    #Train Val Test split\n",
    "    nodes_id = list(range(n_nodes))\n",
    "    random.shuffle(nodes_id)\n",
    "    idx_train = nodes_id[:(7*n_nodes)//10]\n",
    "    idx_train = [True if i in idx_train else False for i in list(range(n_nodes))]\n",
    "    idx_val = nodes_id[(7*n_nodes)//10: (9*n_nodes)//10]\n",
    "    idx_val = [True if i in idx_val else False for i in list(range(n_nodes))]\n",
    "    idx_test = nodes_id[(9*n_nodes)//10: n_nodes]\n",
    "    idx_test = [True if i in idx_test else False for i in list(range(n_nodes))]\n",
    "\n",
    "    #custom data type that holds everything i might need\n",
    "    return STG_Dataset(np.array(adjs),\n",
    "                        np.array(adjs),\n",
    "                        np.array(feats), \n",
    "                        np.array(feats), \n",
    "                        np.array(labels), \n",
    "                        np.array(labels), \n",
    "                        n_nodes, n_timesteps, n_class, n_feat, \n",
    "                        np.array(idx_train),\n",
    "                        np.array(idx_val),\n",
    "                        np.array(idx_test))\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "#dynaernn is a prebuilt model\n",
    "def dynaernn(data):\n",
    "    #parameters\n",
    "    length=data.n_timestamps\n",
    "    lookup=length-2\n",
    "    dim_emb  = data.n_class\n",
    "          \n",
    "    #set device to be used\n",
    "    tf.device('/gpu:0')\n",
    "\n",
    "    #embedding is the model\n",
    "    embedding = DynAERNN(d   = dim_emb,\n",
    "        beta           = 5,\n",
    "        n_prev_graphs  = lookup,\n",
    "        nu1            = 1e-6,\n",
    "        nu2            = 1e-6,\n",
    "        n_aeunits      = [50, 30],\n",
    "        n_lstmunits    = [50,dim_emb],\n",
    "        rho            = 0.3,\n",
    "        n_iter         = 2,\n",
    "        xeta           = 1e-3,\n",
    "        n_batch        = 10,\n",
    "        modelfile      = ['./intermediate/enc_model_dynAERNN.json', \n",
    "                            './intermediate/dec_model_dynAERNN.json'],\n",
    "        weightfile     = ['./intermediate/enc_weights_dynAERNN.hdf5', \n",
    "                            './intermediate/dec_weights_dynAERNN.hdf5'],\n",
    "        savefilesuffix = \"testing\")\n",
    "    embs = []\n",
    "\n",
    "    #for each of the adjacency graphs\n",
    "    graphs     = [nx.Graph(data.adjs[l,:,:]) for l in range(length)]\n",
    "    #find the embedding of the graph\n",
    "    for temp_var in range(lookup, length):\n",
    "                    emb, _ = embedding.learn_embeddings(graphs[:temp_var])\n",
    "                    embs.append(emb)\n",
    "    #find the centroid\n",
    "    centroid=kmeans(embs[-1],data.n_class)[0] #change kSigvec from complex64 to float\n",
    "    #find the result\n",
    "    result=vq(embs[-1],centroid)[0]\n",
    "\n",
    "\n",
    "    #permute all of the classes\n",
    "    perm = permutations(range(data.n_class)) \n",
    "    #convert to one hot for easy analysis\n",
    "    one_hot_result=one_hot(result,data.n_class)\n",
    "\n",
    "    #initializations\n",
    "    acc_test=0\n",
    "    f1_test=0\n",
    "    auc_test=0\n",
    "    count=0\n",
    "\n",
    "    #for each class permutation\n",
    "    for i in perm: \n",
    "        #calculate all metrics\n",
    "        count+=1\n",
    "        one_hot_i=one_hot(np.array(i))\n",
    "        perm_result=np.matmul(one_hot_result,one_hot_i)\n",
    "        labels = np.argmax(data.labels,axis=1)\n",
    "        pred_labels=np.argmax(perm_result,axis=1)\n",
    "        acc_test = max(metrics.accuracy_score(labels,pred_labels),acc_test)\n",
    "        f1_test=max(metrics.f1_score(labels, pred_labels,average='weighted'),f1_test)\n",
    "        auc_test=max(metrics.roc_auc_score(one_hot(labels), perm_result,multi_class='ovr',average='weighted'),auc_test)\n",
    "        #every 10000 update me\n",
    "        if count%10000==0:\n",
    "            print(count)\n",
    "            print(acc_test,f1_test,auc_test)   \n",
    "    print(str(acc_test)+'\\t'+str(f1_test)+'\\t'+str(auc_test))  \n",
    "\n",
    "    #carry over code, not necessary\n",
    "    try:\n",
    "        spec_norm=getKlargestSigVec(adj-Probability_matrix,2)[0]\n",
    "    except:\n",
    "        spec_norm=[]\n",
    "    return 0,acc_test,spec_norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def spectral(data):\n",
    "    adj = np.add.reduce(data.adjs_timestep, axis=0, keepdims=False, dtype=np.float32)\n",
    "\n",
    "    #normalize the adj matrix\n",
    "    adj += np.eye(adj.shape[0], dtype=np.float32)\n",
    "    d = np.add.reduce(adj, axis=1)\n",
    "    normalizing_matrix = np.zeros((adj.shape[0], adj.shape[0]))\n",
    "    normalizing_matrix[range(len(normalizing_matrix)), range(len(normalizing_matrix))] = d**(-0.5)\n",
    "    adj = np.matmul(normalizing_matrix,adj)\n",
    "    adj=np.matmul(np.matmul(normalizing_matrix,adj), normalizing_matrix)\n",
    "\n",
    "    #conversion\n",
    "    Lbar=np.array(adj)  #no normalizaton\n",
    "    top_k=data.n_class\n",
    "    #get sig values and eignevectors\n",
    "    kSigVal,kSigVec=getKlargestSigVec(Lbar,top_k)\n",
    "    #find the centroid\n",
    "    centroid=kmeans(kSigVec.astype(float),data.n_class)[0] #change kSigvec from complex64 to float\n",
    "    #find the result\n",
    "    result=vq(kSigVec.astype(float),centroid)[0]\n",
    "\n",
    "    #for each class permutation\n",
    "    perm = permutations(range(data.n_class)) \n",
    "    #calculate all metrics\n",
    "    one_hot_result=one_hot(result,data.n_class)\n",
    "    #initializations\n",
    "    acc_test=0\n",
    "    f1_test=0\n",
    "    auc_test=0\n",
    "    count=0\n",
    "    #for each class permutation\n",
    "    for i in perm: \n",
    "        #calculate all metrics\n",
    "        count+=1\n",
    "        one_hot_i=one_hot(np.array(i))\n",
    "        perm_result=np.matmul(one_hot_result,one_hot_i)\n",
    "        labels = np.argmax(data.labels,axis=1)\n",
    "        pred_labels=np.argmax(perm_result,axis=1)\n",
    "        acc_test = max(metrics.accuracy_score(labels,pred_labels),acc_test)\n",
    "        f1_test=max(metrics.f1_score(labels, pred_labels,average='weighted'),f1_test)\n",
    "        auc_test=max(metrics.roc_auc_score(one_hot(labels), perm_result,multi_class='ovr',average='weighted'),auc_test)\n",
    "        #every 10000 update me\n",
    "        if count%10000==0:\n",
    "            print(count)\n",
    "            print(acc_test,f1_test,auc_test)   \n",
    "    print(str(acc_test)+'\\t'+str(f1_test)+'\\t'+str(auc_test))  \n",
    "    try:\n",
    "        spec_norm=getKlargestSigVec(adj-Probability_matrix,2)[0]\n",
    "    except:\n",
    "        spec_norm=[]\n",
    "    return 0,acc_test,spec_norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def getKlargestSigVec(Lbar,k):\n",
    "\t\"\"\"input\n",
    "\t\"matrix Lbar and k\n",
    "\t\"return\n",
    "\t\"k largest singular values and their corresponding eigen vectors\n",
    "\t\"\"\"\n",
    "\tlsigvec,sigval,rsigvec=linalg.svd(Lbar)\n",
    "\tdim=len(sigval)\n",
    " \n",
    "\t#find top k largest left sigval\n",
    "\tdictSigval=dict(zip(sigval,range(0,dim)))\n",
    "\tkSig=np.sort(sigval)[::-1][:k]#[0:k]\n",
    "\tix=[dictSigval[k] for k in kSig]\n",
    "\treturn sigval[ix],lsigvec[:,ix]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "#basic one hot encoding\n",
    "def one_hot(l,classnum=1): #classnum fix some special case\n",
    "    one_hot_l=np.zeros((len(l),max(l.max()+1,classnum)))\n",
    "    for i in range(len(l)):\n",
    "        one_hot_l[i][l[i]]=1\n",
    "    return one_hot_l"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "STG_Dataset(adjs=array([[[1., 1., 1., ..., 0., 0., 0.],\n",
       "        [1., 1., 1., ..., 0., 0., 0.],\n",
       "        [1., 1., 1., ..., 0., 0., 0.],\n",
       "        ...,\n",
       "        [0., 0., 0., ..., 1., 0., 0.],\n",
       "        [0., 0., 0., ..., 0., 1., 0.],\n",
       "        [0., 0., 0., ..., 0., 0., 1.]],\n",
       "\n",
       "       [[1., 1., 1., ..., 0., 0., 0.],\n",
       "        [1., 1., 1., ..., 0., 0., 0.],\n",
       "        [1., 1., 1., ..., 0., 0., 0.],\n",
       "        ...,\n",
       "        [0., 0., 0., ..., 1., 0., 0.],\n",
       "        [0., 0., 0., ..., 0., 1., 0.],\n",
       "        [0., 0., 0., ..., 0., 0., 1.]],\n",
       "\n",
       "       [[1., 1., 1., ..., 0., 0., 0.],\n",
       "        [1., 1., 1., ..., 0., 0., 0.],\n",
       "        [1., 1., 1., ..., 0., 0., 0.],\n",
       "        ...,\n",
       "        [0., 0., 0., ..., 1., 0., 0.],\n",
       "        [0., 0., 0., ..., 0., 1., 0.],\n",
       "        [0., 0., 0., ..., 0., 0., 1.]],\n",
       "\n",
       "       ...,\n",
       "\n",
       "       [[1., 1., 1., ..., 0., 0., 0.],\n",
       "        [1., 1., 1., ..., 0., 0., 0.],\n",
       "        [1., 1., 1., ..., 0., 0., 0.],\n",
       "        ...,\n",
       "        [0., 0., 0., ..., 1., 0., 0.],\n",
       "        [0., 0., 0., ..., 0., 1., 0.],\n",
       "        [0., 0., 0., ..., 0., 0., 1.]],\n",
       "\n",
       "       [[1., 1., 1., ..., 0., 0., 0.],\n",
       "        [1., 1., 1., ..., 0., 0., 0.],\n",
       "        [1., 1., 1., ..., 0., 0., 0.],\n",
       "        ...,\n",
       "        [0., 0., 0., ..., 1., 0., 0.],\n",
       "        [0., 0., 0., ..., 0., 1., 1.],\n",
       "        [0., 0., 0., ..., 0., 1., 1.]],\n",
       "\n",
       "       [[1., 1., 1., ..., 0., 0., 0.],\n",
       "        [1., 1., 1., ..., 0., 0., 0.],\n",
       "        [1., 1., 1., ..., 0., 0., 0.],\n",
       "        ...,\n",
       "        [0., 0., 0., ..., 1., 0., 0.],\n",
       "        [0., 0., 0., ..., 0., 1., 0.],\n",
       "        [0., 0., 0., ..., 0., 0., 1.]]]), adjs_timestep=array([[[1., 1., 1., ..., 0., 0., 0.],\n",
       "        [1., 1., 1., ..., 0., 0., 0.],\n",
       "        [1., 1., 1., ..., 0., 0., 0.],\n",
       "        ...,\n",
       "        [0., 0., 0., ..., 1., 0., 0.],\n",
       "        [0., 0., 0., ..., 0., 1., 0.],\n",
       "        [0., 0., 0., ..., 0., 0., 1.]],\n",
       "\n",
       "       [[1., 1., 1., ..., 0., 0., 0.],\n",
       "        [1., 1., 1., ..., 0., 0., 0.],\n",
       "        [1., 1., 1., ..., 0., 0., 0.],\n",
       "        ...,\n",
       "        [0., 0., 0., ..., 1., 0., 0.],\n",
       "        [0., 0., 0., ..., 0., 1., 0.],\n",
       "        [0., 0., 0., ..., 0., 0., 1.]],\n",
       "\n",
       "       [[1., 1., 1., ..., 0., 0., 0.],\n",
       "        [1., 1., 1., ..., 0., 0., 0.],\n",
       "        [1., 1., 1., ..., 0., 0., 0.],\n",
       "        ...,\n",
       "        [0., 0., 0., ..., 1., 0., 0.],\n",
       "        [0., 0., 0., ..., 0., 1., 0.],\n",
       "        [0., 0., 0., ..., 0., 0., 1.]],\n",
       "\n",
       "       ...,\n",
       "\n",
       "       [[1., 1., 1., ..., 0., 0., 0.],\n",
       "        [1., 1., 1., ..., 0., 0., 0.],\n",
       "        [1., 1., 1., ..., 0., 0., 0.],\n",
       "        ...,\n",
       "        [0., 0., 0., ..., 1., 0., 0.],\n",
       "        [0., 0., 0., ..., 0., 1., 0.],\n",
       "        [0., 0., 0., ..., 0., 0., 1.]],\n",
       "\n",
       "       [[1., 1., 1., ..., 0., 0., 0.],\n",
       "        [1., 1., 1., ..., 0., 0., 0.],\n",
       "        [1., 1., 1., ..., 0., 0., 0.],\n",
       "        ...,\n",
       "        [0., 0., 0., ..., 1., 0., 0.],\n",
       "        [0., 0., 0., ..., 0., 1., 1.],\n",
       "        [0., 0., 0., ..., 0., 1., 1.]],\n",
       "\n",
       "       [[1., 1., 1., ..., 0., 0., 0.],\n",
       "        [1., 1., 1., ..., 0., 0., 0.],\n",
       "        [1., 1., 1., ..., 0., 0., 0.],\n",
       "        ...,\n",
       "        [0., 0., 0., ..., 1., 0., 0.],\n",
       "        [0., 0., 0., ..., 0., 1., 0.],\n",
       "        [0., 0., 0., ..., 0., 0., 1.]]]), feats=array([[[ 2.93138406e+00,  2.71742502e+00,  2.02877947e+00, ...,\n",
       "          3.15081253e-01, -1.00153006e-01, -2.91507221e-02],\n",
       "        [-1.47461175e+00, -1.58453631e+00, -8.66416540e-01, ...,\n",
       "          6.39972224e-01,  7.50973463e-01,  1.68464292e-01],\n",
       "        [ 1.55804106e+00, -2.14564332e+00,  5.91066934e-01, ...,\n",
       "          6.08396813e-01, -4.66517779e-01,  6.55890697e-02],\n",
       "        ...,\n",
       "        [ 2.32409961e+00,  9.73057426e-01, -2.45927919e+00, ...,\n",
       "          2.16803578e-01,  8.47853776e-01,  1.06558255e+00],\n",
       "        [-1.23653956e+00, -1.57395665e+00, -1.00887414e+00, ...,\n",
       "         -3.24659246e-01,  2.31697745e-01,  9.27076559e-01],\n",
       "        [-2.62599830e+00, -1.10942746e+00,  8.20539578e-01, ...,\n",
       "          4.16532796e-02, -5.18122719e-01,  9.24578586e-01]],\n",
       "\n",
       "       [[ 8.52821826e+00,  3.58951036e+00,  1.46914809e+00, ...,\n",
       "          3.75203360e-01, -6.25842659e-03, -7.89301257e-02],\n",
       "        [ 4.77120953e+00,  4.22594906e-01, -1.35995508e+00, ...,\n",
       "          7.23495468e-01,  3.31812551e-01, -1.78128130e-01],\n",
       "        [ 3.84976478e+00, -1.43112030e+00,  3.37267017e-02, ...,\n",
       "         -1.81939521e-02, -2.28613344e-01, -1.80574750e-01],\n",
       "        ...,\n",
       "        [ 9.55216137e-02,  2.53785067e-01, -1.02241391e+00, ...,\n",
       "          2.83347703e-01,  2.37848964e-01,  6.35258054e-01],\n",
       "        [-2.91723666e+00, -1.39992688e+00, -3.04946277e-01, ...,\n",
       "         -2.78556356e-01,  3.37992062e-01,  8.14526620e-01],\n",
       "        [-4.17054317e+00, -1.14877357e+00,  9.26227166e-01, ...,\n",
       "         -3.05063723e-01, -1.81174674e-01,  6.73323795e-01]],\n",
       "\n",
       "       [[ 5.59852601e+00,  3.97787991e+00,  1.98605508e+00, ...,\n",
       "          2.99125206e-01,  1.61002275e-01, -1.56042058e-01],\n",
       "        [ 1.49517753e+00, -2.81448285e-01, -1.23307006e+00, ...,\n",
       "          3.08222094e-01,  3.07199113e-01,  1.39141936e-02],\n",
       "        [ 2.20271192e+00, -2.02663807e+00, -1.06757042e-01, ...,\n",
       "          2.77344605e-01, -2.22992290e-01, -2.58803508e-01],\n",
       "        ...,\n",
       "        [ 2.09145368e+00,  6.19908933e-01, -1.81310294e+00, ...,\n",
       "          3.01964806e-01,  3.18539718e-01,  6.57088306e-01],\n",
       "        [-1.46294369e+00, -1.34519279e+00, -5.18675203e-01, ...,\n",
       "          4.64181943e-02,  5.11771388e-01,  8.71544862e-01],\n",
       "        [-2.90706936e+00, -1.10065231e+00,  8.37228952e-01, ...,\n",
       "          7.34385850e-03, -2.90840321e-01,  8.57165795e-01]],\n",
       "\n",
       "       ...,\n",
       "\n",
       "       [[ 2.33464301e+00,  7.69308280e-01,  2.00183198e+00, ...,\n",
       "         -1.45372166e-01,  4.09646617e-01, -7.09866566e-02],\n",
       "        [ 4.99634452e+00, -6.30644749e-01, -1.79011808e+00, ...,\n",
       "          7.33277701e-01,  2.15224853e-01, -3.14556542e-01],\n",
       "        [ 2.11398046e+00, -6.61167833e-01,  4.11500024e-01, ...,\n",
       "          1.55077791e-01,  2.68693833e-01, -1.41754316e-01],\n",
       "        ...,\n",
       "        [ 1.66111677e+00, -6.88564449e-01,  2.63617171e+00, ...,\n",
       "          8.52422996e-02, -5.19956301e-01,  1.04606242e-01],\n",
       "        [-5.20700067e+00, -1.63056000e+00, -1.87059396e+00, ...,\n",
       "          4.13141398e-01, -1.88544993e-01,  5.88912416e-01],\n",
       "        [-2.55430142e+00, -1.37558841e+00, -6.02125809e-01, ...,\n",
       "         -6.58926939e-01, -1.43563936e-01,  2.38516809e-02]],\n",
       "\n",
       "       [[ 3.89366261e+00,  1.54358319e+00,  1.66455965e+00, ...,\n",
       "          9.34796190e-02,  6.87181712e-01, -2.24318340e-01],\n",
       "        [ 6.23053110e+00, -7.98057005e-01, -2.71353174e+00, ...,\n",
       "          1.10604161e+00, -4.19309542e-01,  1.26272889e-03],\n",
       "        [ 3.51968075e+00, -1.06928300e+00, -1.43978023e-01, ...,\n",
       "          1.74826929e-01,  1.97073885e-01, -4.50263290e-01],\n",
       "        ...,\n",
       "        [ 3.14001533e+00, -8.86196821e-01,  2.78883305e+00, ...,\n",
       "         -9.86043015e-02, -6.34122448e-01, -1.24328560e-01],\n",
       "        [-5.73593019e+00, -1.94659104e+00, -2.14398914e+00, ...,\n",
       "          7.47610540e-01, -2.36949688e-01,  9.71212752e-01],\n",
       "        [-3.14882633e+00, -1.54628706e+00, -7.74910945e-01, ...,\n",
       "         -8.46581121e-01,  3.35656430e-01, -1.57022377e-01]],\n",
       "\n",
       "       [[ 2.82528028e+00,  1.01916556e+00,  2.83550893e+00, ...,\n",
       "         -3.49299556e-01,  8.39011608e-01, -1.33985757e-01],\n",
       "        [ 5.91323162e+00, -2.57793651e-01, -1.77191356e+00, ...,\n",
       "          8.88027998e-01,  5.20197016e-01, -1.00435688e+00],\n",
       "        [ 2.68645288e+00, -5.72841833e-01,  6.66139887e-01, ...,\n",
       "          7.92110183e-01,  2.40546658e-01, -5.54907088e-01],\n",
       "        ...,\n",
       "        [ 3.18455609e+00, -2.18213857e-01,  3.43769556e+00, ...,\n",
       "          6.49864429e-01, -7.26173042e-01,  1.20061263e+00],\n",
       "        [-4.43332185e+00, -8.14903245e-01, -1.36189238e+00, ...,\n",
       "          1.04807722e+00, -9.26878066e-01,  4.40942004e-01],\n",
       "        [-2.06460775e+00, -1.16549082e+00, -8.55053501e-01, ...,\n",
       "         -9.38559361e-01, -1.87949089e-01, -1.97852940e-01]]]), feats_timestep=array([[[ 2.93138406e+00,  2.71742502e+00,  2.02877947e+00, ...,\n",
       "          3.15081253e-01, -1.00153006e-01, -2.91507221e-02],\n",
       "        [-1.47461175e+00, -1.58453631e+00, -8.66416540e-01, ...,\n",
       "          6.39972224e-01,  7.50973463e-01,  1.68464292e-01],\n",
       "        [ 1.55804106e+00, -2.14564332e+00,  5.91066934e-01, ...,\n",
       "          6.08396813e-01, -4.66517779e-01,  6.55890697e-02],\n",
       "        ...,\n",
       "        [ 2.32409961e+00,  9.73057426e-01, -2.45927919e+00, ...,\n",
       "          2.16803578e-01,  8.47853776e-01,  1.06558255e+00],\n",
       "        [-1.23653956e+00, -1.57395665e+00, -1.00887414e+00, ...,\n",
       "         -3.24659246e-01,  2.31697745e-01,  9.27076559e-01],\n",
       "        [-2.62599830e+00, -1.10942746e+00,  8.20539578e-01, ...,\n",
       "          4.16532796e-02, -5.18122719e-01,  9.24578586e-01]],\n",
       "\n",
       "       [[ 8.52821826e+00,  3.58951036e+00,  1.46914809e+00, ...,\n",
       "          3.75203360e-01, -6.25842659e-03, -7.89301257e-02],\n",
       "        [ 4.77120953e+00,  4.22594906e-01, -1.35995508e+00, ...,\n",
       "          7.23495468e-01,  3.31812551e-01, -1.78128130e-01],\n",
       "        [ 3.84976478e+00, -1.43112030e+00,  3.37267017e-02, ...,\n",
       "         -1.81939521e-02, -2.28613344e-01, -1.80574750e-01],\n",
       "        ...,\n",
       "        [ 9.55216137e-02,  2.53785067e-01, -1.02241391e+00, ...,\n",
       "          2.83347703e-01,  2.37848964e-01,  6.35258054e-01],\n",
       "        [-2.91723666e+00, -1.39992688e+00, -3.04946277e-01, ...,\n",
       "         -2.78556356e-01,  3.37992062e-01,  8.14526620e-01],\n",
       "        [-4.17054317e+00, -1.14877357e+00,  9.26227166e-01, ...,\n",
       "         -3.05063723e-01, -1.81174674e-01,  6.73323795e-01]],\n",
       "\n",
       "       [[ 5.59852601e+00,  3.97787991e+00,  1.98605508e+00, ...,\n",
       "          2.99125206e-01,  1.61002275e-01, -1.56042058e-01],\n",
       "        [ 1.49517753e+00, -2.81448285e-01, -1.23307006e+00, ...,\n",
       "          3.08222094e-01,  3.07199113e-01,  1.39141936e-02],\n",
       "        [ 2.20271192e+00, -2.02663807e+00, -1.06757042e-01, ...,\n",
       "          2.77344605e-01, -2.22992290e-01, -2.58803508e-01],\n",
       "        ...,\n",
       "        [ 2.09145368e+00,  6.19908933e-01, -1.81310294e+00, ...,\n",
       "          3.01964806e-01,  3.18539718e-01,  6.57088306e-01],\n",
       "        [-1.46294369e+00, -1.34519279e+00, -5.18675203e-01, ...,\n",
       "          4.64181943e-02,  5.11771388e-01,  8.71544862e-01],\n",
       "        [-2.90706936e+00, -1.10065231e+00,  8.37228952e-01, ...,\n",
       "          7.34385850e-03, -2.90840321e-01,  8.57165795e-01]],\n",
       "\n",
       "       ...,\n",
       "\n",
       "       [[ 2.33464301e+00,  7.69308280e-01,  2.00183198e+00, ...,\n",
       "         -1.45372166e-01,  4.09646617e-01, -7.09866566e-02],\n",
       "        [ 4.99634452e+00, -6.30644749e-01, -1.79011808e+00, ...,\n",
       "          7.33277701e-01,  2.15224853e-01, -3.14556542e-01],\n",
       "        [ 2.11398046e+00, -6.61167833e-01,  4.11500024e-01, ...,\n",
       "          1.55077791e-01,  2.68693833e-01, -1.41754316e-01],\n",
       "        ...,\n",
       "        [ 1.66111677e+00, -6.88564449e-01,  2.63617171e+00, ...,\n",
       "          8.52422996e-02, -5.19956301e-01,  1.04606242e-01],\n",
       "        [-5.20700067e+00, -1.63056000e+00, -1.87059396e+00, ...,\n",
       "          4.13141398e-01, -1.88544993e-01,  5.88912416e-01],\n",
       "        [-2.55430142e+00, -1.37558841e+00, -6.02125809e-01, ...,\n",
       "         -6.58926939e-01, -1.43563936e-01,  2.38516809e-02]],\n",
       "\n",
       "       [[ 3.89366261e+00,  1.54358319e+00,  1.66455965e+00, ...,\n",
       "          9.34796190e-02,  6.87181712e-01, -2.24318340e-01],\n",
       "        [ 6.23053110e+00, -7.98057005e-01, -2.71353174e+00, ...,\n",
       "          1.10604161e+00, -4.19309542e-01,  1.26272889e-03],\n",
       "        [ 3.51968075e+00, -1.06928300e+00, -1.43978023e-01, ...,\n",
       "          1.74826929e-01,  1.97073885e-01, -4.50263290e-01],\n",
       "        ...,\n",
       "        [ 3.14001533e+00, -8.86196821e-01,  2.78883305e+00, ...,\n",
       "         -9.86043015e-02, -6.34122448e-01, -1.24328560e-01],\n",
       "        [-5.73593019e+00, -1.94659104e+00, -2.14398914e+00, ...,\n",
       "          7.47610540e-01, -2.36949688e-01,  9.71212752e-01],\n",
       "        [-3.14882633e+00, -1.54628706e+00, -7.74910945e-01, ...,\n",
       "         -8.46581121e-01,  3.35656430e-01, -1.57022377e-01]],\n",
       "\n",
       "       [[ 2.82528028e+00,  1.01916556e+00,  2.83550893e+00, ...,\n",
       "         -3.49299556e-01,  8.39011608e-01, -1.33985757e-01],\n",
       "        [ 5.91323162e+00, -2.57793651e-01, -1.77191356e+00, ...,\n",
       "          8.88027998e-01,  5.20197016e-01, -1.00435688e+00],\n",
       "        [ 2.68645288e+00, -5.72841833e-01,  6.66139887e-01, ...,\n",
       "          7.92110183e-01,  2.40546658e-01, -5.54907088e-01],\n",
       "        ...,\n",
       "        [ 3.18455609e+00, -2.18213857e-01,  3.43769556e+00, ...,\n",
       "          6.49864429e-01, -7.26173042e-01,  1.20061263e+00],\n",
       "        [-4.43332185e+00, -8.14903245e-01, -1.36189238e+00, ...,\n",
       "          1.04807722e+00, -9.26878066e-01,  4.40942004e-01],\n",
       "        [-2.06460775e+00, -1.16549082e+00, -8.55053501e-01, ...,\n",
       "         -9.38559361e-01, -1.87949089e-01, -1.97852940e-01]]]), labels=array([[0., 0., 1., ..., 0., 0., 0.],\n",
       "       [0., 0., 0., ..., 0., 0., 0.],\n",
       "       [0., 0., 0., ..., 0., 0., 0.],\n",
       "       ...,\n",
       "       [0., 0., 0., ..., 0., 0., 0.],\n",
       "       [0., 0., 0., ..., 0., 0., 0.],\n",
       "       [0., 0., 0., ..., 0., 0., 0.]]), labels_timestep=array([[0., 0., 1., ..., 0., 0., 0.],\n",
       "       [0., 0., 0., ..., 0., 0., 0.],\n",
       "       [0., 0., 0., ..., 0., 0., 0.],\n",
       "       ...,\n",
       "       [0., 0., 0., ..., 0., 0., 0.],\n",
       "       [0., 0., 0., ..., 0., 0., 0.],\n",
       "       [0., 0., 0., ..., 0., 0., 0.]]), n_nodes=5000, n_timestamps=12, n_class=10, n_feat=20, idx_train=array([ True,  True,  True, ...,  True, False,  True]), idx_val=array([False, False, False, ..., False,  True, False]), idx_test=array([False, False, False, ..., False, False, False]))"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "read_data(\"Brain\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.7594306049822064\t0.6654759318224387\t0.5037102833514956\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(0, 0.7594306049822064, [])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "spectral(read_data(\"DBLP3\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\dynamicgem\\embedding\\dynAERNN.py:124: The name tf.ConfigProto is deprecated. Please use tf.compat.v1.ConfigProto instead.\n",
      "\n",
      "WARNING:tensorflow:From C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\dynamicgem\\embedding\\dynAERNN.py:134: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.\n",
      "\n",
      "WARNING:tensorflow:From C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\keras\\backend\\tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.\n",
      "\n",
      "WARNING:tensorflow:From C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\keras\\backend\\tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.\n",
      "\n",
      "WARNING:tensorflow:From C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\keras\\backend\\tensorflow_backend.py:4138: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\dynamicgem\\utils\\dnn_utils.py:491: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(50, activation=<keras.lay..., kernel_regularizer=<keras.reg...)`\n",
      "  W_regularizer=Reg.l1_l2(l1=nu1, l2=nu2))(y[i])\n",
      "C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\keras\\activations.py:211: UserWarning: Do not pass a layer instance (such as LeakyReLU) as the activation argument of another layer. Instead, advanced activation layers should be used just like any other layer in a model.\n",
      "  identifier=identifier.__class__.__name__))\n",
      "C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\dynamicgem\\utils\\dnn_utils.py:491: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(30, activation=<keras.lay..., kernel_regularizer=<keras.reg...)`\n",
      "  W_regularizer=Reg.l1_l2(l1=nu1, l2=nu2))(y[i])\n",
      "C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\dynamicgem\\utils\\dnn_utils.py:493: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(3, activation=<keras.lay..., kernel_regularizer=<keras.reg...)`\n",
      "  W_regularizer=Reg.l1_l2(l1=nu1, l2=nu2))(y[K - 1])\n",
      "C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\dynamicgem\\utils\\dnn_utils.py:495: UserWarning: Update your `Model` call to the Keras 2 API: `Model(inputs=Tensor(\"in..., outputs=Tensor(\"de...)`\n",
      "  encoder = Model(input=x, output=y[K])\n",
      "C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\dynamicgem\\utils\\dnn_utils.py:533: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(30, activation=<keras.lay..., kernel_regularizer=<keras.reg...)`\n",
      "  W_regularizer=Reg.l1_l2(l1=nu1, l2=nu2))(y_hat[i + 1])\n",
      "C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\dynamicgem\\utils\\dnn_utils.py:533: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(50, activation=<keras.lay..., kernel_regularizer=<keras.reg...)`\n",
      "  W_regularizer=Reg.l1_l2(l1=nu1, l2=nu2))(y_hat[i + 1])\n",
      "C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\dynamicgem\\utils\\dnn_utils.py:535: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(1405, activation=<keras.lay..., kernel_regularizer=<keras.reg...)`\n",
      "  W_regularizer=Reg.l1_l2(l1=nu1, l2=nu2))(y_hat[1])\n",
      "C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\dynamicgem\\utils\\dnn_utils.py:540: UserWarning: Update your `Model` call to the Keras 2 API: `Model(inputs=Tensor(\"in..., outputs=Tensor(\"de...)`\n",
      "  decoder = Model(input=y, output=x_hat)\n",
      "C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\dynamicgem\\utils\\dnn_utils.py:885: UserWarning: Update your `Model` call to the Keras 2 API: `Model(inputs=Tensor(\"in..., outputs=[<tf.Tenso...)`\n",
      "  autoencoder = Model(input=x_in, output=[x_hat, y])\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\keras\\optimizers.py:790: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\dynamicgem\\embedding\\dynAERNN.py:205: UserWarning: Update your `Model` call to the Keras 2 API: `Model(inputs=[<tf.Tenso..., outputs=Tensor(\"su...)`\n",
      "  self._model = Model(input=[x_in, x_pred], output=x_diff)\n",
      "C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\dynamicgem\\embedding\\dynAERNN.py:225: UserWarning: The semantics of the Keras 2 argument `steps_per_epoch` is not the same as the Keras 1 argument `samples_per_epoch`. `steps_per_epoch` is the number of batches to draw from the generator at each epoch. Basically steps_per_epoch = samples_per_epoch/batch_size. Similarly `nb_val_samples`->`validation_steps` and `val_samples`->`steps` arguments have changed. Update your method calls accordingly.\n",
      "  verbose=1\n",
      "C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\dynamicgem\\embedding\\dynAERNN.py:225: UserWarning: Update your `fit_generator` call to the Keras 2 API: `fit_generator(generator=<generator..., verbose=1, steps_per_epoch=1124, epochs=2)`\n",
      "  verbose=1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\tensorflow\\python\\ops\\math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n",
      "Epoch 1/2\n",
      "1124/1124 [==============================] - 23s 20ms/step - loss: 7.6121e-04\n",
      "Epoch 2/2\n",
      "1124/1124 [==============================] - 18s 16ms/step - loss: 7.7946e-05\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\dynamicgem\\utils\\dnn_utils.py:491: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(50, activation=<keras.lay..., kernel_regularizer=<keras.reg...)`\n",
      "  W_regularizer=Reg.l1_l2(l1=nu1, l2=nu2))(y[i])\n",
      "C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\keras\\activations.py:211: UserWarning: Do not pass a layer instance (such as LeakyReLU) as the activation argument of another layer. Instead, advanced activation layers should be used just like any other layer in a model.\n",
      "  identifier=identifier.__class__.__name__))\n",
      "C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\dynamicgem\\utils\\dnn_utils.py:491: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(30, activation=<keras.lay..., kernel_regularizer=<keras.reg...)`\n",
      "  W_regularizer=Reg.l1_l2(l1=nu1, l2=nu2))(y[i])\n",
      "C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\dynamicgem\\utils\\dnn_utils.py:493: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(3, activation=<keras.lay..., kernel_regularizer=<keras.reg...)`\n",
      "  W_regularizer=Reg.l1_l2(l1=nu1, l2=nu2))(y[K - 1])\n",
      "C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\dynamicgem\\utils\\dnn_utils.py:495: UserWarning: Update your `Model` call to the Keras 2 API: `Model(inputs=Tensor(\"in..., outputs=Tensor(\"de...)`\n",
      "  encoder = Model(input=x, output=y[K])\n",
      "C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\dynamicgem\\utils\\dnn_utils.py:533: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(30, activation=<keras.lay..., kernel_regularizer=<keras.reg...)`\n",
      "  W_regularizer=Reg.l1_l2(l1=nu1, l2=nu2))(y_hat[i + 1])\n",
      "C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\dynamicgem\\utils\\dnn_utils.py:533: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(50, activation=<keras.lay..., kernel_regularizer=<keras.reg...)`\n",
      "  W_regularizer=Reg.l1_l2(l1=nu1, l2=nu2))(y_hat[i + 1])\n",
      "C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\dynamicgem\\utils\\dnn_utils.py:535: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(1405, activation=<keras.lay..., kernel_regularizer=<keras.reg...)`\n",
      "  W_regularizer=Reg.l1_l2(l1=nu1, l2=nu2))(y_hat[1])\n",
      "C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\dynamicgem\\utils\\dnn_utils.py:540: UserWarning: Update your `Model` call to the Keras 2 API: `Model(inputs=Tensor(\"in..., outputs=Tensor(\"de...)`\n",
      "  decoder = Model(input=y, output=x_hat)\n",
      "C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\dynamicgem\\utils\\dnn_utils.py:885: UserWarning: Update your `Model` call to the Keras 2 API: `Model(inputs=Tensor(\"in..., outputs=[<tf.Tenso...)`\n",
      "  autoencoder = Model(input=x_in, output=[x_hat, y])\n",
      "C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\dynamicgem\\embedding\\dynAERNN.py:205: UserWarning: Update your `Model` call to the Keras 2 API: `Model(inputs=[<tf.Tenso..., outputs=Tensor(\"su...)`\n",
      "  self._model = Model(input=[x_in, x_pred], output=x_diff)\n",
      "C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\dynamicgem\\embedding\\dynAERNN.py:225: UserWarning: The semantics of the Keras 2 argument `steps_per_epoch` is not the same as the Keras 1 argument `samples_per_epoch`. `steps_per_epoch` is the number of batches to draw from the generator at each epoch. Basically steps_per_epoch = samples_per_epoch/batch_size. Similarly `nb_val_samples`->`validation_steps` and `val_samples`->`steps` arguments have changed. Update your method calls accordingly.\n",
      "  verbose=1\n",
      "C:\\Users\\conno\\anaconda3\\envs\\dynamicgem_env\\lib\\site-packages\\dynamicgem\\embedding\\dynAERNN.py:225: UserWarning: Update your `fit_generator` call to the Keras 2 API: `fit_generator(generator=<generator..., verbose=1, steps_per_epoch=1124, epochs=2)`\n",
      "  verbose=1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/2\n",
      "1124/1124 [==============================] - 24s 21ms/step - loss: 39.4928\n",
      "Epoch 2/2\n",
      "1124/1124 [==============================] - 21s 18ms/step - loss: 32.6706\n",
      "0.47758007117437723\t0.5370336334670105\t0.5040663783546102\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(0, 0.47758007117437723, [])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dynaernn(read_data(\"DBLP3\"))"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "50c3ad3fdabee9fdefd23e1a4e55e7732f1cc2e5176c2a0141ad64aed25ac9fb"
  },
  "kernelspec": {
   "display_name": "Python 3.6.13 64-bit ('original_repro': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
