{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "from keras.layers import Dense, Activation, Dropout, ReLU, Input\n",
    "from keras.models import Model\n",
    "from keras.regularizers import l2\n",
    "from keras.optimizers import Adam\n",
    "from keras.layers.normalization import BatchNormalization\n",
    "from keras.constraints import unit_norm\n",
    "from keras import optimizers\n",
    "from keras import regularizers\n",
    "from keras import initializers\n",
    "import keras.backend as K\n",
    "from sklearn.utils import class_weight\n",
    "from scipy.linalg import fractional_matrix_power\n",
    "import numpy as np\n",
    "import networkx as nx\n",
    "\n",
    "from utils import *\n",
    "from gat_layer import GAT\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore', category=DeprecationWarning)\n",
    "warnings.filterwarnings('ignore', category=FutureWarning)\n",
    "warnings.filterwarnings('ignore', category=UserWarning)\n",
    "\n",
    "import tensorflow as tf\n",
    "tf.logging.set_verbosity(tf.logging.ERROR)\n",
    "\n",
    "import os\n",
    "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\" \n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Read data.\n",
    "A, X, Y_train, Y_val, Y_test, train_idx, val_idx, test_idx, train_mask, val_mask, test_mask, Y = load_data('cora')\n",
    "A = np.array(A.todense())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = normalize_features(X)\n",
    "X = X.todense()\n",
    "X = np.array(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "G = nx.from_numpy_matrix(A)\n",
    "feature_dictionary = {}\n",
    "\n",
    "for i in np.arange(len(Y)):\n",
    "    feature_dictionary[i] = Y[i]\n",
    "\n",
    "nx.set_node_attributes(G, feature_dictionary, \"attr_name\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "sub_graphs = []\n",
    "\n",
    "A_array = A\n",
    "\n",
    "for i in np.arange(len(A_array)):\n",
    "    s_indexes = []\n",
    "    for j in np.arange(len(A_array)):\n",
    "        s_indexes.append(i)\n",
    "        if(A_array[i][j]==1):\n",
    "            s_indexes.append(j)\n",
    "    sub_graphs.append(G.subgraph(s_indexes))\n",
    "\n",
    "subgraph_nodes_list = []\n",
    "\n",
    "for i in np.arange(len(sub_graphs)):\n",
    "    subgraph_nodes_list.append(list(sub_graphs[i].nodes))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "sub_graphs_adj = []\n",
    "for index in np.arange(len(sub_graphs)):\n",
    "    sub_graphs_adj.append(nx.adjacency_matrix(sub_graphs[index]).toarray())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "sub_graph_edges = []\n",
    "for index in np.arange(len(sub_graphs)):\n",
    "    sub_graph_edges.append(sub_graphs[index].number_of_edges())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_adj = np.zeros((A_array.shape[0], A_array.shape[0]))\n",
    "\n",
    "for node in np.arange(len(subgraph_nodes_list)):\n",
    "    sub_adj = sub_graphs_adj[node]\n",
    "    for neighbors in np.arange(len(subgraph_nodes_list[node])):\n",
    "        index = subgraph_nodes_list[node][neighbors]\n",
    "        count = 0\n",
    "        if(index==node):\n",
    "            continue\n",
    "        else:\n",
    "            c_neighbors = set(subgraph_nodes_list[node]).intersection(subgraph_nodes_list[index])\n",
    "            if index in c_neighbors:\n",
    "                nodes_list = subgraph_nodes_list[node]\n",
    "                sub_graph_index = nodes_list.index(index)\n",
    "                c_neighbors_list = list(c_neighbors)\n",
    "                for i, item1 in enumerate(nodes_list):\n",
    "                    if(item1 in c_neighbors):\n",
    "                        for item2 in c_neighbors_list:\n",
    "                            j = nodes_list.index(item2)\n",
    "                            count += sub_adj[i][j]\n",
    "\n",
    "            new_adj[node][index] = count/2\n",
    "            new_adj[node][index] = new_adj[node][index]/(len(c_neighbors)*(len(c_neighbors)-1))\n",
    "            new_adj[node][index] = new_adj[node][index] * (len(c_neighbors)**1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = np.argmax(Y, axis=1) + 1\n",
    "labels_train = np.zeros(labels.shape)\n",
    "labels_train[train_idx] = labels[train_idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "weight = new_adj / new_adj.sum(axis=1, keepdims=True)\n",
    "\n",
    "weight = weight + A\n",
    "coeff = weight.sum(axis=1)\n",
    "coeff = np.diag(coeff)\n",
    "\n",
    "weight = weight + coeff\n",
    "weight = np.nan_to_num(weight, nan=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "row_sum = np.array(np.sum(weight, axis=1))\n",
    "degree_matrix = np.matrix(np.diag(row_sum+1))\n",
    "\n",
    "D = fractional_matrix_power(degree_matrix, -0.5)\n",
    "adj = D.dot(weight).dot(D)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def att_factor(inputs, num_nodes, dropout):\n",
    "    \n",
    "    h_1 = BatchNormalization()(inputs)\n",
    "    h_1 = GAT(num_nodes, \n",
    "              adj, \n",
    "              num_attention_heads=8,\n",
    "              attention_combine='concat', \n",
    "              attention_dropout=0.6,\n",
    "              kernel_initializer=initializers.glorot_normal(seed=1), \n",
    "              kernel_regularizer=l2(6e-3),\n",
    "              kernel_constraint=unit_norm(),\n",
    "              use_bias=True,\n",
    "              bias_initializer=initializers.glorot_normal(seed=1), \n",
    "              bias_constraint=unit_norm())(h_1)\n",
    "    h_1 = ReLU()(h_1)\n",
    "    output = Dropout(dropout)(h_1)\n",
    "    return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def att_block(inputs):\n",
    "\n",
    "    x = inputs\n",
    "    \n",
    "    num_nodes = [64]\n",
    "    dropout = [0.6]\n",
    "\n",
    "    for i in range(1):\n",
    "        x = att_factor(x, num_nodes[i], dropout[i])\n",
    "\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def att_block_model(x_train):\n",
    "    \n",
    "    inputs = Input((x_train.shape[1],))\n",
    "\n",
    "    x = att_block(inputs)\n",
    "    \n",
    "    predictions = Dense(7, kernel_initializer=initializers.glorot_normal(seed=1), \n",
    "                        kernel_regularizer=regularizers.l2(1e-10),\n",
    "                        kernel_constraint=unit_norm(), \n",
    "                        activity_regularizer=regularizers.l2(1e-10),\n",
    "                        use_bias=True, \n",
    "                        bias_initializer=initializers.glorot_normal(seed=1), \n",
    "                        bias_constraint=unit_norm(), \n",
    "                        activation='softmax', name='fc_'+str(1))(x)\n",
    "\n",
    "    model = Model(input=inputs, output=predictions)\n",
    "    \n",
    "    model.compile(loss='categorical_crossentropy', optimizer=Adam(lr=0.001), metrics=['acc'])\n",
    "    \n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "input_1 (InputLayer)         (None, 1433)              0         \n",
      "_________________________________________________________________\n",
      "batch_normalization_1 (Batch (None, 1433)              5732      \n",
      "_________________________________________________________________\n",
      "gat_1 (GAT)                  (None, 512)               735248    \n",
      "_________________________________________________________________\n",
      "re_lu_1 (ReLU)               (None, 512)               0         \n",
      "_________________________________________________________________\n",
      "dropout_9 (Dropout)          (None, 512)               0         \n",
      "_________________________________________________________________\n",
      "fc_1 (Dense)                 (None, 7)                 3591      \n",
      "=================================================================\n",
      "Total params: 744,571\n",
      "Trainable params: 741,705\n",
      "Non-trainable params: 2,866\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model = att_block_model(X)\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0000 train_acc= 0.1500 test_acc= 0.1162\n",
      "Epoch: 0001 train_acc= 0.2143 test_acc= 0.1353\n",
      "Epoch: 0002 train_acc= 0.2643 test_acc= 0.1581\n",
      "Epoch: 0003 train_acc= 0.2786 test_acc= 0.1728\n",
      "Epoch: 0004 train_acc= 0.4000 test_acc= 0.2051\n",
      "Epoch: 0005 train_acc= 0.5286 test_acc= 0.2596\n",
      "Epoch: 0006 train_acc= 0.6286 test_acc= 0.3419\n",
      "Epoch: 0007 train_acc= 0.7143 test_acc= 0.4154\n",
      "Epoch: 0008 train_acc= 0.8286 test_acc= 0.5147\n",
      "Epoch: 0009 train_acc= 0.8857 test_acc= 0.5846\n",
      "Epoch: 0010 train_acc= 0.9286 test_acc= 0.6559\n",
      "Epoch: 0011 train_acc= 0.9429 test_acc= 0.6993\n",
      "Epoch: 0012 train_acc= 0.9500 test_acc= 0.7331\n",
      "Epoch: 0013 train_acc= 0.9571 test_acc= 0.7640\n",
      "Epoch: 0014 train_acc= 0.9571 test_acc= 0.7824\n",
      "Epoch: 0015 train_acc= 0.9571 test_acc= 0.7912\n",
      "Epoch: 0016 train_acc= 0.9571 test_acc= 0.8037\n",
      "Epoch: 0017 train_acc= 0.9643 test_acc= 0.8037\n",
      "Epoch: 0018 train_acc= 0.9643 test_acc= 0.8103\n",
      "Epoch: 0019 train_acc= 0.9714 test_acc= 0.8169\n",
      "Epoch: 0020 train_acc= 0.9714 test_acc= 0.8221\n",
      "Epoch: 0021 train_acc= 0.9786 test_acc= 0.8265\n",
      "Epoch: 0022 train_acc= 0.9786 test_acc= 0.8272\n",
      "Epoch: 0023 train_acc= 0.9786 test_acc= 0.8294\n",
      "Epoch: 0024 train_acc= 0.9786 test_acc= 0.8301\n",
      "Epoch: 0025 train_acc= 0.9786 test_acc= 0.8338\n",
      "Epoch: 0026 train_acc= 0.9786 test_acc= 0.8360\n",
      "Epoch: 0027 train_acc= 0.9786 test_acc= 0.8346\n",
      "Epoch: 0028 train_acc= 0.9857 test_acc= 0.8353\n",
      "Epoch: 0029 train_acc= 0.9786 test_acc= 0.8338\n",
      "Epoch: 0030 train_acc= 0.9786 test_acc= 0.8360\n",
      "Epoch: 0031 train_acc= 0.9857 test_acc= 0.8353\n",
      "Epoch: 0032 train_acc= 0.9857 test_acc= 0.8353\n",
      "Epoch: 0033 train_acc= 0.9857 test_acc= 0.8360\n",
      "Epoch: 0034 train_acc= 0.9857 test_acc= 0.8360\n",
      "Epoch: 0035 train_acc= 0.9857 test_acc= 0.8368\n",
      "Epoch: 0036 train_acc= 0.9857 test_acc= 0.8380\n",
      "Epoch: 0037 train_acc= 0.9857 test_acc= 0.8353\n",
      "Epoch: 0038 train_acc= 0.9929 test_acc= 0.8331\n",
      "Epoch: 0039 train_acc= 0.9929 test_acc= 0.8316\n",
      "Epoch: 0040 train_acc= 0.9929 test_acc= 0.8324\n",
      "Epoch: 0041 train_acc= 0.9929 test_acc= 0.8331\n",
      "Epoch: 0042 train_acc= 0.9929 test_acc= 0.8353\n",
      "Epoch: 0043 train_acc= 0.9929 test_acc= 0.8331\n",
      "Epoch: 0044 train_acc= 0.9929 test_acc= 0.8360\n",
      "Epoch: 0045 train_acc= 0.9929 test_acc= 0.8368\n",
      "Epoch: 0046 train_acc= 0.9929 test_acc= 0.8353\n",
      "Epoch: 0047 train_acc= 0.9929 test_acc= 0.8331\n",
      "Epoch: 0048 train_acc= 0.9929 test_acc= 0.8338\n",
      "Epoch: 0049 train_acc= 0.9929 test_acc= 0.8346\n",
      "Epoch: 0050 train_acc= 0.9929 test_acc= 0.8346\n",
      "Epoch: 0051 train_acc= 0.9929 test_acc= 0.8346\n",
      "Epoch: 0052 train_acc= 0.9929 test_acc= 0.8346\n",
      "Epoch: 0053 train_acc= 0.9929 test_acc= 0.8346\n",
      "Epoch: 0054 train_acc= 0.9929 test_acc= 0.8331\n",
      "Epoch: 0055 train_acc= 0.9929 test_acc= 0.8301\n",
      "Epoch: 0056 train_acc= 0.9929 test_acc= 0.8272\n",
      "Epoch: 0057 train_acc= 0.9929 test_acc= 0.8257\n",
      "Epoch: 0058 train_acc= 0.9929 test_acc= 0.8257\n",
      "Epoch: 0059 train_acc= 0.9929 test_acc= 0.8250\n",
      "Epoch: 0060 train_acc= 0.9929 test_acc= 0.8243\n",
      "Epoch: 0061 train_acc= 0.9929 test_acc= 0.8243\n",
      "Epoch: 0062 train_acc= 0.9929 test_acc= 0.8243\n",
      "Epoch: 0063 train_acc= 0.9929 test_acc= 0.8228\n",
      "Epoch: 0064 train_acc= 0.9929 test_acc= 0.8235\n",
      "Epoch: 0065 train_acc= 0.9929 test_acc= 0.8250\n",
      "Epoch: 0066 train_acc= 0.9929 test_acc= 0.8265\n",
      "Epoch: 0067 train_acc= 0.9929 test_acc= 0.8265\n",
      "Epoch: 0068 train_acc= 0.9929 test_acc= 0.8272\n",
      "Epoch: 0069 train_acc= 0.9929 test_acc= 0.8279\n",
      "Epoch: 0070 train_acc= 0.9929 test_acc= 0.8265\n",
      "Epoch: 0071 train_acc= 0.9929 test_acc= 0.8257\n",
      "Epoch: 0072 train_acc= 0.9929 test_acc= 0.8235\n",
      "Epoch: 0073 train_acc= 0.9929 test_acc= 0.8221\n",
      "Epoch: 0074 train_acc= 0.9929 test_acc= 0.8228\n",
      "Epoch: 0075 train_acc= 0.9929 test_acc= 0.8213\n",
      "Epoch: 0076 train_acc= 0.9929 test_acc= 0.8228\n",
      "Epoch: 0077 train_acc= 0.9929 test_acc= 0.8228\n",
      "Epoch: 0078 train_acc= 0.9929 test_acc= 0.8221\n",
      "Epoch: 0079 train_acc= 0.9929 test_acc= 0.8221\n",
      "Epoch: 0080 train_acc= 0.9929 test_acc= 0.8235\n",
      "Epoch: 0081 train_acc= 0.9929 test_acc= 0.8243\n",
      "Epoch: 0082 train_acc= 0.9929 test_acc= 0.8250\n",
      "Epoch: 0083 train_acc= 0.9929 test_acc= 0.8243\n",
      "Epoch: 0084 train_acc= 0.9929 test_acc= 0.8235\n",
      "Epoch: 0085 train_acc= 0.9929 test_acc= 0.8221\n",
      "Epoch: 0086 train_acc= 0.9929 test_acc= 0.8235\n",
      "Epoch: 0087 train_acc= 0.9929 test_acc= 0.8257\n",
      "Epoch: 0088 train_acc= 0.9929 test_acc= 0.8228\n",
      "Epoch: 0089 train_acc= 0.9929 test_acc= 0.8235\n",
      "Epoch: 0090 train_acc= 0.9929 test_acc= 0.8228\n",
      "Epoch: 0091 train_acc= 0.9929 test_acc= 0.8243\n",
      "Epoch: 0092 train_acc= 0.9929 test_acc= 0.8250\n",
      "Epoch: 0093 train_acc= 0.9929 test_acc= 0.8265\n",
      "Epoch: 0094 train_acc= 0.9929 test_acc= 0.8257\n",
      "Epoch: 0095 train_acc= 0.9929 test_acc= 0.8228\n",
      "Epoch: 0096 train_acc= 0.9929 test_acc= 0.8235\n",
      "Epoch: 0097 train_acc= 0.9929 test_acc= 0.8221\n",
      "Epoch: 0098 train_acc= 0.9929 test_acc= 0.8213\n",
      "Epoch: 0099 train_acc= 0.9929 test_acc= 0.8213\n",
      "Epoch: 0100 train_acc= 0.9929 test_acc= 0.8221\n",
      "Epoch: 0101 train_acc= 0.9929 test_acc= 0.8206\n",
      "Epoch: 0102 train_acc= 0.9929 test_acc= 0.8199\n",
      "Epoch: 0103 train_acc= 0.9929 test_acc= 0.8184\n",
      "Epoch: 0104 train_acc= 0.9929 test_acc= 0.8184\n",
      "Epoch: 0105 train_acc= 0.9929 test_acc= 0.8176\n",
      "Epoch: 0106 train_acc= 0.9929 test_acc= 0.8169\n",
      "Epoch: 0107 train_acc= 0.9929 test_acc= 0.8162\n",
      "Epoch: 0108 train_acc= 0.9929 test_acc= 0.8169\n",
      "Epoch: 0109 train_acc= 0.9929 test_acc= 0.8154\n",
      "Epoch: 0110 train_acc= 0.9929 test_acc= 0.8162\n",
      "Epoch: 0111 train_acc= 0.9929 test_acc= 0.8184\n",
      "Epoch: 0112 train_acc= 0.9929 test_acc= 0.8213\n",
      "Epoch: 0113 train_acc= 0.9929 test_acc= 0.8176\n",
      "Epoch: 0114 train_acc= 0.9929 test_acc= 0.8162\n",
      "Epoch: 0115 train_acc= 0.9929 test_acc= 0.8162\n",
      "Epoch: 0116 train_acc= 0.9929 test_acc= 0.8191\n",
      "Epoch: 0117 train_acc= 0.9929 test_acc= 0.8191\n",
      "Epoch: 0118 train_acc= 0.9929 test_acc= 0.8199\n",
      "Epoch: 0119 train_acc= 0.9929 test_acc= 0.8228\n",
      "Epoch: 0120 train_acc= 0.9929 test_acc= 0.8221\n",
      "Epoch: 0121 train_acc= 0.9929 test_acc= 0.8221\n",
      "Epoch: 0122 train_acc= 0.9929 test_acc= 0.8228\n",
      "Epoch: 0123 train_acc= 0.9929 test_acc= 0.8184\n",
      "Epoch: 0124 train_acc= 0.9929 test_acc= 0.8162\n",
      "Epoch: 0125 train_acc= 0.9929 test_acc= 0.8169\n",
      "Epoch: 0126 train_acc= 0.9929 test_acc= 0.8191\n",
      "Epoch: 0127 train_acc= 0.9929 test_acc= 0.8206\n",
      "Epoch: 0128 train_acc= 0.9929 test_acc= 0.8206\n",
      "Epoch: 0129 train_acc= 0.9929 test_acc= 0.8206\n",
      "Epoch: 0130 train_acc= 0.9929 test_acc= 0.8206\n",
      "Epoch: 0131 train_acc= 0.9929 test_acc= 0.8184\n",
      "Epoch: 0132 train_acc= 0.9929 test_acc= 0.8162\n",
      "Epoch: 0133 train_acc= 0.9929 test_acc= 0.8125\n",
      "Epoch: 0134 train_acc= 0.9929 test_acc= 0.8118\n",
      "Epoch: 0135 train_acc= 0.9929 test_acc= 0.8125\n",
      "Epoch: 0136 train_acc= 0.9929 test_acc= 0.8147\n",
      "Epoch: 0137 train_acc= 0.9929 test_acc= 0.8125\n",
      "Epoch: 0138 train_acc= 0.9929 test_acc= 0.8118\n",
      "Epoch: 0139 train_acc= 0.9929 test_acc= 0.8110\n",
      "Epoch: 0140 train_acc= 0.9929 test_acc= 0.8118\n",
      "Epoch: 0141 train_acc= 0.9929 test_acc= 0.8140\n",
      "Epoch: 0142 train_acc= 0.9929 test_acc= 0.8147\n",
      "Epoch: 0143 train_acc= 0.9929 test_acc= 0.8132\n",
      "Epoch: 0144 train_acc= 0.9929 test_acc= 0.8132\n",
      "Epoch: 0145 train_acc= 0.9929 test_acc= 0.8110\n",
      "Epoch: 0146 train_acc= 0.9929 test_acc= 0.8103\n",
      "Epoch: 0147 train_acc= 0.9929 test_acc= 0.8118\n",
      "Epoch: 0148 train_acc= 0.9929 test_acc= 0.8096\n",
      "Epoch: 0149 train_acc= 0.9929 test_acc= 0.8096\n",
      "Epoch: 0150 train_acc= 0.9929 test_acc= 0.8110\n",
      "Epoch: 0151 train_acc= 0.9929 test_acc= 0.8110\n",
      "Epoch: 0152 train_acc= 0.9929 test_acc= 0.8096\n",
      "Epoch: 0153 train_acc= 0.9929 test_acc= 0.8103\n",
      "Epoch: 0154 train_acc= 0.9929 test_acc= 0.8140\n",
      "Epoch: 0155 train_acc= 0.9929 test_acc= 0.8169\n",
      "Epoch: 0156 train_acc= 0.9929 test_acc= 0.8184\n",
      "Epoch: 0157 train_acc= 0.9929 test_acc= 0.8199\n",
      "Epoch: 0158 train_acc= 0.9929 test_acc= 0.8199\n",
      "Epoch: 0159 train_acc= 0.9929 test_acc= 0.8199\n",
      "Epoch: 0160 train_acc= 0.9929 test_acc= 0.8191\n",
      "Epoch: 0161 train_acc= 0.9929 test_acc= 0.8140\n",
      "Epoch: 0162 train_acc= 0.9929 test_acc= 0.8051\n",
      "Epoch: 0163 train_acc= 0.9929 test_acc= 0.8088\n",
      "Epoch: 0164 train_acc= 0.9929 test_acc= 0.8059\n",
      "Epoch: 0165 train_acc= 0.9929 test_acc= 0.8096\n",
      "Epoch: 0166 train_acc= 0.9929 test_acc= 0.8096\n",
      "Epoch: 0167 train_acc= 0.9929 test_acc= 0.8110\n",
      "Epoch: 0168 train_acc= 0.9929 test_acc= 0.8125\n",
      "Epoch: 0169 train_acc= 0.9929 test_acc= 0.8132\n",
      "Epoch: 0170 train_acc= 0.9929 test_acc= 0.8125\n",
      "Epoch: 0171 train_acc= 0.9929 test_acc= 0.8140\n",
      "Epoch: 0172 train_acc= 0.9929 test_acc= 0.8140\n",
      "Epoch: 0173 train_acc= 0.9929 test_acc= 0.8140\n",
      "Epoch: 0174 train_acc= 0.9929 test_acc= 0.8132\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0175 train_acc= 1.0000 test_acc= 0.8103\n",
      "Epoch: 0176 train_acc= 1.0000 test_acc= 0.8059\n",
      "Epoch: 0177 train_acc= 1.0000 test_acc= 0.8074\n",
      "Epoch: 0178 train_acc= 1.0000 test_acc= 0.8081\n",
      "Epoch: 0179 train_acc= 1.0000 test_acc= 0.8110\n",
      "Epoch: 0180 train_acc= 1.0000 test_acc= 0.8118\n",
      "Epoch: 0181 train_acc= 1.0000 test_acc= 0.8125\n",
      "Epoch: 0182 train_acc= 0.9929 test_acc= 0.8118\n",
      "Epoch: 0183 train_acc= 0.9929 test_acc= 0.8132\n",
      "Epoch: 0184 train_acc= 0.9929 test_acc= 0.8132\n",
      "Epoch: 0185 train_acc= 0.9929 test_acc= 0.8154\n",
      "Epoch: 0186 train_acc= 1.0000 test_acc= 0.8162\n",
      "Epoch: 0187 train_acc= 1.0000 test_acc= 0.8140\n",
      "Epoch: 0188 train_acc= 1.0000 test_acc= 0.8125\n",
      "Epoch: 0189 train_acc= 0.9929 test_acc= 0.8125\n",
      "Epoch: 0190 train_acc= 0.9929 test_acc= 0.8140\n",
      "Epoch: 0191 train_acc= 0.9929 test_acc= 0.8132\n",
      "Epoch: 0192 train_acc= 0.9929 test_acc= 0.8125\n",
      "Epoch: 0193 train_acc= 0.9929 test_acc= 0.8103\n",
      "Epoch: 0194 train_acc= 0.9929 test_acc= 0.8110\n",
      "Epoch: 0195 train_acc= 0.9929 test_acc= 0.8096\n",
      "Epoch: 0196 train_acc= 0.9929 test_acc= 0.8088\n",
      "Epoch: 0197 train_acc= 0.9929 test_acc= 0.8088\n",
      "Epoch: 0198 train_acc= 0.9929 test_acc= 0.8074\n",
      "Epoch: 0199 train_acc= 0.9929 test_acc= 0.8051\n"
     ]
    }
   ],
   "source": [
    "nb_epochs = 200\n",
    "\n",
    "class_weight = class_weight.compute_class_weight('balanced', np.unique(labels_train), labels_train)\n",
    "class_weight_dic = dict(enumerate(class_weight))\n",
    "\n",
    "for epoch in range(nb_epochs):\n",
    "    model.fit(X, Y_train, sample_weight=train_mask, batch_size=A.shape[0], epochs=1, shuffle=False, \n",
    "              class_weight=class_weight_dic, verbose=0)\n",
    "    Y_pred = model.predict(X, batch_size=A.shape[0])\n",
    "    _, train_acc = evaluate_preds(Y_pred, [Y_train], [train_idx])\n",
    "    _, val_acc = evaluate_preds(Y_pred, [Y_val], [val_idx])\n",
    "    _, test_acc = evaluate_preds(Y_pred, [Y_test], [test_idx])\n",
    "    print(\"Epoch: {:04d}\".format(epoch), \"train_acc= {:.4f}\".format(train_acc[0]), \"test_acc= {:.4f}\".format(test_acc[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
