{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import division\n",
    "from __future__ import print_function\n",
    "\n",
    "import time\n",
    "import argparse\n",
    "import numpy as np\n",
    "import os\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import networkx as nx\n",
    "from scipy import sparse\n",
    "from scipy.linalg import fractional_matrix_power\n",
    "\n",
    "from utils import *\n",
    "from models import GNN\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "_StoreAction(option_strings=['--dataset'], dest='dataset', nargs=None, const=None, default='cora', type=None, choices=None, help='Dataset name.', metavar=None)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "parser = argparse.ArgumentParser()\n",
    "parser.add_argument('--no-cuda', action='store_true', default=False,\n",
    "                    help='Disables CUDA training.')\n",
    "parser.add_argument('--fastmode', action='store_true', default=False,\n",
    "                    help='Validate during training pass.')\n",
    "parser.add_argument('--seed', type=int, default=42, help='Random seed.')\n",
    "parser.add_argument('--epochs', type=int, default=200,\n",
    "                    help='Number of epochs to train.')\n",
    "parser.add_argument('--lr', type=float, default=0.023,\n",
    "                    help='Initial learning rate.')\n",
    "parser.add_argument('--weight_decay', type=float, default=5e-4,\n",
    "                    help='Weight decay (L2 loss on parameters).')\n",
    "parser.add_argument('--hidden', type=int, default=128,\n",
    "                    help='Number of hidden units.')\n",
    "parser.add_argument('--dropout', type=float, default=0.91,\n",
    "                    help='Dropout rate (1 - keep probability).')\n",
    "parser.add_argument('--dataset', default='cora', help='Dataset name.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = parser.parse_args(\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x2ca2fd12900>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.random.seed(args.seed)\n",
    "torch.manual_seed(args.seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "A_norm, A, X, labels, idx_train, idx_val, idx_test = load_data(args.dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "G = nx.from_numpy_matrix(A.toarray())\n",
    "feature_dictionary = {}\n",
    "\n",
    "for i in np.arange(len(labels)):\n",
    "    feature_dictionary[i] = labels[i]\n",
    "\n",
    "nx.set_node_attributes(G, feature_dictionary, \"attr_name\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "sub_graphs = []\n",
    "\n",
    "A_array = A.toarray()\n",
    "\n",
    "for i in np.arange(len(A_array)):\n",
    "    s_indexes = []\n",
    "    for j in np.arange(len(A_array)):\n",
    "        s_indexes.append(i)\n",
    "        if(A_array[i][j]==1):\n",
    "            s_indexes.append(j)\n",
    "    sub_graphs.append(G.subgraph(s_indexes))\n",
    "\n",
    "subgraph_nodes_list = []\n",
    "\n",
    "for i in np.arange(len(sub_graphs)):\n",
    "    subgraph_nodes_list.append(list(sub_graphs[i].nodes))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "sub_graphs_adj = []\n",
    "for index in np.arange(len(sub_graphs)):\n",
    "    sub_graphs_adj.append(nx.adjacency_matrix(sub_graphs[index]).toarray())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "sub_graph_edges = []\n",
    "for index in np.arange(len(sub_graphs)):\n",
    "    sub_graph_edges.append(sub_graphs[index].number_of_edges())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "sub_graph_nodes_count = []\n",
    "for x in subgraph_nodes_list:\n",
    "    sub_graph_nodes_count.append(len(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_adj = torch.zeros(A_array.shape[0], A_array.shape[0])\n",
    "\n",
    "for node in np.arange(len(subgraph_nodes_list)):\n",
    "    sub_adj = sub_graphs_adj[node]\n",
    "    for neighbors in np.arange(len(subgraph_nodes_list[node])):\n",
    "        index = subgraph_nodes_list[node][neighbors]\n",
    "        count = torch.tensor(0).float()\n",
    "        if(index==node):\n",
    "            continue\n",
    "        else:\n",
    "            c_neighbors = set(subgraph_nodes_list[node]).intersection(subgraph_nodes_list[index])\n",
    "            if index in c_neighbors:\n",
    "                nodes_list = subgraph_nodes_list[node]\n",
    "                sub_graph_index = nodes_list.index(index)\n",
    "                c_neighbors_list = list(c_neighbors)\n",
    "                for i, item1 in enumerate(nodes_list):\n",
    "                    if(item1 in c_neighbors):\n",
    "                        for item2 in c_neighbors_list:\n",
    "                            j = nodes_list.index(item2)\n",
    "                            count += sub_adj[i][j]\n",
    "\n",
    "            new_adj[node][index] = count/2\n",
    "            new_adj[node][index] = new_adj[node][index]/(len(c_neighbors)*(len(c_neighbors)-1))\n",
    "            new_adj[node][index] = new_adj[node][index] * (len(c_neighbors)**1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "features = torch.FloatTensor(X.toarray())\n",
    "labels = torch.LongTensor(labels)\n",
    "\n",
    "weight = torch.FloatTensor(new_adj)\n",
    "weight = weight / weight.sum(1, keepdim=True)\n",
    "\n",
    "weight = weight + torch.FloatTensor(A_array)\n",
    "\n",
    "coeff = weight.sum(1, keepdim=True)\n",
    "coeff = torch.diag((coeff.T)[0])\n",
    "\n",
    "weight = weight + coeff\n",
    "\n",
    "idx_train = torch.LongTensor(idx_train)\n",
    "idx_val = torch.LongTensor(idx_val)\n",
    "idx_test = torch.LongTensor(idx_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "weight = weight.detach().numpy()\n",
    "weight = np.nan_to_num(weight, nan=0)\n",
    "\n",
    "row_sum = np.array(np.sum(weight, axis=1))\n",
    "degree_matrix = np.matrix(np.diag(row_sum+1))\n",
    "\n",
    "D = fractional_matrix_power(degree_matrix, -0.5)\n",
    "A_tilde_hat = D.dot(weight).dot(D)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "adj = torch.FloatTensor(A_tilde_hat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model and optimizer\n",
    "model = GNN(nfeat=features.shape[1],\n",
    "            nhid=args.hidden,\n",
    "            nclass=labels.max().item() + 1,\n",
    "            dropout=args.dropout)\n",
    "\n",
    "optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(epoch):\n",
    "    t = time.time()\n",
    "    model.train()\n",
    "    optimizer.zero_grad()\n",
    "    output = model(features, adj)\n",
    "    loss_train = F.nll_loss(output[idx_train], labels[idx_train])\n",
    "    acc_train = accuracy(output[idx_train], labels[idx_train])\n",
    "    loss_train.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    if not args.fastmode:\n",
    "        # Evaluate validation set performance separately, deactivates dropout during validation run.\n",
    "        model.eval()\n",
    "        output = model(features, adj)\n",
    "\n",
    "    loss_val = F.nll_loss(output[idx_val], labels[idx_val])\n",
    "    acc_val = accuracy(output[idx_val], labels[idx_val])\n",
    "    print('Epoch: {:04d}'.format(epoch+1),\n",
    "          'loss_train: {:.4f}'.format(loss_train.item()),\n",
    "          'acc_train: {:.4f}'.format(acc_train.item()),\n",
    "          'loss_val: {:.4f}'.format(loss_val.item()),\n",
    "          'acc_val: {:.4f}'.format(acc_val.item()),\n",
    "          'time: {:.4f}s'.format(time.time() - t))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test():\n",
    "    model.eval()\n",
    "    output = model(features, adj)\n",
    "    loss_test = F.nll_loss(output[idx_test], labels[idx_test])\n",
    "    acc_test = accuracy(output[idx_test], labels[idx_test])\n",
    "    print(\"Test set results:\",\n",
    "          \"loss= {:.4f}\".format(loss_test.item()),\n",
    "          \"accuracy= {:.4f}\".format(acc_test.item()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0001 loss_train: 1.9810 acc_train: 0.1357 loss_val: 1.9815 acc_val: 0.1360 time: 5.5072s\n",
      "Epoch: 0002 loss_train: 1.9982 acc_train: 0.1857 loss_val: 1.9568 acc_val: 0.0620 time: 2.3208s\n",
      "Epoch: 0003 loss_train: 1.9526 acc_train: 0.1286 loss_val: 1.9539 acc_val: 0.0660 time: 2.1303s\n",
      "Epoch: 0004 loss_train: 1.9429 acc_train: 0.1571 loss_val: 1.9425 acc_val: 0.0840 time: 2.0914s\n",
      "Epoch: 0005 loss_train: 1.9459 acc_train: 0.1786 loss_val: 1.9417 acc_val: 0.1580 time: 2.0974s\n",
      "Epoch: 0006 loss_train: 1.9373 acc_train: 0.1714 loss_val: 1.9482 acc_val: 0.1180 time: 2.0655s\n",
      "Epoch: 0007 loss_train: 1.9375 acc_train: 0.1143 loss_val: 1.9458 acc_val: 0.1120 time: 2.0545s\n",
      "Epoch: 0008 loss_train: 1.9414 acc_train: 0.1357 loss_val: 1.9417 acc_val: 0.1020 time: 2.1373s\n",
      "Epoch: 0009 loss_train: 1.9486 acc_train: 0.1286 loss_val: 1.9393 acc_val: 0.0860 time: 2.2719s\n",
      "Epoch: 0010 loss_train: 1.9415 acc_train: 0.1929 loss_val: 1.9322 acc_val: 0.0780 time: 2.1512s\n",
      "Epoch: 0011 loss_train: 1.9295 acc_train: 0.1429 loss_val: 1.8992 acc_val: 0.1480 time: 2.0954s\n",
      "Epoch: 0012 loss_train: 1.8718 acc_train: 0.2214 loss_val: 2.0375 acc_val: 0.0800 time: 2.2749s\n",
      "Epoch: 0013 loss_train: 1.8839 acc_train: 0.2429 loss_val: 1.9032 acc_val: 0.1360 time: 2.2859s\n",
      "Epoch: 0014 loss_train: 1.8388 acc_train: 0.2000 loss_val: 1.9470 acc_val: 0.1440 time: 2.1752s\n",
      "Epoch: 0015 loss_train: 1.9105 acc_train: 0.1500 loss_val: 1.8595 acc_val: 0.1760 time: 2.2709s\n",
      "Epoch: 0016 loss_train: 1.7978 acc_train: 0.2143 loss_val: 1.7007 acc_val: 0.1900 time: 2.3537s\n",
      "Epoch: 0017 loss_train: 1.6340 acc_train: 0.3143 loss_val: 1.5884 acc_val: 0.1820 time: 2.2330s\n",
      "Epoch: 0018 loss_train: 1.5294 acc_train: 0.2643 loss_val: 1.5405 acc_val: 0.3300 time: 2.0814s\n",
      "Epoch: 0019 loss_train: 1.4219 acc_train: 0.2786 loss_val: 1.7218 acc_val: 0.2040 time: 2.1383s\n",
      "Epoch: 0020 loss_train: 1.5682 acc_train: 0.2786 loss_val: 1.6030 acc_val: 0.2120 time: 2.1094s\n",
      "Epoch: 0021 loss_train: 1.4008 acc_train: 0.3500 loss_val: 1.6913 acc_val: 0.1420 time: 2.1642s\n",
      "Epoch: 0022 loss_train: 1.4544 acc_train: 0.2929 loss_val: 1.6054 acc_val: 0.2460 time: 2.1483s\n",
      "Epoch: 0023 loss_train: 1.3767 acc_train: 0.2786 loss_val: 1.5339 acc_val: 0.3580 time: 2.4425s\n",
      "Epoch: 0024 loss_train: 1.3425 acc_train: 0.3071 loss_val: 1.5326 acc_val: 0.3580 time: 2.4854s\n",
      "Epoch: 0025 loss_train: 1.3615 acc_train: 0.3500 loss_val: 1.5090 acc_val: 0.3780 time: 2.2181s\n",
      "Epoch: 0026 loss_train: 1.3980 acc_train: 0.3643 loss_val: 1.4939 acc_val: 0.3680 time: 2.1822s\n",
      "Epoch: 0027 loss_train: 1.2925 acc_train: 0.3429 loss_val: 1.5723 acc_val: 0.3320 time: 2.1682s\n",
      "Epoch: 0028 loss_train: 1.3441 acc_train: 0.3357 loss_val: 1.5393 acc_val: 0.2640 time: 2.1532s\n",
      "Epoch: 0029 loss_train: 1.2588 acc_train: 0.4214 loss_val: 1.6524 acc_val: 0.2680 time: 2.1443s\n",
      "Epoch: 0030 loss_train: 1.2504 acc_train: 0.4143 loss_val: 1.5616 acc_val: 0.3820 time: 2.1612s\n",
      "Epoch: 0031 loss_train: 1.2191 acc_train: 0.3786 loss_val: 1.5231 acc_val: 0.4480 time: 2.1423s\n",
      "Epoch: 0032 loss_train: 1.2456 acc_train: 0.4071 loss_val: 1.4950 acc_val: 0.4360 time: 2.1652s\n",
      "Epoch: 0033 loss_train: 1.1804 acc_train: 0.5786 loss_val: 1.5114 acc_val: 0.4360 time: 2.1752s\n",
      "Epoch: 0034 loss_train: 1.1908 acc_train: 0.4786 loss_val: 1.4743 acc_val: 0.4520 time: 2.1183s\n",
      "Epoch: 0035 loss_train: 1.1402 acc_train: 0.4929 loss_val: 1.5510 acc_val: 0.4820 time: 2.1503s\n",
      "Epoch: 0036 loss_train: 1.1986 acc_train: 0.4571 loss_val: 1.5221 acc_val: 0.4420 time: 2.1572s\n",
      "Epoch: 0037 loss_train: 1.0854 acc_train: 0.5214 loss_val: 1.4719 acc_val: 0.5540 time: 2.1892s\n",
      "Epoch: 0038 loss_train: 1.0959 acc_train: 0.5857 loss_val: 1.4645 acc_val: 0.5640 time: 2.1263s\n",
      "Epoch: 0039 loss_train: 1.0204 acc_train: 0.6214 loss_val: 1.5192 acc_val: 0.5200 time: 2.1114s\n",
      "Epoch: 0040 loss_train: 1.0131 acc_train: 0.5571 loss_val: 1.5890 acc_val: 0.5020 time: 2.1353s\n",
      "Epoch: 0041 loss_train: 1.0441 acc_train: 0.5929 loss_val: 1.5267 acc_val: 0.4880 time: 2.1373s\n",
      "Epoch: 0042 loss_train: 0.8915 acc_train: 0.6643 loss_val: 1.7149 acc_val: 0.5320 time: 2.2051s\n",
      "Epoch: 0043 loss_train: 1.0164 acc_train: 0.5429 loss_val: 1.8186 acc_val: 0.5220 time: 2.1423s\n",
      "Epoch: 0044 loss_train: 1.0503 acc_train: 0.5500 loss_val: 1.5530 acc_val: 0.5440 time: 2.3407s\n",
      "Epoch: 0045 loss_train: 0.9197 acc_train: 0.6214 loss_val: 2.0022 acc_val: 0.4300 time: 2.4106s\n",
      "Epoch: 0046 loss_train: 1.1931 acc_train: 0.4500 loss_val: 1.4804 acc_val: 0.5040 time: 2.6040s\n",
      "Epoch: 0047 loss_train: 0.9031 acc_train: 0.6000 loss_val: 1.4847 acc_val: 0.5500 time: 2.7766s\n",
      "Epoch: 0048 loss_train: 0.9837 acc_train: 0.5857 loss_val: 1.4120 acc_val: 0.5900 time: 2.4694s\n",
      "Epoch: 0049 loss_train: 0.9796 acc_train: 0.5857 loss_val: 1.4323 acc_val: 0.5600 time: 2.6280s\n",
      "Epoch: 0050 loss_train: 0.8625 acc_train: 0.6214 loss_val: 1.4722 acc_val: 0.5640 time: 2.7516s\n",
      "Epoch: 0051 loss_train: 0.8718 acc_train: 0.6357 loss_val: 1.4787 acc_val: 0.6100 time: 2.6968s\n",
      "Epoch: 0052 loss_train: 0.8458 acc_train: 0.6286 loss_val: 1.4878 acc_val: 0.6120 time: 2.4564s\n",
      "Epoch: 0053 loss_train: 0.7753 acc_train: 0.6929 loss_val: 1.5288 acc_val: 0.5880 time: 2.5302s\n",
      "Epoch: 0054 loss_train: 0.7402 acc_train: 0.6857 loss_val: 1.7144 acc_val: 0.5800 time: 2.6180s\n",
      "Epoch: 0055 loss_train: 0.8489 acc_train: 0.6429 loss_val: 1.7697 acc_val: 0.5460 time: 2.5203s\n",
      "Epoch: 0056 loss_train: 0.6869 acc_train: 0.6857 loss_val: 1.8326 acc_val: 0.5760 time: 2.4903s\n",
      "Epoch: 0057 loss_train: 0.7844 acc_train: 0.6429 loss_val: 1.6447 acc_val: 0.4980 time: 2.4495s\n",
      "Epoch: 0058 loss_train: 0.8055 acc_train: 0.6429 loss_val: 1.6619 acc_val: 0.5200 time: 2.6439s\n",
      "Epoch: 0059 loss_train: 0.7808 acc_train: 0.6714 loss_val: 1.6481 acc_val: 0.6040 time: 2.6230s\n",
      "Epoch: 0060 loss_train: 0.6553 acc_train: 0.7786 loss_val: 2.1018 acc_val: 0.5460 time: 2.6110s\n",
      "Epoch: 0061 loss_train: 1.1182 acc_train: 0.5286 loss_val: 1.5151 acc_val: 0.5740 time: 3.1336s\n",
      "Epoch: 0062 loss_train: 0.5993 acc_train: 0.7786 loss_val: 1.6172 acc_val: 0.5460 time: 2.5622s\n",
      "Epoch: 0063 loss_train: 0.8243 acc_train: 0.6357 loss_val: 1.6417 acc_val: 0.5920 time: 2.4275s\n",
      "Epoch: 0064 loss_train: 0.7056 acc_train: 0.6929 loss_val: 1.8093 acc_val: 0.5900 time: 2.3926s\n",
      "Epoch: 0065 loss_train: 0.6248 acc_train: 0.6929 loss_val: 2.0779 acc_val: 0.5980 time: 2.4145s\n",
      "Epoch: 0066 loss_train: 0.6416 acc_train: 0.6857 loss_val: 1.9788 acc_val: 0.6140 time: 2.4345s\n",
      "Epoch: 0067 loss_train: 0.7777 acc_train: 0.7357 loss_val: 1.9447 acc_val: 0.5640 time: 2.3926s\n",
      "Epoch: 0068 loss_train: 0.6126 acc_train: 0.7429 loss_val: 1.7782 acc_val: 0.5660 time: 2.4365s\n",
      "Epoch: 0069 loss_train: 0.6251 acc_train: 0.7429 loss_val: 1.6533 acc_val: 0.5280 time: 2.4235s\n",
      "Epoch: 0070 loss_train: 0.7027 acc_train: 0.6786 loss_val: 1.7826 acc_val: 0.5800 time: 2.4245s\n",
      "Epoch: 0071 loss_train: 0.8417 acc_train: 0.7000 loss_val: 1.9924 acc_val: 0.6100 time: 2.4076s\n",
      "Epoch: 0072 loss_train: 0.5616 acc_train: 0.8214 loss_val: 2.0291 acc_val: 0.5860 time: 2.6539s\n",
      "Epoch: 0073 loss_train: 0.5752 acc_train: 0.7571 loss_val: 2.0732 acc_val: 0.5760 time: 2.4704s\n",
      "Epoch: 0074 loss_train: 0.7076 acc_train: 0.6857 loss_val: 2.0931 acc_val: 0.5840 time: 2.4784s\n",
      "Epoch: 0075 loss_train: 0.6214 acc_train: 0.7214 loss_val: 2.0128 acc_val: 0.5620 time: 2.6200s\n",
      "Epoch: 0076 loss_train: 0.7430 acc_train: 0.6786 loss_val: 1.9121 acc_val: 0.5520 time: 2.3986s\n",
      "Epoch: 0077 loss_train: 0.9012 acc_train: 0.6071 loss_val: 1.6472 acc_val: 0.5620 time: 2.4754s\n",
      "Epoch: 0078 loss_train: 0.7042 acc_train: 0.7286 loss_val: 1.7324 acc_val: 0.5780 time: 2.4136s\n",
      "Epoch: 0079 loss_train: 0.6143 acc_train: 0.6857 loss_val: 1.8147 acc_val: 0.6200 time: 2.7137s\n",
      "Epoch: 0080 loss_train: 0.6709 acc_train: 0.7429 loss_val: 1.8708 acc_val: 0.6360 time: 2.4056s\n",
      "Epoch: 0081 loss_train: 0.5839 acc_train: 0.7286 loss_val: 1.8747 acc_val: 0.6220 time: 2.4455s\n",
      "Epoch: 0082 loss_train: 0.6342 acc_train: 0.7357 loss_val: 1.7934 acc_val: 0.6100 time: 2.4834s\n",
      "Epoch: 0083 loss_train: 0.6129 acc_train: 0.7500 loss_val: 1.7151 acc_val: 0.5920 time: 2.7786s\n",
      "Epoch: 0084 loss_train: 0.5057 acc_train: 0.7786 loss_val: 1.6341 acc_val: 0.6160 time: 2.6988s\n",
      "Epoch: 0085 loss_train: 0.6927 acc_train: 0.7143 loss_val: 1.8859 acc_val: 0.5680 time: 2.4933s\n",
      "Epoch: 0086 loss_train: 0.8591 acc_train: 0.6929 loss_val: 1.7229 acc_val: 0.6060 time: 2.8035s\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0087 loss_train: 0.5971 acc_train: 0.7786 loss_val: 1.8231 acc_val: 0.5940 time: 3.5346s\n",
      "Epoch: 0088 loss_train: 0.5459 acc_train: 0.8071 loss_val: 1.9137 acc_val: 0.5800 time: 3.5006s\n",
      "Epoch: 0089 loss_train: 0.5132 acc_train: 0.7929 loss_val: 1.9386 acc_val: 0.6200 time: 3.3032s\n",
      "Epoch: 0090 loss_train: 0.5235 acc_train: 0.8214 loss_val: 1.9655 acc_val: 0.6200 time: 2.7108s\n",
      "Epoch: 0091 loss_train: 0.5946 acc_train: 0.7857 loss_val: 1.7513 acc_val: 0.6340 time: 3.0010s\n",
      "Epoch: 0092 loss_train: 0.4899 acc_train: 0.8071 loss_val: 1.5668 acc_val: 0.6240 time: 2.2310s\n",
      "Epoch: 0093 loss_train: 0.4831 acc_train: 0.8214 loss_val: 1.5644 acc_val: 0.6200 time: 2.1782s\n",
      "Epoch: 0094 loss_train: 0.4358 acc_train: 0.8214 loss_val: 1.6252 acc_val: 0.6080 time: 2.5542s\n",
      "Epoch: 0095 loss_train: 0.5375 acc_train: 0.8071 loss_val: 1.8403 acc_val: 0.5960 time: 2.4634s\n",
      "Epoch: 0096 loss_train: 0.6297 acc_train: 0.8214 loss_val: 2.0353 acc_val: 0.6100 time: 2.1532s\n",
      "Epoch: 0097 loss_train: 0.6464 acc_train: 0.7214 loss_val: 1.8270 acc_val: 0.6540 time: 2.3766s\n",
      "Epoch: 0098 loss_train: 0.4827 acc_train: 0.7857 loss_val: 1.9226 acc_val: 0.5560 time: 2.7826s\n",
      "Epoch: 0099 loss_train: 0.5261 acc_train: 0.7500 loss_val: 1.5138 acc_val: 0.6300 time: 2.3268s\n",
      "Epoch: 0100 loss_train: 0.9134 acc_train: 0.7286 loss_val: 2.7861 acc_val: 0.4720 time: 2.2640s\n",
      "Epoch: 0101 loss_train: 1.7446 acc_train: 0.5429 loss_val: 1.6417 acc_val: 0.6420 time: 2.1004s\n",
      "Epoch: 0102 loss_train: 0.6340 acc_train: 0.7643 loss_val: 1.6832 acc_val: 0.5640 time: 2.2699s\n",
      "Epoch: 0103 loss_train: 0.6853 acc_train: 0.6929 loss_val: 1.4891 acc_val: 0.5960 time: 2.3757s\n",
      "Epoch: 0104 loss_train: 0.8657 acc_train: 0.6214 loss_val: 1.2151 acc_val: 0.6480 time: 2.2889s\n",
      "Epoch: 0105 loss_train: 0.7424 acc_train: 0.7357 loss_val: 1.3146 acc_val: 0.5880 time: 2.0994s\n",
      "Epoch: 0106 loss_train: 0.7832 acc_train: 0.7000 loss_val: 1.3885 acc_val: 0.5260 time: 2.0854s\n",
      "Epoch: 0107 loss_train: 1.0445 acc_train: 0.5786 loss_val: 1.2313 acc_val: 0.5780 time: 2.1463s\n",
      "Epoch: 0108 loss_train: 0.6560 acc_train: 0.7786 loss_val: 1.2353 acc_val: 0.6100 time: 2.1074s\n",
      "Epoch: 0109 loss_train: 0.6568 acc_train: 0.7571 loss_val: 1.3869 acc_val: 0.5880 time: 2.3417s\n",
      "Epoch: 0110 loss_train: 0.7005 acc_train: 0.7571 loss_val: 1.4386 acc_val: 0.5920 time: 2.1542s\n",
      "Epoch: 0111 loss_train: 0.6433 acc_train: 0.7286 loss_val: 1.5087 acc_val: 0.6080 time: 2.3258s\n",
      "Epoch: 0112 loss_train: 0.5578 acc_train: 0.7714 loss_val: 1.6151 acc_val: 0.5720 time: 2.6010s\n",
      "Epoch: 0113 loss_train: 0.5925 acc_train: 0.7786 loss_val: 1.6938 acc_val: 0.5960 time: 2.2490s\n",
      "Epoch: 0114 loss_train: 0.5894 acc_train: 0.7643 loss_val: 1.7954 acc_val: 0.5820 time: 2.0814s\n",
      "Epoch: 0115 loss_train: 0.6452 acc_train: 0.7643 loss_val: 1.6970 acc_val: 0.6020 time: 2.2560s\n",
      "Epoch: 0116 loss_train: 0.5638 acc_train: 0.8071 loss_val: 1.7029 acc_val: 0.6140 time: 2.3088s\n",
      "Epoch: 0117 loss_train: 0.5592 acc_train: 0.8429 loss_val: 1.6938 acc_val: 0.6280 time: 2.4445s\n",
      "Epoch: 0118 loss_train: 0.4772 acc_train: 0.8357 loss_val: 1.7250 acc_val: 0.6300 time: 2.4145s\n",
      "Epoch: 0119 loss_train: 0.5377 acc_train: 0.7857 loss_val: 1.7344 acc_val: 0.6080 time: 2.5402s\n",
      "Epoch: 0120 loss_train: 0.4588 acc_train: 0.8143 loss_val: 1.6999 acc_val: 0.6160 time: 2.7147s\n",
      "Epoch: 0121 loss_train: 0.5063 acc_train: 0.8143 loss_val: 1.6444 acc_val: 0.6080 time: 2.7497s\n",
      "Epoch: 0122 loss_train: 0.4130 acc_train: 0.8786 loss_val: 1.6504 acc_val: 0.6020 time: 2.8474s\n",
      "Epoch: 0123 loss_train: 0.4373 acc_train: 0.8357 loss_val: 1.6908 acc_val: 0.6040 time: 2.8235s\n",
      "Epoch: 0124 loss_train: 0.6414 acc_train: 0.8000 loss_val: 1.6672 acc_val: 0.6300 time: 2.4514s\n",
      "Epoch: 0125 loss_train: 0.4784 acc_train: 0.8286 loss_val: 1.6223 acc_val: 0.6420 time: 2.6489s\n",
      "Epoch: 0126 loss_train: 0.3934 acc_train: 0.8571 loss_val: 1.6827 acc_val: 0.6620 time: 2.4245s\n",
      "Epoch: 0127 loss_train: 0.4825 acc_train: 0.8214 loss_val: 1.8980 acc_val: 0.6520 time: 2.4165s\n",
      "Epoch: 0128 loss_train: 0.5085 acc_train: 0.8286 loss_val: 1.9428 acc_val: 0.6540 time: 2.6130s\n",
      "Epoch: 0129 loss_train: 0.4401 acc_train: 0.8286 loss_val: 2.6570 acc_val: 0.5900 time: 2.4524s\n",
      "Epoch: 0130 loss_train: 1.6166 acc_train: 0.7286 loss_val: 2.0046 acc_val: 0.6280 time: 2.4066s\n",
      "Epoch: 0131 loss_train: 0.2914 acc_train: 0.9000 loss_val: 2.4034 acc_val: 0.5880 time: 2.5502s\n",
      "Epoch: 0132 loss_train: 0.6868 acc_train: 0.7500 loss_val: 2.6382 acc_val: 0.5380 time: 2.7108s\n",
      "Epoch: 0133 loss_train: 0.8211 acc_train: 0.6643 loss_val: 2.5332 acc_val: 0.5180 time: 2.4644s\n",
      "Epoch: 0134 loss_train: 0.8975 acc_train: 0.6929 loss_val: 2.2987 acc_val: 0.5160 time: 2.3916s\n",
      "Epoch: 0135 loss_train: 0.9081 acc_train: 0.6786 loss_val: 2.0411 acc_val: 0.5520 time: 2.4175s\n",
      "Epoch: 0136 loss_train: 0.6614 acc_train: 0.7571 loss_val: 1.9181 acc_val: 0.5320 time: 2.3786s\n",
      "Epoch: 0137 loss_train: 0.7348 acc_train: 0.7571 loss_val: 1.8591 acc_val: 0.4520 time: 2.4006s\n",
      "Epoch: 0138 loss_train: 0.7395 acc_train: 0.6643 loss_val: 1.7285 acc_val: 0.4700 time: 2.3946s\n",
      "Epoch: 0139 loss_train: 0.6731 acc_train: 0.7071 loss_val: 1.6254 acc_val: 0.5540 time: 2.4514s\n",
      "Epoch: 0140 loss_train: 0.6176 acc_train: 0.7643 loss_val: 1.5414 acc_val: 0.6000 time: 2.4275s\n",
      "Epoch: 0141 loss_train: 0.6033 acc_train: 0.7643 loss_val: 1.5711 acc_val: 0.6140 time: 2.7935s\n",
      "Epoch: 0142 loss_train: 0.6680 acc_train: 0.7500 loss_val: 1.6462 acc_val: 0.6400 time: 2.4265s\n",
      "Epoch: 0143 loss_train: 0.6221 acc_train: 0.7357 loss_val: 1.9644 acc_val: 0.6380 time: 2.5841s\n",
      "Epoch: 0144 loss_train: 0.6201 acc_train: 0.7714 loss_val: 1.9946 acc_val: 0.6420 time: 2.4355s\n",
      "Epoch: 0145 loss_train: 0.4927 acc_train: 0.7857 loss_val: 2.0665 acc_val: 0.6180 time: 2.4425s\n",
      "Epoch: 0146 loss_train: 0.4557 acc_train: 0.8214 loss_val: 1.9110 acc_val: 0.6380 time: 2.4155s\n",
      "Epoch: 0147 loss_train: 0.5593 acc_train: 0.7857 loss_val: 2.0895 acc_val: 0.6600 time: 2.3936s\n",
      "Epoch: 0148 loss_train: 0.7752 acc_train: 0.8143 loss_val: 2.2669 acc_val: 0.6520 time: 2.5432s\n",
      "Epoch: 0149 loss_train: 0.4399 acc_train: 0.8286 loss_val: 2.3233 acc_val: 0.6460 time: 2.5801s\n",
      "Epoch: 0150 loss_train: 0.4820 acc_train: 0.8143 loss_val: 2.1314 acc_val: 0.6600 time: 2.5502s\n",
      "Epoch: 0151 loss_train: 0.3297 acc_train: 0.8929 loss_val: 1.9703 acc_val: 0.6520 time: 2.7048s\n",
      "Epoch: 0152 loss_train: 0.4363 acc_train: 0.8429 loss_val: 1.9178 acc_val: 0.6440 time: 2.6479s\n",
      "Epoch: 0153 loss_train: 0.4058 acc_train: 0.8857 loss_val: 1.9305 acc_val: 0.6500 time: 2.6878s\n",
      "Epoch: 0154 loss_train: 0.4280 acc_train: 0.8429 loss_val: 2.0567 acc_val: 0.6560 time: 2.5412s\n",
      "Epoch: 0155 loss_train: 0.3845 acc_train: 0.8786 loss_val: 2.3363 acc_val: 0.6540 time: 2.4116s\n",
      "Epoch: 0156 loss_train: 0.4990 acc_train: 0.8143 loss_val: 2.3586 acc_val: 0.6720 time: 2.4574s\n",
      "Epoch: 0157 loss_train: 0.4379 acc_train: 0.8429 loss_val: 2.4164 acc_val: 0.6700 time: 2.6260s\n",
      "Epoch: 0158 loss_train: 0.4228 acc_train: 0.8857 loss_val: 2.3742 acc_val: 0.6640 time: 2.7836s\n",
      "Epoch: 0159 loss_train: 0.3479 acc_train: 0.8786 loss_val: 2.2988 acc_val: 0.6680 time: 2.4155s\n",
      "Epoch: 0160 loss_train: 0.3269 acc_train: 0.9000 loss_val: 2.3064 acc_val: 0.6580 time: 2.3866s\n",
      "Epoch: 0161 loss_train: 0.3720 acc_train: 0.8429 loss_val: 2.3790 acc_val: 0.6640 time: 2.2141s\n",
      "Epoch: 0162 loss_train: 0.2737 acc_train: 0.9071 loss_val: 2.3645 acc_val: 0.6740 time: 2.0625s\n",
      "Epoch: 0163 loss_train: 0.3666 acc_train: 0.8500 loss_val: 2.3321 acc_val: 0.6740 time: 2.0994s\n",
      "Epoch: 0164 loss_train: 0.3616 acc_train: 0.8929 loss_val: 2.1172 acc_val: 0.6920 time: 2.2749s\n",
      "Epoch: 0165 loss_train: 0.2264 acc_train: 0.9214 loss_val: 1.9869 acc_val: 0.6960 time: 2.2580s\n",
      "Epoch: 0166 loss_train: 0.2327 acc_train: 0.9429 loss_val: 1.9346 acc_val: 0.6880 time: 2.2699s\n",
      "Epoch: 0167 loss_train: 0.3639 acc_train: 0.8857 loss_val: 1.9806 acc_val: 0.6820 time: 2.5781s\n",
      "Epoch: 0168 loss_train: 0.4175 acc_train: 0.8857 loss_val: 2.0753 acc_val: 0.6820 time: 2.3118s\n",
      "Epoch: 0169 loss_train: 0.4705 acc_train: 0.8357 loss_val: 2.3361 acc_val: 0.6960 time: 2.4874s\n",
      "Epoch: 0170 loss_train: 0.2287 acc_train: 0.9286 loss_val: 2.5447 acc_val: 0.6940 time: 2.4844s\n",
      "Epoch: 0171 loss_train: 0.2757 acc_train: 0.8929 loss_val: 2.5009 acc_val: 0.6900 time: 2.1014s\n",
      "Epoch: 0172 loss_train: 0.2741 acc_train: 0.8929 loss_val: 2.4094 acc_val: 0.6760 time: 2.2121s\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0173 loss_train: 0.3331 acc_train: 0.9000 loss_val: 2.3231 acc_val: 0.6480 time: 2.3397s\n",
      "Epoch: 0174 loss_train: 0.2368 acc_train: 0.9214 loss_val: 2.3163 acc_val: 0.6440 time: 2.1822s\n",
      "Epoch: 0175 loss_train: 0.4058 acc_train: 0.8643 loss_val: 2.4127 acc_val: 0.6720 time: 2.1692s\n",
      "Epoch: 0176 loss_train: 0.2456 acc_train: 0.9214 loss_val: 2.5213 acc_val: 0.6960 time: 2.2300s\n",
      "Epoch: 0177 loss_train: 0.3076 acc_train: 0.9143 loss_val: 2.6238 acc_val: 0.6860 time: 2.1094s\n",
      "Epoch: 0178 loss_train: 0.5174 acc_train: 0.8643 loss_val: 2.6923 acc_val: 0.6640 time: 2.0505s\n",
      "Epoch: 0179 loss_train: 0.4631 acc_train: 0.8714 loss_val: 2.5804 acc_val: 0.6700 time: 2.0874s\n",
      "Epoch: 0180 loss_train: 0.3483 acc_train: 0.8857 loss_val: 2.3962 acc_val: 0.6820 time: 2.0854s\n",
      "Epoch: 0181 loss_train: 0.3869 acc_train: 0.8714 loss_val: 2.4909 acc_val: 0.6480 time: 2.0645s\n",
      "Epoch: 0182 loss_train: 0.3698 acc_train: 0.8786 loss_val: 2.4901 acc_val: 0.6580 time: 2.1522s\n",
      "Epoch: 0183 loss_train: 0.4338 acc_train: 0.8714 loss_val: 2.4655 acc_val: 0.6800 time: 2.3587s\n",
      "Epoch: 0184 loss_train: 0.3183 acc_train: 0.9000 loss_val: 2.4732 acc_val: 0.6660 time: 2.3148s\n",
      "Epoch: 0185 loss_train: 1.0860 acc_train: 0.7929 loss_val: 2.1153 acc_val: 0.6920 time: 2.1024s\n",
      "Epoch: 0186 loss_train: 0.5978 acc_train: 0.9000 loss_val: 1.8662 acc_val: 0.6600 time: 2.0794s\n",
      "Epoch: 0187 loss_train: 0.4306 acc_train: 0.8571 loss_val: 1.7565 acc_val: 0.6520 time: 2.1652s\n",
      "Epoch: 0188 loss_train: 0.4558 acc_train: 0.8357 loss_val: 1.6162 acc_val: 0.6700 time: 2.4395s\n",
      "Epoch: 0189 loss_train: 0.7120 acc_train: 0.8143 loss_val: 1.4304 acc_val: 0.6580 time: 2.4913s\n",
      "Epoch: 0190 loss_train: 0.3312 acc_train: 0.8786 loss_val: 1.4113 acc_val: 0.6460 time: 2.2699s\n",
      "Epoch: 0191 loss_train: 0.4366 acc_train: 0.8286 loss_val: 1.4311 acc_val: 0.6460 time: 2.3936s\n",
      "Epoch: 0192 loss_train: 0.4070 acc_train: 0.8571 loss_val: 1.4689 acc_val: 0.6500 time: 2.3936s\n",
      "Epoch: 0193 loss_train: 0.4141 acc_train: 0.9000 loss_val: 1.5282 acc_val: 0.6840 time: 2.5582s\n",
      "Epoch: 0194 loss_train: 0.4066 acc_train: 0.8714 loss_val: 1.5892 acc_val: 0.7000 time: 2.6609s\n",
      "Epoch: 0195 loss_train: 0.3303 acc_train: 0.8786 loss_val: 1.6844 acc_val: 0.6920 time: 2.7287s\n",
      "Epoch: 0196 loss_train: 0.4558 acc_train: 0.8286 loss_val: 1.8238 acc_val: 0.6860 time: 2.7227s\n",
      "Epoch: 0197 loss_train: 0.4056 acc_train: 0.8500 loss_val: 1.9401 acc_val: 0.6980 time: 2.6439s\n",
      "Epoch: 0198 loss_train: 0.2619 acc_train: 0.9071 loss_val: 2.0520 acc_val: 0.6980 time: 2.4056s\n",
      "Epoch: 0199 loss_train: 0.2775 acc_train: 0.9214 loss_val: 2.0901 acc_val: 0.7020 time: 2.3806s\n",
      "Epoch: 0200 loss_train: 0.5653 acc_train: 0.8643 loss_val: 2.1728 acc_val: 0.7080 time: 2.4056s\n",
      "Optimization Finished!\n",
      "Total time elapsed: 494.5094s\n",
      "Test set results: loss= 1.7424 accuracy= 0.7570\n"
     ]
    }
   ],
   "source": [
    "# Train model\n",
    "t_total = time.time()\n",
    "for epoch in range(args.epochs):\n",
    "    train(epoch)\n",
    "print(\"Optimization Finished!\")\n",
    "print(\"Total time elapsed: {:.4f}s\".format(time.time() - t_total))\n",
    "\n",
    "# Testing\n",
    "test()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
