{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import division\n",
    "from __future__ import print_function\n",
    "\n",
    "import time\n",
    "import argparse\n",
    "import numpy as np\n",
    "import os\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import networkx as nx\n",
    "from scipy import sparse\n",
    "from scipy.linalg import fractional_matrix_power\n",
    "\n",
    "from utils import *\n",
    "from models import GNN\n",
    "from dataset_utils import DataLoader\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "_StoreAction(option_strings=['--dataset'], dest='dataset', nargs=None, const=None, default='cora', type=None, choices=None, help='Dataset name.', metavar=None)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "parser = argparse.ArgumentParser()\n",
    "parser.add_argument('--no-cuda', action='store_true', default=False,\n",
    "                    help='Disables CUDA training.')\n",
    "parser.add_argument('--fastmode', action='store_true', default=False,\n",
    "                    help='Validate during training pass.')\n",
    "parser.add_argument('--seed', type=int, default=42, help='Random seed.')\n",
    "parser.add_argument('--epochs', type=int, default=200,\n",
    "                    help='Number of epochs to train.')\n",
    "parser.add_argument('--lr', type=float, default=0.001,\n",
    "                    help='Initial learning rate.')\n",
    "parser.add_argument('--weight_decay', type=float, default=8e-3,\n",
    "                    help='Weight decay (L2 loss on parameters).')\n",
    "parser.add_argument('--hidden', type=int, default=64,\n",
    "                    help='Number of hidden units.')\n",
    "parser.add_argument('--dropout', type=float, default=0.9,\n",
    "                    help='Dropout rate (1 - keep probability).')\n",
    "parser.add_argument('--dataset', default='cora', help='Dataset name.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = parser.parse_args(\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x1e515de4900>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.random.seed(args.seed)\n",
    "torch.manual_seed(args.seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "dname = args.dataset\n",
    "dataset = DataLoader(dname)\n",
    "data = dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "A_norm, A, X, labels, idx_train, idx_val, idx_test = load_citation_data(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "G = nx.from_numpy_matrix(A)\n",
    "feature_dictionary = {}\n",
    "\n",
    "for i in np.arange(len(labels)):\n",
    "    feature_dictionary[i] = labels[i]\n",
    "\n",
    "nx.set_node_attributes(G, feature_dictionary, \"attr_name\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "sub_graphs = []\n",
    "\n",
    "for i in np.arange(len(A)):\n",
    "    s_indexes = []\n",
    "    for j in np.arange(len(A)):\n",
    "        s_indexes.append(i)\n",
    "        if(A[i][j]==1):\n",
    "            s_indexes.append(j)\n",
    "    sub_graphs.append(G.subgraph(s_indexes))\n",
    "\n",
    "subgraph_nodes_list = []\n",
    "\n",
    "for i in np.arange(len(sub_graphs)):\n",
    "    subgraph_nodes_list.append(list(sub_graphs[i].nodes))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "sub_graphs_adj = []\n",
    "for index in np.arange(len(sub_graphs)):\n",
    "    sub_graphs_adj.append(nx.adjacency_matrix(sub_graphs[index]).toarray())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_adj = torch.zeros(A.shape[0], A.shape[0])\n",
    "\n",
    "for node in np.arange(len(subgraph_nodes_list)):\n",
    "    sub_adj = sub_graphs_adj[node]\n",
    "    for neighbors in np.arange(len(subgraph_nodes_list[node])):\n",
    "        index = subgraph_nodes_list[node][neighbors]\n",
    "        count = torch.tensor(0).float()\n",
    "        if(index==node):\n",
    "            continue\n",
    "        else:\n",
    "            c_neighbors = set(subgraph_nodes_list[node]).intersection(subgraph_nodes_list[index])\n",
    "            if index in c_neighbors:\n",
    "                nodes_list = subgraph_nodes_list[node]\n",
    "                sub_graph_index = nodes_list.index(index)\n",
    "                c_neighbors_list = list(c_neighbors)\n",
    "                for i, item1 in enumerate(nodes_list):\n",
    "                    if(item1 in c_neighbors):\n",
    "                        for item2 in c_neighbors_list:\n",
    "                            j = nodes_list.index(item2)\n",
    "                            count += sub_adj[i][j]\n",
    "\n",
    "            new_adj[node][index] = count/2\n",
    "            new_adj[node][index] = new_adj[node][index]/(len(c_neighbors)*(len(c_neighbors)-1))\n",
    "            new_adj[node][index] = new_adj[node][index] * (len(c_neighbors)**1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "features = torch.FloatTensor(X)\n",
    "labels = torch.LongTensor(labels)\n",
    "\n",
    "weight = torch.FloatTensor(new_adj)\n",
    "weight = weight / weight.sum(1, keepdim=True)\n",
    "\n",
    "weight = weight + torch.FloatTensor(A)\n",
    "\n",
    "coeff = weight.sum(1, keepdim=True)\n",
    "coeff = torch.diag((coeff.T)[0])\n",
    "\n",
    "weight = weight + coeff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "weight = weight.detach().numpy()\n",
    "weight = np.nan_to_num(weight, nan=0)\n",
    "\n",
    "row_sum = np.array(np.sum(weight, axis=1))\n",
    "degree_matrix = np.matrix(np.diag(row_sum+1))\n",
    "\n",
    "D = fractional_matrix_power(degree_matrix, -0.5)\n",
    "A_tilde_hat = D.dot(weight).dot(D)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "adj = torch.FloatTensor(A_tilde_hat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model and optimizer\n",
    "model = GNN(nfeat=features.shape[1],\n",
    "            nhid=args.hidden,\n",
    "            nclass=labels.max().item() + 1,\n",
    "            dropout=args.dropout)\n",
    "\n",
    "optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(epoch):\n",
    "    t = time.time()\n",
    "    model.train()\n",
    "    optimizer.zero_grad()\n",
    "    output = model(features, adj)\n",
    "    loss_train = F.nll_loss(output[idx_train], labels[idx_train])\n",
    "    acc_train = accuracy(output[idx_train], labels[idx_train])\n",
    "    loss_train.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    if not args.fastmode:\n",
    "        # Evaluate validation set performance separately, deactivates dropout during validation run.\n",
    "        model.eval()\n",
    "        output = model(features, adj)\n",
    "\n",
    "    loss_val = F.nll_loss(output[idx_val], labels[idx_val])\n",
    "    acc_val = accuracy(output[idx_val], labels[idx_val])\n",
    "    print('Epoch: {:04d}'.format(epoch+1),\n",
    "          'loss_train: {:.4f}'.format(loss_train.item()),\n",
    "          'acc_train: {:.4f}'.format(acc_train.item()),\n",
    "          'loss_val: {:.4f}'.format(loss_val.item()),\n",
    "          'acc_val: {:.4f}'.format(acc_val.item()),\n",
    "          'time: {:.4f}s'.format(time.time() - t))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test():\n",
    "    model.eval()\n",
    "    output = model(features, adj)\n",
    "    loss_test = F.nll_loss(output[idx_test], labels[idx_test])\n",
    "    acc_test = accuracy(output[idx_test], labels[idx_test])\n",
    "    print(\"Test set results:\",\n",
    "          \"loss= {:.4f}\".format(loss_test.item()),\n",
    "          \"accuracy= {:.4f}\".format(acc_test.item()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0001 loss_train: 1.9803 acc_train: 0.1429 loss_val: 1.8961 acc_val: 0.2100 time: 0.6183s\n",
      "Epoch: 0002 loss_train: 1.9559 acc_train: 0.1500 loss_val: 1.8945 acc_val: 0.2140 time: 0.5717s\n",
      "Epoch: 0003 loss_train: 1.9697 acc_train: 0.1643 loss_val: 1.8928 acc_val: 0.2140 time: 0.5407s\n",
      "Epoch: 0004 loss_train: 1.9813 acc_train: 0.1571 loss_val: 1.8911 acc_val: 0.2180 time: 0.5549s\n",
      "Epoch: 0005 loss_train: 1.9538 acc_train: 0.1571 loss_val: 1.8893 acc_val: 0.2300 time: 0.5884s\n",
      "Epoch: 0006 loss_train: 1.9510 acc_train: 0.2000 loss_val: 1.8876 acc_val: 0.2420 time: 0.5720s\n",
      "Epoch: 0007 loss_train: 1.9517 acc_train: 0.1643 loss_val: 1.8859 acc_val: 0.2480 time: 0.5866s\n",
      "Epoch: 0008 loss_train: 1.9364 acc_train: 0.1929 loss_val: 1.8842 acc_val: 0.2520 time: 0.6297s\n",
      "Epoch: 0009 loss_train: 1.9380 acc_train: 0.1714 loss_val: 1.8826 acc_val: 0.2600 time: 0.5395s\n",
      "Epoch: 0010 loss_train: 1.9356 acc_train: 0.1643 loss_val: 1.8811 acc_val: 0.2640 time: 0.5095s\n",
      "Epoch: 0011 loss_train: 1.9317 acc_train: 0.1714 loss_val: 1.8796 acc_val: 0.2640 time: 0.5047s\n",
      "Epoch: 0012 loss_train: 1.9310 acc_train: 0.1857 loss_val: 1.8780 acc_val: 0.2700 time: 0.4993s\n",
      "Epoch: 0013 loss_train: 1.9361 acc_train: 0.2000 loss_val: 1.8765 acc_val: 0.2760 time: 0.4890s\n",
      "Epoch: 0014 loss_train: 1.9198 acc_train: 0.2143 loss_val: 1.8749 acc_val: 0.2780 time: 0.5349s\n",
      "Epoch: 0015 loss_train: 1.9146 acc_train: 0.1857 loss_val: 1.8734 acc_val: 0.2820 time: 0.4927s\n",
      "Epoch: 0016 loss_train: 1.9254 acc_train: 0.1929 loss_val: 1.8719 acc_val: 0.2840 time: 0.5052s\n",
      "Epoch: 0017 loss_train: 1.9113 acc_train: 0.2571 loss_val: 1.8703 acc_val: 0.2860 time: 0.5217s\n",
      "Epoch: 0018 loss_train: 1.9079 acc_train: 0.2286 loss_val: 1.8686 acc_val: 0.2920 time: 0.4807s\n",
      "Epoch: 0019 loss_train: 1.9101 acc_train: 0.1929 loss_val: 1.8670 acc_val: 0.3020 time: 0.4903s\n",
      "Epoch: 0020 loss_train: 1.8989 acc_train: 0.1786 loss_val: 1.8653 acc_val: 0.3040 time: 0.5253s\n",
      "Epoch: 0021 loss_train: 1.8974 acc_train: 0.2429 loss_val: 1.8637 acc_val: 0.3040 time: 0.4568s\n",
      "Epoch: 0022 loss_train: 1.8782 acc_train: 0.2214 loss_val: 1.8620 acc_val: 0.3080 time: 0.4259s\n",
      "Epoch: 0023 loss_train: 1.8725 acc_train: 0.2214 loss_val: 1.8604 acc_val: 0.3180 time: 0.4318s\n",
      "Epoch: 0024 loss_train: 1.8767 acc_train: 0.2500 loss_val: 1.8587 acc_val: 0.3240 time: 0.4602s\n",
      "Epoch: 0025 loss_train: 1.8851 acc_train: 0.2071 loss_val: 1.8570 acc_val: 0.3260 time: 0.4221s\n",
      "Epoch: 0026 loss_train: 1.8752 acc_train: 0.2429 loss_val: 1.8552 acc_val: 0.3300 time: 0.4348s\n",
      "Epoch: 0027 loss_train: 1.8515 acc_train: 0.2786 loss_val: 1.8535 acc_val: 0.3340 time: 0.4568s\n",
      "Epoch: 0028 loss_train: 1.8662 acc_train: 0.2357 loss_val: 1.8516 acc_val: 0.3360 time: 0.4408s\n",
      "Epoch: 0029 loss_train: 1.8764 acc_train: 0.2214 loss_val: 1.8498 acc_val: 0.3420 time: 0.4318s\n",
      "Epoch: 0030 loss_train: 1.8421 acc_train: 0.2714 loss_val: 1.8480 acc_val: 0.3460 time: 0.4404s\n",
      "Epoch: 0031 loss_train: 1.8506 acc_train: 0.2857 loss_val: 1.8462 acc_val: 0.3480 time: 0.4747s\n",
      "Epoch: 0032 loss_train: 1.8534 acc_train: 0.3000 loss_val: 1.8443 acc_val: 0.3520 time: 0.4459s\n",
      "Epoch: 0033 loss_train: 1.8322 acc_train: 0.3214 loss_val: 1.8424 acc_val: 0.3520 time: 0.4199s\n",
      "Epoch: 0034 loss_train: 1.8368 acc_train: 0.3000 loss_val: 1.8405 acc_val: 0.3520 time: 0.4465s\n",
      "Epoch: 0035 loss_train: 1.8359 acc_train: 0.3143 loss_val: 1.8385 acc_val: 0.3540 time: 0.4456s\n",
      "Epoch: 0036 loss_train: 1.8265 acc_train: 0.2643 loss_val: 1.8364 acc_val: 0.3540 time: 0.4202s\n",
      "Epoch: 0037 loss_train: 1.8069 acc_train: 0.3500 loss_val: 1.8343 acc_val: 0.3540 time: 0.4388s\n",
      "Epoch: 0038 loss_train: 1.8190 acc_train: 0.3214 loss_val: 1.8321 acc_val: 0.3580 time: 0.4610s\n",
      "Epoch: 0039 loss_train: 1.8176 acc_train: 0.3429 loss_val: 1.8298 acc_val: 0.3700 time: 0.4408s\n",
      "Epoch: 0040 loss_train: 1.8288 acc_train: 0.2857 loss_val: 1.8275 acc_val: 0.3720 time: 0.4279s\n",
      "Epoch: 0041 loss_train: 1.7796 acc_train: 0.4286 loss_val: 1.8252 acc_val: 0.3760 time: 0.4512s\n",
      "Epoch: 0042 loss_train: 1.8012 acc_train: 0.3929 loss_val: 1.8228 acc_val: 0.3780 time: 0.4388s\n",
      "Epoch: 0043 loss_train: 1.8081 acc_train: 0.4000 loss_val: 1.8204 acc_val: 0.3820 time: 0.4408s\n",
      "Epoch: 0044 loss_train: 1.7702 acc_train: 0.3929 loss_val: 1.8179 acc_val: 0.3880 time: 0.4179s\n",
      "Epoch: 0045 loss_train: 1.7975 acc_train: 0.3571 loss_val: 1.8153 acc_val: 0.3980 time: 0.4604s\n",
      "Epoch: 0046 loss_train: 1.7656 acc_train: 0.4357 loss_val: 1.8126 acc_val: 0.4040 time: 0.4408s\n",
      "Epoch: 0047 loss_train: 1.7747 acc_train: 0.4000 loss_val: 1.8100 acc_val: 0.4080 time: 0.4249s\n",
      "Epoch: 0048 loss_train: 1.7432 acc_train: 0.5071 loss_val: 1.8073 acc_val: 0.4140 time: 0.4438s\n",
      "Epoch: 0049 loss_train: 1.7579 acc_train: 0.4357 loss_val: 1.8045 acc_val: 0.4200 time: 0.4488s\n",
      "Epoch: 0050 loss_train: 1.7493 acc_train: 0.4714 loss_val: 1.8017 acc_val: 0.4240 time: 0.4478s\n",
      "Epoch: 0051 loss_train: 1.7489 acc_train: 0.4429 loss_val: 1.7989 acc_val: 0.4320 time: 0.4249s\n",
      "Epoch: 0052 loss_train: 1.7515 acc_train: 0.4786 loss_val: 1.7959 acc_val: 0.4380 time: 0.4553s\n",
      "Epoch: 0053 loss_train: 1.7576 acc_train: 0.3714 loss_val: 1.7930 acc_val: 0.4500 time: 0.4299s\n",
      "Epoch: 0054 loss_train: 1.7267 acc_train: 0.4857 loss_val: 1.7899 acc_val: 0.4580 time: 0.4408s\n",
      "Epoch: 0055 loss_train: 1.7066 acc_train: 0.5357 loss_val: 1.7868 acc_val: 0.4620 time: 0.4279s\n",
      "Epoch: 0056 loss_train: 1.7073 acc_train: 0.5500 loss_val: 1.7837 acc_val: 0.4700 time: 0.4658s\n",
      "Epoch: 0057 loss_train: 1.7161 acc_train: 0.5286 loss_val: 1.7805 acc_val: 0.4840 time: 0.4318s\n",
      "Epoch: 0058 loss_train: 1.7305 acc_train: 0.4643 loss_val: 1.7772 acc_val: 0.5040 time: 0.4229s\n",
      "Epoch: 0059 loss_train: 1.7117 acc_train: 0.5286 loss_val: 1.7738 acc_val: 0.5140 time: 0.4513s\n",
      "Epoch: 0060 loss_train: 1.6836 acc_train: 0.5714 loss_val: 1.7704 acc_val: 0.5240 time: 0.4418s\n",
      "Epoch: 0061 loss_train: 1.6932 acc_train: 0.5286 loss_val: 1.7669 acc_val: 0.5360 time: 0.4365s\n",
      "Epoch: 0062 loss_train: 1.6902 acc_train: 0.5714 loss_val: 1.7634 acc_val: 0.5520 time: 0.4249s\n",
      "Epoch: 0063 loss_train: 1.6579 acc_train: 0.5857 loss_val: 1.7598 acc_val: 0.5620 time: 0.4767s\n",
      "Epoch: 0064 loss_train: 1.6365 acc_train: 0.6214 loss_val: 1.7559 acc_val: 0.5780 time: 0.4534s\n",
      "Epoch: 0065 loss_train: 1.6471 acc_train: 0.6500 loss_val: 1.7520 acc_val: 0.5880 time: 0.4418s\n",
      "Epoch: 0066 loss_train: 1.6409 acc_train: 0.6143 loss_val: 1.7480 acc_val: 0.5980 time: 0.4199s\n",
      "Epoch: 0067 loss_train: 1.6536 acc_train: 0.6643 loss_val: 1.7440 acc_val: 0.6080 time: 0.4608s\n",
      "Epoch: 0068 loss_train: 1.6659 acc_train: 0.5857 loss_val: 1.7398 acc_val: 0.6160 time: 0.4259s\n",
      "Epoch: 0069 loss_train: 1.6312 acc_train: 0.6071 loss_val: 1.7357 acc_val: 0.6220 time: 0.4239s\n",
      "Epoch: 0070 loss_train: 1.6345 acc_train: 0.6786 loss_val: 1.7315 acc_val: 0.6320 time: 0.4639s\n",
      "Epoch: 0071 loss_train: 1.5780 acc_train: 0.8071 loss_val: 1.7272 acc_val: 0.6440 time: 0.4299s\n",
      "Epoch: 0072 loss_train: 1.6102 acc_train: 0.6357 loss_val: 1.7229 acc_val: 0.6460 time: 0.4398s\n",
      "Epoch: 0073 loss_train: 1.5794 acc_train: 0.6786 loss_val: 1.7187 acc_val: 0.6520 time: 0.4269s\n",
      "Epoch: 0074 loss_train: 1.6131 acc_train: 0.6357 loss_val: 1.7144 acc_val: 0.6580 time: 0.4748s\n",
      "Epoch: 0075 loss_train: 1.5667 acc_train: 0.6571 loss_val: 1.7100 acc_val: 0.6680 time: 0.4249s\n",
      "Epoch: 0076 loss_train: 1.5993 acc_train: 0.6429 loss_val: 1.7055 acc_val: 0.6700 time: 0.4388s\n",
      "Epoch: 0077 loss_train: 1.5697 acc_train: 0.6286 loss_val: 1.7010 acc_val: 0.6780 time: 0.4269s\n",
      "Epoch: 0078 loss_train: 1.5567 acc_train: 0.6786 loss_val: 1.6964 acc_val: 0.6820 time: 0.4896s\n",
      "Epoch: 0079 loss_train: 1.5582 acc_train: 0.6786 loss_val: 1.6919 acc_val: 0.6860 time: 0.5026s\n",
      "Epoch: 0080 loss_train: 1.5532 acc_train: 0.6857 loss_val: 1.6873 acc_val: 0.6980 time: 0.4843s\n",
      "Epoch: 0081 loss_train: 1.5410 acc_train: 0.7500 loss_val: 1.6827 acc_val: 0.7000 time: 0.5202s\n",
      "Epoch: 0082 loss_train: 1.5177 acc_train: 0.7214 loss_val: 1.6780 acc_val: 0.7100 time: 0.4793s\n",
      "Epoch: 0083 loss_train: 1.5616 acc_train: 0.6643 loss_val: 1.6732 acc_val: 0.7160 time: 0.4827s\n",
      "Epoch: 0084 loss_train: 1.5265 acc_train: 0.7286 loss_val: 1.6685 acc_val: 0.7260 time: 0.5067s\n",
      "Epoch: 0085 loss_train: 1.4954 acc_train: 0.7143 loss_val: 1.6637 acc_val: 0.7360 time: 0.5005s\n",
      "Epoch: 0086 loss_train: 1.4491 acc_train: 0.7571 loss_val: 1.6589 acc_val: 0.7380 time: 0.5182s\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0087 loss_train: 1.4926 acc_train: 0.7357 loss_val: 1.6541 acc_val: 0.7400 time: 0.4515s\n",
      "Epoch: 0088 loss_train: 1.4965 acc_train: 0.7143 loss_val: 1.6493 acc_val: 0.7420 time: 0.4973s\n",
      "Epoch: 0089 loss_train: 1.4303 acc_train: 0.7714 loss_val: 1.6445 acc_val: 0.7440 time: 0.4299s\n",
      "Epoch: 0090 loss_train: 1.4508 acc_train: 0.7214 loss_val: 1.6396 acc_val: 0.7460 time: 0.4239s\n",
      "Epoch: 0091 loss_train: 1.4527 acc_train: 0.7714 loss_val: 1.6347 acc_val: 0.7480 time: 0.4659s\n",
      "Epoch: 0092 loss_train: 1.4583 acc_train: 0.7643 loss_val: 1.6298 acc_val: 0.7500 time: 0.4209s\n",
      "Epoch: 0093 loss_train: 1.4275 acc_train: 0.7429 loss_val: 1.6248 acc_val: 0.7520 time: 0.4398s\n",
      "Epoch: 0094 loss_train: 1.4201 acc_train: 0.7429 loss_val: 1.6197 acc_val: 0.7540 time: 0.4588s\n",
      "Epoch: 0095 loss_train: 1.4193 acc_train: 0.7429 loss_val: 1.6146 acc_val: 0.7540 time: 0.4328s\n",
      "Epoch: 0096 loss_train: 1.4085 acc_train: 0.7786 loss_val: 1.6094 acc_val: 0.7600 time: 0.4269s\n",
      "Epoch: 0097 loss_train: 1.4357 acc_train: 0.7286 loss_val: 1.6042 acc_val: 0.7620 time: 0.4366s\n",
      "Epoch: 0098 loss_train: 1.4269 acc_train: 0.7214 loss_val: 1.5989 acc_val: 0.7640 time: 0.4548s\n",
      "Epoch: 0099 loss_train: 1.4218 acc_train: 0.7214 loss_val: 1.5938 acc_val: 0.7660 time: 0.4308s\n",
      "Epoch: 0100 loss_train: 1.3253 acc_train: 0.8071 loss_val: 1.5885 acc_val: 0.7680 time: 0.4857s\n",
      "Epoch: 0101 loss_train: 1.3604 acc_train: 0.8143 loss_val: 1.5833 acc_val: 0.7660 time: 0.4448s\n",
      "Epoch: 0102 loss_train: 1.3494 acc_train: 0.7643 loss_val: 1.5781 acc_val: 0.7660 time: 0.4249s\n",
      "Epoch: 0103 loss_train: 1.3684 acc_train: 0.7571 loss_val: 1.5728 acc_val: 0.7720 time: 0.4240s\n",
      "Epoch: 0104 loss_train: 1.3621 acc_train: 0.7643 loss_val: 1.5675 acc_val: 0.7700 time: 0.4763s\n",
      "Epoch: 0105 loss_train: 1.3636 acc_train: 0.8286 loss_val: 1.5623 acc_val: 0.7760 time: 0.4360s\n",
      "Epoch: 0106 loss_train: 1.4022 acc_train: 0.7357 loss_val: 1.5572 acc_val: 0.7740 time: 0.5266s\n",
      "Epoch: 0107 loss_train: 1.3586 acc_train: 0.7714 loss_val: 1.5520 acc_val: 0.7740 time: 0.4907s\n",
      "Epoch: 0108 loss_train: 1.3185 acc_train: 0.8286 loss_val: 1.5469 acc_val: 0.7720 time: 0.5726s\n",
      "Epoch: 0109 loss_train: 1.3424 acc_train: 0.7571 loss_val: 1.5417 acc_val: 0.7700 time: 0.4747s\n",
      "Epoch: 0110 loss_train: 1.2847 acc_train: 0.8286 loss_val: 1.5365 acc_val: 0.7660 time: 0.4468s\n",
      "Epoch: 0111 loss_train: 1.3532 acc_train: 0.7786 loss_val: 1.5313 acc_val: 0.7680 time: 0.5386s\n",
      "Epoch: 0112 loss_train: 1.2806 acc_train: 0.8000 loss_val: 1.5261 acc_val: 0.7660 time: 0.5037s\n",
      "Epoch: 0113 loss_train: 1.2539 acc_train: 0.8429 loss_val: 1.5208 acc_val: 0.7660 time: 0.4498s\n",
      "Epoch: 0114 loss_train: 1.2775 acc_train: 0.7929 loss_val: 1.5155 acc_val: 0.7660 time: 0.5116s\n",
      "Epoch: 0115 loss_train: 1.2263 acc_train: 0.7857 loss_val: 1.5101 acc_val: 0.7660 time: 0.4717s\n",
      "Epoch: 0116 loss_train: 1.2905 acc_train: 0.7929 loss_val: 1.5047 acc_val: 0.7660 time: 0.4259s\n",
      "Epoch: 0117 loss_train: 1.2249 acc_train: 0.8286 loss_val: 1.4992 acc_val: 0.7660 time: 0.4209s\n",
      "Epoch: 0118 loss_train: 1.2318 acc_train: 0.8357 loss_val: 1.4938 acc_val: 0.7660 time: 0.5147s\n",
      "Epoch: 0119 loss_train: 1.1989 acc_train: 0.8143 loss_val: 1.4882 acc_val: 0.7660 time: 0.4229s\n",
      "Epoch: 0120 loss_train: 1.2208 acc_train: 0.8714 loss_val: 1.4825 acc_val: 0.7680 time: 0.4767s\n",
      "Epoch: 0121 loss_train: 1.2223 acc_train: 0.8214 loss_val: 1.4769 acc_val: 0.7720 time: 0.4548s\n",
      "Epoch: 0122 loss_train: 1.2004 acc_train: 0.7786 loss_val: 1.4713 acc_val: 0.7720 time: 0.4568s\n",
      "Epoch: 0123 loss_train: 1.1755 acc_train: 0.8714 loss_val: 1.4655 acc_val: 0.7720 time: 0.4518s\n",
      "Epoch: 0124 loss_train: 1.2279 acc_train: 0.7929 loss_val: 1.4597 acc_val: 0.7720 time: 0.5126s\n",
      "Epoch: 0125 loss_train: 1.1661 acc_train: 0.8786 loss_val: 1.4540 acc_val: 0.7740 time: 0.4874s\n",
      "Epoch: 0126 loss_train: 1.1498 acc_train: 0.8357 loss_val: 1.4483 acc_val: 0.7720 time: 0.5695s\n",
      "Epoch: 0127 loss_train: 1.1621 acc_train: 0.8643 loss_val: 1.4426 acc_val: 0.7720 time: 0.5186s\n",
      "Epoch: 0128 loss_train: 1.2148 acc_train: 0.8143 loss_val: 1.4369 acc_val: 0.7720 time: 0.6163s\n",
      "Epoch: 0129 loss_train: 1.1475 acc_train: 0.8500 loss_val: 1.4312 acc_val: 0.7720 time: 0.6213s\n",
      "Epoch: 0130 loss_train: 1.1366 acc_train: 0.8714 loss_val: 1.4254 acc_val: 0.7740 time: 0.4848s\n",
      "Epoch: 0131 loss_train: 1.0509 acc_train: 0.8714 loss_val: 1.4197 acc_val: 0.7740 time: 0.6168s\n",
      "Epoch: 0132 loss_train: 1.1286 acc_train: 0.8429 loss_val: 1.4139 acc_val: 0.7740 time: 0.5053s\n",
      "Epoch: 0133 loss_train: 1.0582 acc_train: 0.8143 loss_val: 1.4080 acc_val: 0.7760 time: 0.5047s\n",
      "Epoch: 0134 loss_train: 1.0979 acc_train: 0.8714 loss_val: 1.4022 acc_val: 0.7800 time: 0.4479s\n",
      "Epoch: 0135 loss_train: 1.0723 acc_train: 0.8571 loss_val: 1.3963 acc_val: 0.7780 time: 0.4368s\n",
      "Epoch: 0136 loss_train: 1.0593 acc_train: 0.8571 loss_val: 1.3905 acc_val: 0.7800 time: 0.4149s\n",
      "Epoch: 0137 loss_train: 1.1256 acc_train: 0.8500 loss_val: 1.3847 acc_val: 0.7800 time: 0.4573s\n",
      "Epoch: 0138 loss_train: 1.0904 acc_train: 0.8429 loss_val: 1.3789 acc_val: 0.7800 time: 0.4269s\n",
      "Epoch: 0139 loss_train: 1.0703 acc_train: 0.8429 loss_val: 1.3733 acc_val: 0.7800 time: 0.4488s\n",
      "Epoch: 0140 loss_train: 1.0740 acc_train: 0.8500 loss_val: 1.3677 acc_val: 0.7820 time: 0.4717s\n",
      "Epoch: 0141 loss_train: 0.9920 acc_train: 0.8857 loss_val: 1.3621 acc_val: 0.7840 time: 0.5276s\n",
      "Epoch: 0142 loss_train: 1.0461 acc_train: 0.8786 loss_val: 1.3565 acc_val: 0.7860 time: 0.5216s\n",
      "Epoch: 0143 loss_train: 1.0248 acc_train: 0.8786 loss_val: 1.3509 acc_val: 0.7860 time: 0.5086s\n",
      "Epoch: 0144 loss_train: 1.0596 acc_train: 0.8571 loss_val: 1.3452 acc_val: 0.7860 time: 0.4827s\n",
      "Epoch: 0145 loss_train: 1.0417 acc_train: 0.8286 loss_val: 1.3397 acc_val: 0.7860 time: 0.5136s\n",
      "Epoch: 0146 loss_train: 1.0609 acc_train: 0.8286 loss_val: 1.3342 acc_val: 0.7840 time: 0.4628s\n",
      "Epoch: 0147 loss_train: 1.0626 acc_train: 0.8714 loss_val: 1.3288 acc_val: 0.7840 time: 0.5027s\n",
      "Epoch: 0148 loss_train: 1.0151 acc_train: 0.8643 loss_val: 1.3235 acc_val: 0.7880 time: 0.4588s\n",
      "Epoch: 0149 loss_train: 1.0143 acc_train: 0.8571 loss_val: 1.3182 acc_val: 0.7900 time: 0.4717s\n",
      "Epoch: 0150 loss_train: 1.0100 acc_train: 0.8786 loss_val: 1.3131 acc_val: 0.7900 time: 0.4458s\n",
      "Epoch: 0151 loss_train: 0.9939 acc_train: 0.8643 loss_val: 1.3081 acc_val: 0.7900 time: 0.5226s\n",
      "Epoch: 0152 loss_train: 0.9480 acc_train: 0.8714 loss_val: 1.3031 acc_val: 0.7900 time: 0.4837s\n",
      "Epoch: 0153 loss_train: 0.9811 acc_train: 0.8643 loss_val: 1.2981 acc_val: 0.7900 time: 0.5605s\n",
      "Epoch: 0154 loss_train: 0.9699 acc_train: 0.8429 loss_val: 1.2933 acc_val: 0.7880 time: 0.4851s\n",
      "Epoch: 0155 loss_train: 0.9391 acc_train: 0.8786 loss_val: 1.2884 acc_val: 0.7880 time: 0.4807s\n",
      "Epoch: 0156 loss_train: 0.9214 acc_train: 0.8929 loss_val: 1.2834 acc_val: 0.7880 time: 0.4259s\n",
      "Epoch: 0157 loss_train: 0.9927 acc_train: 0.8429 loss_val: 1.2785 acc_val: 0.7880 time: 0.4585s\n",
      "Epoch: 0158 loss_train: 0.9346 acc_train: 0.8714 loss_val: 1.2739 acc_val: 0.7880 time: 0.4209s\n",
      "Epoch: 0159 loss_train: 0.9397 acc_train: 0.9000 loss_val: 1.2693 acc_val: 0.7880 time: 0.4229s\n",
      "Epoch: 0160 loss_train: 0.8787 acc_train: 0.8786 loss_val: 1.2648 acc_val: 0.7880 time: 0.4669s\n",
      "Epoch: 0161 loss_train: 0.9457 acc_train: 0.8571 loss_val: 1.2606 acc_val: 0.7880 time: 0.4169s\n",
      "Epoch: 0162 loss_train: 0.9057 acc_train: 0.8714 loss_val: 1.2566 acc_val: 0.7900 time: 0.4618s\n",
      "Epoch: 0163 loss_train: 0.9410 acc_train: 0.8571 loss_val: 1.2530 acc_val: 0.7900 time: 0.4737s\n",
      "Epoch: 0164 loss_train: 0.9118 acc_train: 0.8929 loss_val: 1.2493 acc_val: 0.7860 time: 0.4586s\n",
      "Epoch: 0165 loss_train: 0.8512 acc_train: 0.8643 loss_val: 1.2455 acc_val: 0.7860 time: 0.4159s\n",
      "Epoch: 0166 loss_train: 0.9121 acc_train: 0.9071 loss_val: 1.2417 acc_val: 0.7860 time: 0.4508s\n",
      "Epoch: 0167 loss_train: 0.9163 acc_train: 0.8357 loss_val: 1.2378 acc_val: 0.7860 time: 0.5515s\n",
      "Epoch: 0168 loss_train: 0.9046 acc_train: 0.8357 loss_val: 1.2340 acc_val: 0.7840 time: 0.5406s\n",
      "Epoch: 0169 loss_train: 0.8788 acc_train: 0.9000 loss_val: 1.2300 acc_val: 0.7840 time: 0.4628s\n",
      "Epoch: 0170 loss_train: 0.8882 acc_train: 0.8714 loss_val: 1.2260 acc_val: 0.7840 time: 0.4923s\n",
      "Epoch: 0171 loss_train: 0.8923 acc_train: 0.8643 loss_val: 1.2219 acc_val: 0.7860 time: 0.4259s\n",
      "Epoch: 0172 loss_train: 0.8795 acc_train: 0.8714 loss_val: 1.2179 acc_val: 0.7920 time: 0.4937s\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0173 loss_train: 0.8384 acc_train: 0.8857 loss_val: 1.2138 acc_val: 0.7900 time: 0.5236s\n",
      "Epoch: 0174 loss_train: 0.9054 acc_train: 0.8571 loss_val: 1.2096 acc_val: 0.7880 time: 0.5017s\n",
      "Epoch: 0175 loss_train: 0.9198 acc_train: 0.8357 loss_val: 1.2056 acc_val: 0.7880 time: 0.4478s\n",
      "Epoch: 0176 loss_train: 0.8332 acc_train: 0.8929 loss_val: 1.2016 acc_val: 0.7900 time: 0.5066s\n",
      "Epoch: 0177 loss_train: 0.8237 acc_train: 0.9214 loss_val: 1.1976 acc_val: 0.7900 time: 0.4857s\n",
      "Epoch: 0178 loss_train: 0.8544 acc_train: 0.8643 loss_val: 1.1935 acc_val: 0.7900 time: 0.4368s\n",
      "Epoch: 0179 loss_train: 0.8569 acc_train: 0.8786 loss_val: 1.1895 acc_val: 0.7920 time: 0.4867s\n",
      "Epoch: 0180 loss_train: 0.7805 acc_train: 0.8857 loss_val: 1.1855 acc_val: 0.7920 time: 0.4718s\n",
      "Epoch: 0181 loss_train: 0.8470 acc_train: 0.8643 loss_val: 1.1817 acc_val: 0.7920 time: 0.4279s\n",
      "Epoch: 0182 loss_train: 0.7640 acc_train: 0.9357 loss_val: 1.1779 acc_val: 0.7920 time: 0.4318s\n",
      "Epoch: 0183 loss_train: 0.7688 acc_train: 0.9071 loss_val: 1.1741 acc_val: 0.7920 time: 0.4308s\n",
      "Epoch: 0184 loss_train: 0.8281 acc_train: 0.8714 loss_val: 1.1702 acc_val: 0.7920 time: 0.4538s\n",
      "Epoch: 0185 loss_train: 0.8050 acc_train: 0.8786 loss_val: 1.1663 acc_val: 0.7920 time: 0.4199s\n",
      "Epoch: 0186 loss_train: 0.7547 acc_train: 0.9286 loss_val: 1.1623 acc_val: 0.7920 time: 0.4585s\n",
      "Epoch: 0187 loss_train: 0.8254 acc_train: 0.9071 loss_val: 1.1583 acc_val: 0.7920 time: 0.4809s\n",
      "Epoch: 0188 loss_train: 0.8332 acc_train: 0.8786 loss_val: 1.1544 acc_val: 0.7920 time: 0.4630s\n",
      "Epoch: 0189 loss_train: 0.7960 acc_train: 0.8500 loss_val: 1.1506 acc_val: 0.7900 time: 0.4460s\n",
      "Epoch: 0190 loss_train: 0.8200 acc_train: 0.9071 loss_val: 1.1469 acc_val: 0.7900 time: 0.4410s\n",
      "Epoch: 0191 loss_train: 0.8203 acc_train: 0.8786 loss_val: 1.1431 acc_val: 0.7900 time: 0.4532s\n",
      "Epoch: 0192 loss_train: 0.7811 acc_train: 0.9143 loss_val: 1.1395 acc_val: 0.7920 time: 0.4169s\n",
      "Epoch: 0193 loss_train: 0.7987 acc_train: 0.8643 loss_val: 1.1361 acc_val: 0.7920 time: 0.4358s\n",
      "Epoch: 0194 loss_train: 0.7614 acc_train: 0.9000 loss_val: 1.1328 acc_val: 0.7920 time: 0.4503s\n",
      "Epoch: 0195 loss_train: 0.7626 acc_train: 0.8929 loss_val: 1.1296 acc_val: 0.7900 time: 0.4398s\n",
      "Epoch: 0196 loss_train: 0.7401 acc_train: 0.9357 loss_val: 1.1266 acc_val: 0.7900 time: 0.4159s\n",
      "Epoch: 0197 loss_train: 0.7345 acc_train: 0.9071 loss_val: 1.1236 acc_val: 0.7900 time: 0.4239s\n",
      "Epoch: 0198 loss_train: 0.7381 acc_train: 0.9214 loss_val: 1.1208 acc_val: 0.7900 time: 0.4641s\n",
      "Epoch: 0199 loss_train: 0.7208 acc_train: 0.9214 loss_val: 1.1180 acc_val: 0.7920 time: 0.4219s\n",
      "Epoch: 0200 loss_train: 0.7254 acc_train: 0.9286 loss_val: 1.1152 acc_val: 0.7940 time: 0.4881s\n",
      "Optimization Finished!\n",
      "Total time elapsed: 96.7929s\n",
      "Test set results: loss= 1.0738 accuracy= 0.8310\n"
     ]
    }
   ],
   "source": [
    "# Train model\n",
    "t_total = time.time()\n",
    "for epoch in range(args.epochs):\n",
    "    train(epoch)\n",
    "print(\"Optimization Finished!\")\n",
    "print(\"Total time elapsed: {:.4f}s\".format(time.time() - t_total))\n",
    "\n",
    "# Testing\n",
    "test()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
