{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import division\n",
    "from __future__ import print_function\n",
    "\n",
    "import time\n",
    "import argparse\n",
    "import numpy as np\n",
    "import os\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import networkx as nx\n",
    "from scipy import sparse\n",
    "from scipy.linalg import fractional_matrix_power\n",
    "\n",
    "from utils import *\n",
    "from models import GNN\n",
    "from dataset_utils import DataLoader\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "_StoreAction(option_strings=['--dataset'], dest='dataset', nargs=None, const=None, default='citeseer', type=None, choices=None, help='Dataset name.', metavar=None)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "parser = argparse.ArgumentParser()\n",
    "parser.add_argument('--no-cuda', action='store_true', default=False,\n",
    "                    help='Disables CUDA training.')\n",
    "parser.add_argument('--fastmode', action='store_true', default=False,\n",
    "                    help='Validate during training pass.')\n",
    "parser.add_argument('--seed', type=int, default=42, help='Random seed.')\n",
    "parser.add_argument('--epochs', type=int, default=200,\n",
    "                    help='Number of epochs to train.')\n",
    "parser.add_argument('--lr', type=float, default=0.019,\n",
    "                    help='Initial learning rate.')\n",
    "parser.add_argument('--weight_decay', type=float, default=5e-2,\n",
    "                    help='Weight decay (L2 loss on parameters).')\n",
    "parser.add_argument('--hidden', type=int, default=128,\n",
    "                    help='Number of hidden units.')\n",
    "parser.add_argument('--dropout', type=float, default=0.9,\n",
    "                    help='Dropout rate (1 - keep probability).')\n",
    "parser.add_argument('--dataset', default='citeseer', help='Dataset name.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = parser.parse_args(\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x278d55a4900>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.random.seed(args.seed)\n",
    "torch.manual_seed(args.seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "dname = args.dataset\n",
    "dataset = DataLoader(dname)\n",
    "data = dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "A_norm, A, X, labels, idx_train, idx_val, idx_test = load_citation_data(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "G = nx.from_numpy_matrix(A)\n",
    "feature_dictionary = {}\n",
    "\n",
    "for i in np.arange(len(labels)):\n",
    "    feature_dictionary[i] = labels[i]\n",
    "\n",
    "nx.set_node_attributes(G, feature_dictionary, \"attr_name\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "sub_graphs = []\n",
    "\n",
    "for i in np.arange(len(A)):\n",
    "    s_indexes = []\n",
    "    for j in np.arange(len(A)):\n",
    "        s_indexes.append(i)\n",
    "        if(A[i][j]==1):\n",
    "            s_indexes.append(j)\n",
    "    sub_graphs.append(G.subgraph(s_indexes))\n",
    "\n",
    "subgraph_nodes_list = []\n",
    "\n",
    "for i in np.arange(len(sub_graphs)):\n",
    "    subgraph_nodes_list.append(list(sub_graphs[i].nodes))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "sub_graphs_adj = []\n",
    "for index in np.arange(len(sub_graphs)):\n",
    "    sub_graphs_adj.append(nx.adjacency_matrix(sub_graphs[index]).toarray())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_adj = torch.zeros(A.shape[0], A.shape[0])\n",
    "\n",
    "for node in np.arange(len(subgraph_nodes_list)):\n",
    "    sub_adj = sub_graphs_adj[node]\n",
    "    for neighbors in np.arange(len(subgraph_nodes_list[node])):\n",
    "        index = subgraph_nodes_list[node][neighbors]\n",
    "        count = torch.tensor(0).float()\n",
    "        if(index==node):\n",
    "            continue\n",
    "        else:\n",
    "            c_neighbors = set(subgraph_nodes_list[node]).intersection(subgraph_nodes_list[index])\n",
    "            if index in c_neighbors:\n",
    "                nodes_list = subgraph_nodes_list[node]\n",
    "                sub_graph_index = nodes_list.index(index)\n",
    "                c_neighbors_list = list(c_neighbors)\n",
    "                for i, item1 in enumerate(nodes_list):\n",
    "                    if(item1 in c_neighbors):\n",
    "                        for item2 in c_neighbors_list:\n",
    "                            j = nodes_list.index(item2)\n",
    "                            count += sub_adj[i][j]\n",
    "\n",
    "            new_adj[node][index] = count/2\n",
    "            new_adj[node][index] = new_adj[node][index]/(len(c_neighbors)*(len(c_neighbors)-1))\n",
    "            new_adj[node][index] = new_adj[node][index] * (len(c_neighbors)**1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "features = torch.FloatTensor(X)\n",
    "labels = torch.LongTensor(labels)\n",
    "\n",
    "weight = torch.FloatTensor(new_adj)\n",
    "weight = weight / weight.sum(1, keepdim=True)\n",
    "\n",
    "weight = weight + torch.FloatTensor(A)\n",
    "\n",
    "coeff = weight.sum(1, keepdim=True)\n",
    "coeff = torch.diag((coeff.T)[0])\n",
    "\n",
    "weight = weight + coeff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "weight = weight.detach().numpy()\n",
    "weight = np.nan_to_num(weight, nan=0)\n",
    "\n",
    "row_sum = np.array(np.sum(weight, axis=1))\n",
    "degree_matrix = np.matrix(np.diag(row_sum+1))\n",
    "\n",
    "D = fractional_matrix_power(degree_matrix, -0.5)\n",
    "A_tilde_hat = D.dot(weight).dot(D)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "adj = torch.FloatTensor(A_tilde_hat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model and optimizer\n",
    "model = GNN(nfeat=features.shape[1],\n",
    "            nhid=args.hidden,\n",
    "            nclass=labels.max().item() + 1,\n",
    "            dropout=args.dropout)\n",
    "\n",
    "optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(epoch):\n",
    "    t = time.time()\n",
    "    model.train()\n",
    "    optimizer.zero_grad()\n",
    "    output = model(features, adj)\n",
    "    loss_train = F.nll_loss(output[idx_train], labels[idx_train])\n",
    "    acc_train = accuracy(output[idx_train], labels[idx_train])\n",
    "    loss_train.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    if not args.fastmode:\n",
    "        # Evaluate validation set performance separately, deactivates dropout during validation run.\n",
    "        model.eval()\n",
    "        output = model(features, adj)\n",
    "\n",
    "    loss_val = F.nll_loss(output[idx_val], labels[idx_val])\n",
    "    acc_val = accuracy(output[idx_val], labels[idx_val])\n",
    "    print('Epoch: {:04d}'.format(epoch+1),\n",
    "          'loss_train: {:.4f}'.format(loss_train.item()),\n",
    "          'acc_train: {:.4f}'.format(acc_train.item()),\n",
    "          'loss_val: {:.4f}'.format(loss_val.item()),\n",
    "          'acc_val: {:.4f}'.format(acc_val.item()),\n",
    "          'time: {:.4f}s'.format(time.time() - t))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test():\n",
    "    model.eval()\n",
    "    output = model(features, adj)\n",
    "    loss_test = F.nll_loss(output[idx_test], labels[idx_test])\n",
    "    acc_test = accuracy(output[idx_test], labels[idx_test])\n",
    "    print(\"Test set results:\",\n",
    "          \"loss= {:.4f}\".format(loss_test.item()),\n",
    "          \"accuracy= {:.4f}\".format(acc_test.item()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0001 loss_train: 1.8565 acc_train: 0.1833 loss_val: 1.7710 acc_val: 0.1880 time: 0.8591s\n",
      "Epoch: 0002 loss_train: 1.7756 acc_train: 0.2167 loss_val: 1.7379 acc_val: 0.3060 time: 0.9165s\n",
      "Epoch: 0003 loss_train: 1.6989 acc_train: 0.3500 loss_val: 1.7081 acc_val: 0.4180 time: 0.8833s\n",
      "Epoch: 0004 loss_train: 1.6287 acc_train: 0.3917 loss_val: 1.6795 acc_val: 0.4700 time: 0.8674s\n",
      "Epoch: 0005 loss_train: 1.6024 acc_train: 0.4583 loss_val: 1.6568 acc_val: 0.5120 time: 0.7567s\n",
      "Epoch: 0006 loss_train: 1.5413 acc_train: 0.5583 loss_val: 1.6378 acc_val: 0.5220 time: 0.8116s\n",
      "Epoch: 0007 loss_train: 1.4961 acc_train: 0.6333 loss_val: 1.6243 acc_val: 0.5640 time: 0.7836s\n",
      "Epoch: 0008 loss_train: 1.5057 acc_train: 0.5917 loss_val: 1.6191 acc_val: 0.5820 time: 0.7859s\n",
      "Epoch: 0009 loss_train: 1.4694 acc_train: 0.7250 loss_val: 1.6178 acc_val: 0.5940 time: 0.8141s\n",
      "Epoch: 0010 loss_train: 1.4753 acc_train: 0.6917 loss_val: 1.6203 acc_val: 0.6340 time: 0.7849s\n",
      "Epoch: 0011 loss_train: 1.4720 acc_train: 0.7083 loss_val: 1.6243 acc_val: 0.6420 time: 0.7823s\n",
      "Epoch: 0012 loss_train: 1.4991 acc_train: 0.7083 loss_val: 1.6279 acc_val: 0.6520 time: 0.7518s\n",
      "Epoch: 0013 loss_train: 1.4843 acc_train: 0.6750 loss_val: 1.6326 acc_val: 0.6580 time: 0.7656s\n",
      "Epoch: 0014 loss_train: 1.5096 acc_train: 0.7333 loss_val: 1.6367 acc_val: 0.6620 time: 0.7461s\n",
      "Epoch: 0015 loss_train: 1.5119 acc_train: 0.7167 loss_val: 1.6400 acc_val: 0.6580 time: 0.7880s\n",
      "Epoch: 0016 loss_train: 1.4936 acc_train: 0.7583 loss_val: 1.6423 acc_val: 0.6520 time: 0.7508s\n",
      "Epoch: 0017 loss_train: 1.5399 acc_train: 0.7167 loss_val: 1.6417 acc_val: 0.6540 time: 0.7654s\n",
      "Epoch: 0018 loss_train: 1.5482 acc_train: 0.6750 loss_val: 1.6393 acc_val: 0.6700 time: 0.7509s\n",
      "Epoch: 0019 loss_train: 1.5274 acc_train: 0.6833 loss_val: 1.6334 acc_val: 0.6740 time: 0.8133s\n",
      "Epoch: 0020 loss_train: 1.5209 acc_train: 0.6583 loss_val: 1.6252 acc_val: 0.6780 time: 0.7700s\n",
      "Epoch: 0021 loss_train: 1.5261 acc_train: 0.5833 loss_val: 1.6149 acc_val: 0.6940 time: 0.7815s\n",
      "Epoch: 0022 loss_train: 1.5137 acc_train: 0.6667 loss_val: 1.6023 acc_val: 0.7000 time: 0.7503s\n",
      "Epoch: 0023 loss_train: 1.5035 acc_train: 0.6333 loss_val: 1.5893 acc_val: 0.6960 time: 0.7682s\n",
      "Epoch: 0024 loss_train: 1.4744 acc_train: 0.6167 loss_val: 1.5755 acc_val: 0.6980 time: 0.7984s\n",
      "Epoch: 0025 loss_train: 1.4392 acc_train: 0.7083 loss_val: 1.5618 acc_val: 0.6940 time: 0.8000s\n",
      "Epoch: 0026 loss_train: 1.4312 acc_train: 0.6500 loss_val: 1.5479 acc_val: 0.6780 time: 0.8003s\n",
      "Epoch: 0027 loss_train: 1.4054 acc_train: 0.6250 loss_val: 1.5347 acc_val: 0.6620 time: 0.7552s\n",
      "Epoch: 0028 loss_train: 1.3928 acc_train: 0.6833 loss_val: 1.5201 acc_val: 0.6540 time: 0.7939s\n",
      "Epoch: 0029 loss_train: 1.3276 acc_train: 0.7250 loss_val: 1.5032 acc_val: 0.6580 time: 0.7705s\n",
      "Epoch: 0030 loss_train: 1.2877 acc_train: 0.7500 loss_val: 1.4824 acc_val: 0.6600 time: 0.7978s\n",
      "Epoch: 0031 loss_train: 1.2810 acc_train: 0.7000 loss_val: 1.4623 acc_val: 0.6720 time: 0.7838s\n",
      "Epoch: 0032 loss_train: 1.2296 acc_train: 0.7667 loss_val: 1.4411 acc_val: 0.6940 time: 0.8007s\n",
      "Epoch: 0033 loss_train: 1.2138 acc_train: 0.7500 loss_val: 1.4206 acc_val: 0.7120 time: 0.8480s\n",
      "Epoch: 0034 loss_train: 1.2100 acc_train: 0.7500 loss_val: 1.3992 acc_val: 0.7120 time: 0.8008s\n",
      "Epoch: 0035 loss_train: 1.1709 acc_train: 0.7417 loss_val: 1.3800 acc_val: 0.7060 time: 0.7997s\n",
      "Epoch: 0036 loss_train: 1.1154 acc_train: 0.7667 loss_val: 1.3682 acc_val: 0.7040 time: 0.8476s\n",
      "Epoch: 0037 loss_train: 1.0439 acc_train: 0.7583 loss_val: 1.3608 acc_val: 0.7020 time: 0.8990s\n",
      "Epoch: 0038 loss_train: 1.0737 acc_train: 0.8000 loss_val: 1.3512 acc_val: 0.6960 time: 0.8826s\n",
      "Epoch: 0039 loss_train: 1.0450 acc_train: 0.8000 loss_val: 1.3377 acc_val: 0.6800 time: 0.8829s\n",
      "Epoch: 0040 loss_train: 1.0218 acc_train: 0.7917 loss_val: 1.3254 acc_val: 0.6740 time: 0.9069s\n",
      "Epoch: 0041 loss_train: 0.9725 acc_train: 0.7500 loss_val: 1.3165 acc_val: 0.6720 time: 0.8555s\n",
      "Epoch: 0042 loss_train: 1.0042 acc_train: 0.8417 loss_val: 1.3028 acc_val: 0.6620 time: 0.7567s\n",
      "Epoch: 0043 loss_train: 0.9429 acc_train: 0.8250 loss_val: 1.2823 acc_val: 0.6780 time: 0.7803s\n",
      "Epoch: 0044 loss_train: 0.8813 acc_train: 0.8417 loss_val: 1.2602 acc_val: 0.6860 time: 0.7676s\n",
      "Epoch: 0045 loss_train: 0.8839 acc_train: 0.8417 loss_val: 1.2449 acc_val: 0.7060 time: 0.7976s\n",
      "Epoch: 0046 loss_train: 0.8227 acc_train: 0.8583 loss_val: 1.2390 acc_val: 0.7140 time: 0.7705s\n",
      "Epoch: 0047 loss_train: 0.8914 acc_train: 0.8250 loss_val: 1.2519 acc_val: 0.6960 time: 0.8123s\n",
      "Epoch: 0048 loss_train: 0.8238 acc_train: 0.8083 loss_val: 1.2653 acc_val: 0.6700 time: 0.7550s\n",
      "Epoch: 0049 loss_train: 0.9688 acc_train: 0.7583 loss_val: 1.2607 acc_val: 0.6700 time: 0.8424s\n",
      "Epoch: 0050 loss_train: 0.8659 acc_train: 0.8250 loss_val: 1.2413 acc_val: 0.6740 time: 0.8348s\n",
      "Epoch: 0051 loss_train: 0.7388 acc_train: 0.8833 loss_val: 1.2236 acc_val: 0.6960 time: 0.8647s\n",
      "Epoch: 0052 loss_train: 0.8330 acc_train: 0.8417 loss_val: 1.2110 acc_val: 0.7220 time: 0.8828s\n",
      "Epoch: 0053 loss_train: 0.8061 acc_train: 0.7917 loss_val: 1.2092 acc_val: 0.7120 time: 0.8054s\n",
      "Epoch: 0054 loss_train: 0.7954 acc_train: 0.8250 loss_val: 1.2106 acc_val: 0.7000 time: 0.7679s\n",
      "Epoch: 0055 loss_train: 0.8217 acc_train: 0.7833 loss_val: 1.2091 acc_val: 0.7000 time: 0.8300s\n",
      "Epoch: 0056 loss_train: 0.7850 acc_train: 0.8750 loss_val: 1.2003 acc_val: 0.6960 time: 0.8650s\n",
      "Epoch: 0057 loss_train: 0.8437 acc_train: 0.7750 loss_val: 1.1900 acc_val: 0.6920 time: 0.8510s\n",
      "Epoch: 0058 loss_train: 0.7182 acc_train: 0.8083 loss_val: 1.1816 acc_val: 0.6980 time: 0.7722s\n",
      "Epoch: 0059 loss_train: 0.7612 acc_train: 0.8583 loss_val: 1.1763 acc_val: 0.7060 time: 0.7981s\n",
      "Epoch: 0060 loss_train: 0.6959 acc_train: 0.8750 loss_val: 1.1738 acc_val: 0.7080 time: 0.7841s\n",
      "Epoch: 0061 loss_train: 0.8028 acc_train: 0.8417 loss_val: 1.1768 acc_val: 0.7000 time: 0.7990s\n",
      "Epoch: 0062 loss_train: 0.6933 acc_train: 0.8417 loss_val: 1.1745 acc_val: 0.7080 time: 0.7841s\n",
      "Epoch: 0063 loss_train: 0.7867 acc_train: 0.8083 loss_val: 1.1834 acc_val: 0.6980 time: 0.8723s\n",
      "Epoch: 0064 loss_train: 0.8350 acc_train: 0.7667 loss_val: 1.2003 acc_val: 0.6940 time: 0.9354s\n",
      "Epoch: 0065 loss_train: 0.7270 acc_train: 0.9000 loss_val: 1.2031 acc_val: 0.6980 time: 0.9166s\n",
      "Epoch: 0066 loss_train: 0.6872 acc_train: 0.8500 loss_val: 1.1938 acc_val: 0.6900 time: 0.9634s\n",
      "Epoch: 0067 loss_train: 0.7775 acc_train: 0.8000 loss_val: 1.1841 acc_val: 0.7000 time: 0.8318s\n",
      "Epoch: 0068 loss_train: 0.6968 acc_train: 0.8667 loss_val: 1.1819 acc_val: 0.7020 time: 0.7936s\n",
      "Epoch: 0069 loss_train: 0.7166 acc_train: 0.8417 loss_val: 1.1853 acc_val: 0.6900 time: 0.7853s\n",
      "Epoch: 0070 loss_train: 0.7820 acc_train: 0.7500 loss_val: 1.1804 acc_val: 0.6880 time: 0.8308s\n",
      "Epoch: 0071 loss_train: 0.7428 acc_train: 0.7833 loss_val: 1.1736 acc_val: 0.6960 time: 0.7799s\n",
      "Epoch: 0072 loss_train: 0.8135 acc_train: 0.7750 loss_val: 1.1798 acc_val: 0.6860 time: 0.8182s\n",
      "Epoch: 0073 loss_train: 0.7298 acc_train: 0.8583 loss_val: 1.1900 acc_val: 0.6860 time: 0.7431s\n",
      "Epoch: 0074 loss_train: 0.7757 acc_train: 0.7917 loss_val: 1.1961 acc_val: 0.6980 time: 0.8110s\n",
      "Epoch: 0075 loss_train: 0.8027 acc_train: 0.7833 loss_val: 1.1860 acc_val: 0.6940 time: 0.7703s\n",
      "Epoch: 0076 loss_train: 0.8401 acc_train: 0.7417 loss_val: 1.1737 acc_val: 0.7100 time: 0.7825s\n",
      "Epoch: 0077 loss_train: 0.7517 acc_train: 0.8000 loss_val: 1.1622 acc_val: 0.7240 time: 0.7686s\n",
      "Epoch: 0078 loss_train: 0.6564 acc_train: 0.8750 loss_val: 1.1580 acc_val: 0.7060 time: 0.7673s\n",
      "Epoch: 0079 loss_train: 0.6773 acc_train: 0.8417 loss_val: 1.1505 acc_val: 0.6940 time: 0.7540s\n",
      "Epoch: 0080 loss_train: 0.7259 acc_train: 0.8500 loss_val: 1.1445 acc_val: 0.6800 time: 0.8617s\n",
      "Epoch: 0081 loss_train: 0.8047 acc_train: 0.7667 loss_val: 1.1460 acc_val: 0.7040 time: 0.8225s\n",
      "Epoch: 0082 loss_train: 0.7596 acc_train: 0.8000 loss_val: 1.1678 acc_val: 0.6880 time: 0.7657s\n",
      "Epoch: 0083 loss_train: 0.7365 acc_train: 0.7750 loss_val: 1.1975 acc_val: 0.6600 time: 0.7646s\n",
      "Epoch: 0084 loss_train: 0.7735 acc_train: 0.8083 loss_val: 1.1871 acc_val: 0.6780 time: 0.7832s\n",
      "Epoch: 0085 loss_train: 0.7245 acc_train: 0.8167 loss_val: 1.1699 acc_val: 0.6960 time: 0.8284s\n",
      "Epoch: 0086 loss_train: 0.6582 acc_train: 0.8167 loss_val: 1.1657 acc_val: 0.7020 time: 0.8014s\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0087 loss_train: 0.7304 acc_train: 0.8083 loss_val: 1.1657 acc_val: 0.7060 time: 0.7671s\n",
      "Epoch: 0088 loss_train: 0.8136 acc_train: 0.7750 loss_val: 1.1582 acc_val: 0.7080 time: 0.7819s\n",
      "Epoch: 0089 loss_train: 0.7722 acc_train: 0.7500 loss_val: 1.1687 acc_val: 0.7000 time: 0.7685s\n",
      "Epoch: 0090 loss_train: 0.7648 acc_train: 0.7667 loss_val: 1.1954 acc_val: 0.6780 time: 0.8123s\n",
      "Epoch: 0091 loss_train: 0.7304 acc_train: 0.8000 loss_val: 1.2118 acc_val: 0.6700 time: 0.7867s\n",
      "Epoch: 0092 loss_train: 0.7279 acc_train: 0.8167 loss_val: 1.2058 acc_val: 0.6800 time: 0.7817s\n",
      "Epoch: 0093 loss_train: 0.8071 acc_train: 0.8250 loss_val: 1.1955 acc_val: 0.6800 time: 0.8160s\n",
      "Epoch: 0094 loss_train: 0.7634 acc_train: 0.8000 loss_val: 1.1927 acc_val: 0.6760 time: 0.7640s\n",
      "Epoch: 0095 loss_train: 0.8028 acc_train: 0.7667 loss_val: 1.1867 acc_val: 0.6920 time: 0.7578s\n",
      "Epoch: 0096 loss_train: 0.7581 acc_train: 0.8250 loss_val: 1.1625 acc_val: 0.6920 time: 0.8582s\n",
      "Epoch: 0097 loss_train: 0.7597 acc_train: 0.8250 loss_val: 1.1490 acc_val: 0.6800 time: 0.8908s\n",
      "Epoch: 0098 loss_train: 0.6605 acc_train: 0.8083 loss_val: 1.1449 acc_val: 0.6980 time: 0.7811s\n",
      "Epoch: 0099 loss_train: 0.7384 acc_train: 0.7833 loss_val: 1.1504 acc_val: 0.6920 time: 0.8311s\n",
      "Epoch: 0100 loss_train: 0.7338 acc_train: 0.8250 loss_val: 1.1665 acc_val: 0.6720 time: 0.8022s\n",
      "Epoch: 0101 loss_train: 0.6977 acc_train: 0.8583 loss_val: 1.1611 acc_val: 0.6880 time: 0.7691s\n",
      "Epoch: 0102 loss_train: 0.9112 acc_train: 0.7583 loss_val: 1.1682 acc_val: 0.6940 time: 0.7662s\n",
      "Epoch: 0103 loss_train: 0.6770 acc_train: 0.8500 loss_val: 1.1780 acc_val: 0.7000 time: 0.7823s\n",
      "Epoch: 0104 loss_train: 0.8017 acc_train: 0.7333 loss_val: 1.1698 acc_val: 0.6980 time: 0.7830s\n",
      "Epoch: 0105 loss_train: 0.6459 acc_train: 0.8500 loss_val: 1.1572 acc_val: 0.7020 time: 0.7432s\n",
      "Epoch: 0106 loss_train: 0.7029 acc_train: 0.8250 loss_val: 1.1430 acc_val: 0.7060 time: 0.7835s\n",
      "Epoch: 0107 loss_train: 0.7302 acc_train: 0.8083 loss_val: 1.1297 acc_val: 0.7160 time: 0.7772s\n",
      "Epoch: 0108 loss_train: 0.7097 acc_train: 0.8000 loss_val: 1.1219 acc_val: 0.7180 time: 0.7807s\n",
      "Epoch: 0109 loss_train: 0.6554 acc_train: 0.8000 loss_val: 1.1189 acc_val: 0.6900 time: 0.7696s\n",
      "Epoch: 0110 loss_train: 0.7424 acc_train: 0.7917 loss_val: 1.1260 acc_val: 0.6900 time: 0.7983s\n",
      "Epoch: 0111 loss_train: 0.6146 acc_train: 0.8500 loss_val: 1.1458 acc_val: 0.6840 time: 0.7840s\n",
      "Epoch: 0112 loss_train: 0.6414 acc_train: 0.8167 loss_val: 1.1452 acc_val: 0.6840 time: 0.7674s\n",
      "Epoch: 0113 loss_train: 0.6641 acc_train: 0.8167 loss_val: 1.1227 acc_val: 0.6960 time: 0.8089s\n",
      "Epoch: 0114 loss_train: 0.7813 acc_train: 0.8083 loss_val: 1.1239 acc_val: 0.7120 time: 0.7604s\n",
      "Epoch: 0115 loss_train: 0.6694 acc_train: 0.8333 loss_val: 1.1308 acc_val: 0.7120 time: 0.7829s\n",
      "Epoch: 0116 loss_train: 0.6935 acc_train: 0.8500 loss_val: 1.1522 acc_val: 0.7000 time: 0.7509s\n",
      "Epoch: 0117 loss_train: 0.7872 acc_train: 0.7750 loss_val: 1.1744 acc_val: 0.6920 time: 0.8123s\n",
      "Epoch: 0118 loss_train: 0.6779 acc_train: 0.8500 loss_val: 1.1843 acc_val: 0.6880 time: 0.7714s\n",
      "Epoch: 0119 loss_train: 0.6899 acc_train: 0.8417 loss_val: 1.1886 acc_val: 0.6860 time: 0.7962s\n",
      "Epoch: 0120 loss_train: 0.7518 acc_train: 0.8083 loss_val: 1.1859 acc_val: 0.6920 time: 0.7695s\n",
      "Epoch: 0121 loss_train: 0.6507 acc_train: 0.8167 loss_val: 1.1591 acc_val: 0.7100 time: 0.7500s\n",
      "Epoch: 0122 loss_train: 0.7787 acc_train: 0.7250 loss_val: 1.1389 acc_val: 0.7080 time: 0.7673s\n",
      "Epoch: 0123 loss_train: 0.6562 acc_train: 0.8333 loss_val: 1.1417 acc_val: 0.6960 time: 0.8125s\n",
      "Epoch: 0124 loss_train: 0.7091 acc_train: 0.8167 loss_val: 1.1586 acc_val: 0.6800 time: 0.7706s\n",
      "Epoch: 0125 loss_train: 0.7198 acc_train: 0.8250 loss_val: 1.1703 acc_val: 0.6560 time: 0.7825s\n",
      "Epoch: 0126 loss_train: 0.6917 acc_train: 0.8250 loss_val: 1.1618 acc_val: 0.6700 time: 0.7981s\n",
      "Epoch: 0127 loss_train: 0.6880 acc_train: 0.7917 loss_val: 1.1591 acc_val: 0.6780 time: 0.7701s\n",
      "Epoch: 0128 loss_train: 0.7062 acc_train: 0.8333 loss_val: 1.1533 acc_val: 0.7120 time: 0.8118s\n",
      "Epoch: 0129 loss_train: 0.7322 acc_train: 0.7917 loss_val: 1.1614 acc_val: 0.6980 time: 0.7653s\n",
      "Epoch: 0130 loss_train: 0.7123 acc_train: 0.8250 loss_val: 1.1524 acc_val: 0.6940 time: 0.7890s\n",
      "Epoch: 0131 loss_train: 0.7853 acc_train: 0.7750 loss_val: 1.1382 acc_val: 0.7040 time: 0.7812s\n",
      "Epoch: 0132 loss_train: 0.7348 acc_train: 0.8250 loss_val: 1.1406 acc_val: 0.7000 time: 0.7849s\n",
      "Epoch: 0133 loss_train: 0.7720 acc_train: 0.8000 loss_val: 1.1506 acc_val: 0.7160 time: 0.7525s\n",
      "Epoch: 0134 loss_train: 0.6931 acc_train: 0.8250 loss_val: 1.1627 acc_val: 0.7020 time: 0.7957s\n",
      "Epoch: 0135 loss_train: 0.6612 acc_train: 0.8250 loss_val: 1.1732 acc_val: 0.6900 time: 0.7850s\n",
      "Epoch: 0136 loss_train: 0.7181 acc_train: 0.8000 loss_val: 1.1816 acc_val: 0.6820 time: 0.7820s\n",
      "Epoch: 0137 loss_train: 0.7335 acc_train: 0.7750 loss_val: 1.1817 acc_val: 0.6680 time: 0.9376s\n",
      "Epoch: 0138 loss_train: 0.8083 acc_train: 0.7750 loss_val: 1.1587 acc_val: 0.6800 time: 0.9043s\n",
      "Epoch: 0139 loss_train: 0.6966 acc_train: 0.8083 loss_val: 1.1486 acc_val: 0.6900 time: 0.9051s\n",
      "Epoch: 0140 loss_train: 0.6691 acc_train: 0.8333 loss_val: 1.1394 acc_val: 0.7000 time: 0.8658s\n",
      "Epoch: 0141 loss_train: 0.6950 acc_train: 0.7917 loss_val: 1.1305 acc_val: 0.6820 time: 0.8659s\n",
      "Epoch: 0142 loss_train: 0.7996 acc_train: 0.7667 loss_val: 1.1352 acc_val: 0.6760 time: 0.8689s\n",
      "Epoch: 0143 loss_train: 0.6533 acc_train: 0.8333 loss_val: 1.1495 acc_val: 0.6760 time: 0.8818s\n",
      "Epoch: 0144 loss_train: 0.7533 acc_train: 0.7417 loss_val: 1.1545 acc_val: 0.6920 time: 0.9158s\n",
      "Epoch: 0145 loss_train: 0.8025 acc_train: 0.7583 loss_val: 1.1767 acc_val: 0.6780 time: 0.8824s\n",
      "Epoch: 0146 loss_train: 0.8177 acc_train: 0.8333 loss_val: 1.1909 acc_val: 0.6760 time: 0.8702s\n",
      "Epoch: 0147 loss_train: 0.7511 acc_train: 0.8083 loss_val: 1.1939 acc_val: 0.6780 time: 0.8963s\n",
      "Epoch: 0148 loss_train: 0.6859 acc_train: 0.8333 loss_val: 1.1859 acc_val: 0.6760 time: 0.8694s\n",
      "Epoch: 0149 loss_train: 0.7761 acc_train: 0.8250 loss_val: 1.1878 acc_val: 0.6840 time: 0.9125s\n",
      "Epoch: 0150 loss_train: 0.8075 acc_train: 0.8250 loss_val: 1.2048 acc_val: 0.6700 time: 0.8869s\n",
      "Epoch: 0151 loss_train: 0.8298 acc_train: 0.7417 loss_val: 1.2167 acc_val: 0.6600 time: 0.8980s\n",
      "Epoch: 0152 loss_train: 0.8156 acc_train: 0.8167 loss_val: 1.2098 acc_val: 0.6720 time: 0.8903s\n",
      "Epoch: 0153 loss_train: 0.7433 acc_train: 0.7750 loss_val: 1.1886 acc_val: 0.6740 time: 0.8908s\n",
      "Epoch: 0154 loss_train: 0.6532 acc_train: 0.8417 loss_val: 1.1747 acc_val: 0.6940 time: 1.1150s\n",
      "Epoch: 0155 loss_train: 0.6900 acc_train: 0.8083 loss_val: 1.1691 acc_val: 0.7020 time: 0.9168s\n",
      "Epoch: 0156 loss_train: 0.6460 acc_train: 0.8333 loss_val: 1.1627 acc_val: 0.7100 time: 0.8858s\n",
      "Epoch: 0157 loss_train: 0.7292 acc_train: 0.8000 loss_val: 1.1533 acc_val: 0.7100 time: 0.9139s\n",
      "Epoch: 0158 loss_train: 0.7515 acc_train: 0.8083 loss_val: 1.1408 acc_val: 0.7060 time: 0.8999s\n",
      "Epoch: 0159 loss_train: 0.7341 acc_train: 0.7750 loss_val: 1.1395 acc_val: 0.6920 time: 0.9003s\n",
      "Epoch: 0160 loss_train: 0.6538 acc_train: 0.8333 loss_val: 1.1369 acc_val: 0.6840 time: 0.8886s\n",
      "Epoch: 0161 loss_train: 0.7410 acc_train: 0.8000 loss_val: 1.1494 acc_val: 0.6740 time: 0.8935s\n",
      "Epoch: 0162 loss_train: 0.6524 acc_train: 0.8333 loss_val: 1.1544 acc_val: 0.6660 time: 0.9005s\n",
      "Epoch: 0163 loss_train: 0.6591 acc_train: 0.8333 loss_val: 1.1568 acc_val: 0.6700 time: 0.9005s\n",
      "Epoch: 0164 loss_train: 0.7632 acc_train: 0.7917 loss_val: 1.1478 acc_val: 0.6780 time: 0.8692s\n",
      "Epoch: 0165 loss_train: 0.7787 acc_train: 0.7500 loss_val: 1.1296 acc_val: 0.7020 time: 0.8983s\n",
      "Epoch: 0166 loss_train: 0.7392 acc_train: 0.8083 loss_val: 1.1259 acc_val: 0.7020 time: 0.8842s\n",
      "Epoch: 0167 loss_train: 0.6890 acc_train: 0.8667 loss_val: 1.1317 acc_val: 0.6920 time: 0.8969s\n",
      "Epoch: 0168 loss_train: 0.7449 acc_train: 0.8083 loss_val: 1.1324 acc_val: 0.6860 time: 0.9159s\n",
      "Epoch: 0169 loss_train: 0.7591 acc_train: 0.7750 loss_val: 1.1450 acc_val: 0.6940 time: 1.0735s\n",
      "Epoch: 0170 loss_train: 0.7095 acc_train: 0.8250 loss_val: 1.1718 acc_val: 0.6800 time: 1.1110s\n",
      "Epoch: 0171 loss_train: 0.7640 acc_train: 0.7750 loss_val: 1.1856 acc_val: 0.6660 time: 0.9143s\n",
      "Epoch: 0172 loss_train: 0.7058 acc_train: 0.8750 loss_val: 1.1754 acc_val: 0.6440 time: 0.8845s\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0173 loss_train: 0.7399 acc_train: 0.8167 loss_val: 1.1606 acc_val: 0.6780 time: 0.8906s\n",
      "Epoch: 0174 loss_train: 0.7113 acc_train: 0.8417 loss_val: 1.1490 acc_val: 0.6880 time: 0.8593s\n",
      "Epoch: 0175 loss_train: 0.7066 acc_train: 0.8417 loss_val: 1.1405 acc_val: 0.7100 time: 0.9016s\n",
      "Epoch: 0176 loss_train: 0.7250 acc_train: 0.7917 loss_val: 1.1329 acc_val: 0.7060 time: 0.8825s\n",
      "Epoch: 0177 loss_train: 0.5777 acc_train: 0.8667 loss_val: 1.1324 acc_val: 0.7140 time: 0.8681s\n",
      "Epoch: 0178 loss_train: 0.6608 acc_train: 0.8333 loss_val: 1.1390 acc_val: 0.7120 time: 0.9127s\n",
      "Epoch: 0179 loss_train: 0.5956 acc_train: 0.8750 loss_val: 1.1537 acc_val: 0.6920 time: 0.8697s\n",
      "Epoch: 0180 loss_train: 0.6169 acc_train: 0.8333 loss_val: 1.1738 acc_val: 0.6760 time: 0.8994s\n",
      "Epoch: 0181 loss_train: 0.6511 acc_train: 0.8667 loss_val: 1.1737 acc_val: 0.6660 time: 0.8917s\n",
      "Epoch: 0182 loss_train: 0.5893 acc_train: 0.8333 loss_val: 1.1609 acc_val: 0.6760 time: 0.9327s\n",
      "Epoch: 0183 loss_train: 0.7206 acc_train: 0.7417 loss_val: 1.1407 acc_val: 0.6900 time: 0.9069s\n",
      "Epoch: 0184 loss_train: 0.6769 acc_train: 0.8667 loss_val: 1.1200 acc_val: 0.7000 time: 0.8538s\n",
      "Epoch: 0185 loss_train: 0.6133 acc_train: 0.8583 loss_val: 1.1068 acc_val: 0.7140 time: 0.8971s\n",
      "Epoch: 0186 loss_train: 0.7279 acc_train: 0.8083 loss_val: 1.1064 acc_val: 0.7060 time: 0.9158s\n",
      "Epoch: 0187 loss_train: 0.7127 acc_train: 0.8000 loss_val: 1.1262 acc_val: 0.6960 time: 0.9012s\n",
      "Epoch: 0188 loss_train: 0.7266 acc_train: 0.7917 loss_val: 1.1582 acc_val: 0.6680 time: 0.8862s\n",
      "Epoch: 0189 loss_train: 0.7557 acc_train: 0.8250 loss_val: 1.1739 acc_val: 0.6580 time: 0.8818s\n",
      "Epoch: 0190 loss_train: 0.7135 acc_train: 0.8000 loss_val: 1.1708 acc_val: 0.6660 time: 0.8825s\n",
      "Epoch: 0191 loss_train: 0.7306 acc_train: 0.7583 loss_val: 1.1645 acc_val: 0.6840 time: 0.8842s\n",
      "Epoch: 0192 loss_train: 0.7405 acc_train: 0.8083 loss_val: 1.1447 acc_val: 0.6880 time: 0.9040s\n",
      "Epoch: 0193 loss_train: 0.7021 acc_train: 0.8250 loss_val: 1.1260 acc_val: 0.7080 time: 0.9092s\n",
      "Epoch: 0194 loss_train: 0.6719 acc_train: 0.8250 loss_val: 1.1209 acc_val: 0.7080 time: 0.8711s\n",
      "Epoch: 0195 loss_train: 0.6698 acc_train: 0.8083 loss_val: 1.1197 acc_val: 0.6960 time: 0.8952s\n",
      "Epoch: 0196 loss_train: 0.7767 acc_train: 0.7583 loss_val: 1.1297 acc_val: 0.7040 time: 0.8871s\n",
      "Epoch: 0197 loss_train: 0.6347 acc_train: 0.8417 loss_val: 1.1364 acc_val: 0.7060 time: 0.8685s\n",
      "Epoch: 0198 loss_train: 0.7867 acc_train: 0.7333 loss_val: 1.1408 acc_val: 0.6960 time: 0.9295s\n",
      "Epoch: 0199 loss_train: 0.6090 acc_train: 0.8417 loss_val: 1.1371 acc_val: 0.6940 time: 0.9013s\n",
      "Epoch: 0200 loss_train: 0.6405 acc_train: 0.8083 loss_val: 1.1204 acc_val: 0.7140 time: 0.8839s\n",
      "Optimization Finished!\n",
      "Total time elapsed: 170.7523s\n",
      "Test set results: loss= 1.0855 accuracy= 0.7230\n"
     ]
    }
   ],
   "source": [
    "# Train model\n",
    "t_total = time.time()\n",
    "for epoch in range(args.epochs):\n",
    "    train(epoch)\n",
    "print(\"Optimization Finished!\")\n",
    "print(\"Total time elapsed: {:.4f}s\".format(time.time() - t_total))\n",
    "\n",
    "# Testing\n",
    "test()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
