{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import division\n",
    "from __future__ import print_function\n",
    "\n",
    "import time\n",
    "import argparse\n",
    "import numpy as np\n",
    "import os\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import networkx as nx\n",
    "from scipy import sparse\n",
    "from scipy.linalg import fractional_matrix_power\n",
    "\n",
    "from utils import *\n",
    "from models import Graphsn_GIN\n",
    "from dataset_utils import DataLoader\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "_StoreAction(option_strings=['--dataset'], dest='dataset', nargs=None, const=None, default='cora', type=None, choices=None, help='Dataset name.', metavar=None)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "parser = argparse.ArgumentParser()\n",
    "parser.add_argument('--no-cuda', action='store_true', default=False,\n",
    "                    help='Disables CUDA training.')\n",
    "parser.add_argument('--fastmode', action='store_true', default=False,\n",
    "                    help='Validate during training pass.')\n",
    "parser.add_argument('--seed', type=int, default=42, help='Random seed.')\n",
    "parser.add_argument('--epochs', type=int, default=200,\n",
    "                    help='Number of epochs to train.')\n",
    "parser.add_argument('--lr', type=float, default=0.002,\n",
    "                    help='Initial learning rate.')\n",
    "parser.add_argument('--weight_decay', type=float, default=9e-3,\n",
    "                    help='Weight decay (L2 loss on parameters).')\n",
    "parser.add_argument('--hidden', type=int, default=64,\n",
    "                    help='Number of hidden units.')\n",
    "parser.add_argument('--dropout', type=float, default=0.88,\n",
    "                    help='Dropout rate (1 - keep probability).')\n",
    "parser.add_argument('--dataset', default='cora', help='Dataset name.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = parser.parse_args(\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x1fc94ab5900>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.random.seed(args.seed)\n",
    "torch.manual_seed(args.seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index\n",
      "Processing...\n",
      "Done!\n"
     ]
    }
   ],
   "source": [
    "dname = args.dataset\n",
    "dataset = DataLoader(dname)\n",
    "data = dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "A_norm, A, X, labels, idx_train, idx_val, idx_test = load_citation_data(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "G = nx.from_numpy_matrix(A)\n",
    "feature_dictionary = {}\n",
    "\n",
    "for i in np.arange(len(labels)):\n",
    "    feature_dictionary[i] = labels[i]\n",
    "\n",
    "nx.set_node_attributes(G, feature_dictionary, \"attr_name\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "sub_graphs = []\n",
    "\n",
    "for i in np.arange(len(A)):\n",
    "    s_indexes = []\n",
    "    for j in np.arange(len(A)):\n",
    "        s_indexes.append(i)\n",
    "        if(A[i][j]==1):\n",
    "            s_indexes.append(j)\n",
    "    sub_graphs.append(G.subgraph(s_indexes))\n",
    "\n",
    "subgraph_nodes_list = []\n",
    "\n",
    "for i in np.arange(len(sub_graphs)):\n",
    "    subgraph_nodes_list.append(list(sub_graphs[i].nodes))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "sub_graphs_adj = []\n",
    "for index in np.arange(len(sub_graphs)):\n",
    "    sub_graphs_adj.append(nx.adjacency_matrix(sub_graphs[index]).toarray())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_adj = torch.zeros(A.shape[0], A.shape[0])\n",
    "\n",
    "for node in np.arange(len(subgraph_nodes_list)):\n",
    "    sub_adj = sub_graphs_adj[node]\n",
    "    for neighbors in np.arange(len(subgraph_nodes_list[node])):\n",
    "        index = subgraph_nodes_list[node][neighbors]\n",
    "        count = torch.tensor(0).float()\n",
    "        if(index==node):\n",
    "            continue\n",
    "        else:\n",
    "            c_neighbors = set(subgraph_nodes_list[node]).intersection(subgraph_nodes_list[index])\n",
    "            if index in c_neighbors:\n",
    "                nodes_list = subgraph_nodes_list[node]\n",
    "                sub_graph_index = nodes_list.index(index)\n",
    "                c_neighbors_list = list(c_neighbors)\n",
    "                for i, item1 in enumerate(nodes_list):\n",
    "                    if(item1 in c_neighbors):\n",
    "                        for item2 in c_neighbors_list:\n",
    "                            j = nodes_list.index(item2)\n",
    "                            count += sub_adj[i][j]\n",
    "\n",
    "            new_adj[node][index] = count/2\n",
    "            new_adj[node][index] = new_adj[node][index]/(len(c_neighbors)*(len(c_neighbors)-1))\n",
    "            new_adj[node][index] = new_adj[node][index] * (len(c_neighbors)**1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "features = torch.FloatTensor(X)\n",
    "labels = torch.LongTensor(labels)\n",
    "\n",
    "weight = torch.FloatTensor(new_adj)\n",
    "weight = weight / weight.sum(1, keepdim=True)\n",
    "\n",
    "weight = weight + torch.FloatTensor(A)\n",
    "\n",
    "coeff = weight.sum(1, keepdim=True)\n",
    "coeff = torch.diag((coeff.T)[0])\n",
    "\n",
    "weight = weight + coeff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "weight = weight.detach().numpy()\n",
    "weight = np.nan_to_num(weight, nan=0)\n",
    "adj = torch.FloatTensor(weight)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model and optimizer\n",
    "model = Graphsn_GIN(nfeat=features.shape[1],\n",
    "                    nhid=args.hidden,\n",
    "                    nclass=labels.max().item() + 1,\n",
    "                    dropout=args.dropout)\n",
    "\n",
    "optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(epoch):\n",
    "    t = time.time()\n",
    "    model.train()\n",
    "    optimizer.zero_grad()\n",
    "    output = model(features, adj)\n",
    "    loss_train = F.nll_loss(output[idx_train], labels[idx_train])\n",
    "    acc_train = accuracy(output[idx_train], labels[idx_train])\n",
    "    loss_train.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    if not args.fastmode:\n",
    "        # Evaluate validation set performance separately, deactivates dropout during validation run.\n",
    "        model.eval()\n",
    "        output = model(features, adj)\n",
    "\n",
    "    loss_val = F.nll_loss(output[idx_val], labels[idx_val])\n",
    "    acc_val = accuracy(output[idx_val], labels[idx_val])\n",
    "    print('Epoch: {:04d}'.format(epoch+1),\n",
    "          'loss_train: {:.4f}'.format(loss_train.item()),\n",
    "          'acc_train: {:.4f}'.format(acc_train.item()),\n",
    "          'loss_val: {:.4f}'.format(loss_val.item()),\n",
    "          'acc_val: {:.4f}'.format(acc_val.item()),\n",
    "          'time: {:.4f}s'.format(time.time() - t))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test():\n",
    "    model.eval()\n",
    "    output = model(features, adj)\n",
    "    loss_test = F.nll_loss(output[idx_test], labels[idx_test])\n",
    "    acc_test = accuracy(output[idx_test], labels[idx_test])\n",
    "    print(\"Test set results:\",\n",
    "          \"loss= {:.4f}\".format(loss_test.item()),\n",
    "          \"accuracy= {:.4f}\".format(acc_test.item()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0001 loss_train: 2.0504 acc_train: 0.1786 loss_val: 1.8746 acc_val: 0.3440 time: 1.0253s\n",
      "Epoch: 0002 loss_train: 1.9849 acc_train: 0.2071 loss_val: 1.8429 acc_val: 0.4140 time: 0.6827s\n",
      "Epoch: 0003 loss_train: 1.8407 acc_train: 0.2500 loss_val: 1.8152 acc_val: 0.4720 time: 0.7530s\n",
      "Epoch: 0004 loss_train: 1.7604 acc_train: 0.2929 loss_val: 1.7882 acc_val: 0.5640 time: 0.7181s\n",
      "Epoch: 0005 loss_train: 1.7793 acc_train: 0.3643 loss_val: 1.7618 acc_val: 0.6020 time: 0.6423s\n",
      "Epoch: 0006 loss_train: 1.6856 acc_train: 0.3857 loss_val: 1.7333 acc_val: 0.6380 time: 0.7679s\n",
      "Epoch: 0007 loss_train: 1.6423 acc_train: 0.4714 loss_val: 1.7020 acc_val: 0.6780 time: 0.6732s\n",
      "Epoch: 0008 loss_train: 1.5223 acc_train: 0.5143 loss_val: 1.6692 acc_val: 0.7160 time: 0.7310s\n",
      "Epoch: 0009 loss_train: 1.6090 acc_train: 0.4000 loss_val: 1.6366 acc_val: 0.7300 time: 0.6293s\n",
      "Epoch: 0010 loss_train: 1.5840 acc_train: 0.4571 loss_val: 1.6031 acc_val: 0.7360 time: 0.6154s\n",
      "Epoch: 0011 loss_train: 1.3787 acc_train: 0.5643 loss_val: 1.5679 acc_val: 0.7480 time: 0.6812s\n",
      "Epoch: 0012 loss_train: 1.3795 acc_train: 0.4786 loss_val: 1.5311 acc_val: 0.7500 time: 0.6328s\n",
      "Epoch: 0013 loss_train: 1.4354 acc_train: 0.4714 loss_val: 1.4935 acc_val: 0.7540 time: 0.6524s\n",
      "Epoch: 0014 loss_train: 1.3258 acc_train: 0.5500 loss_val: 1.4557 acc_val: 0.7500 time: 0.5725s\n",
      "Epoch: 0015 loss_train: 1.4832 acc_train: 0.4071 loss_val: 1.4195 acc_val: 0.7480 time: 0.6141s\n",
      "Epoch: 0016 loss_train: 1.3235 acc_train: 0.5071 loss_val: 1.3839 acc_val: 0.7540 time: 0.7263s\n",
      "Epoch: 0017 loss_train: 1.2251 acc_train: 0.5643 loss_val: 1.3482 acc_val: 0.7580 time: 0.6243s\n",
      "Epoch: 0018 loss_train: 1.1894 acc_train: 0.6071 loss_val: 1.3145 acc_val: 0.7640 time: 0.7261s\n",
      "Epoch: 0019 loss_train: 1.0958 acc_train: 0.6357 loss_val: 1.2821 acc_val: 0.7680 time: 0.7251s\n",
      "Epoch: 0020 loss_train: 1.1191 acc_train: 0.5571 loss_val: 1.2516 acc_val: 0.7660 time: 0.6513s\n",
      "Epoch: 0021 loss_train: 1.1088 acc_train: 0.5929 loss_val: 1.2232 acc_val: 0.7640 time: 0.6084s\n",
      "Epoch: 0022 loss_train: 1.0592 acc_train: 0.6500 loss_val: 1.1960 acc_val: 0.7660 time: 0.6293s\n",
      "Epoch: 0023 loss_train: 1.0461 acc_train: 0.6286 loss_val: 1.1703 acc_val: 0.7700 time: 0.7719s\n",
      "Epoch: 0024 loss_train: 1.0237 acc_train: 0.5929 loss_val: 1.1459 acc_val: 0.7680 time: 0.7091s\n",
      "Epoch: 0025 loss_train: 0.9177 acc_train: 0.6786 loss_val: 1.1227 acc_val: 0.7660 time: 0.7271s\n",
      "Epoch: 0026 loss_train: 1.0344 acc_train: 0.6429 loss_val: 1.1028 acc_val: 0.7640 time: 0.5884s\n",
      "Epoch: 0027 loss_train: 0.9955 acc_train: 0.6286 loss_val: 1.0859 acc_val: 0.7660 time: 0.6931s\n",
      "Epoch: 0028 loss_train: 1.0024 acc_train: 0.6643 loss_val: 1.0709 acc_val: 0.7560 time: 0.7027s\n",
      "Epoch: 0029 loss_train: 1.0464 acc_train: 0.5929 loss_val: 1.0574 acc_val: 0.7580 time: 0.6935s\n",
      "Epoch: 0030 loss_train: 0.9380 acc_train: 0.6786 loss_val: 1.0439 acc_val: 0.7580 time: 0.7630s\n",
      "Epoch: 0031 loss_train: 0.8637 acc_train: 0.6929 loss_val: 1.0303 acc_val: 0.7560 time: 0.7580s\n",
      "Epoch: 0032 loss_train: 0.8651 acc_train: 0.7214 loss_val: 1.0166 acc_val: 0.7640 time: 0.7269s\n",
      "Epoch: 0033 loss_train: 0.9363 acc_train: 0.6286 loss_val: 1.0039 acc_val: 0.7680 time: 0.7719s\n",
      "Epoch: 0034 loss_train: 0.9525 acc_train: 0.6643 loss_val: 0.9926 acc_val: 0.7680 time: 0.6951s\n",
      "Epoch: 0035 loss_train: 0.8380 acc_train: 0.6929 loss_val: 0.9826 acc_val: 0.7700 time: 0.7520s\n",
      "Epoch: 0036 loss_train: 0.7494 acc_train: 0.7357 loss_val: 0.9737 acc_val: 0.7760 time: 0.8048s\n",
      "Epoch: 0037 loss_train: 0.8448 acc_train: 0.6357 loss_val: 0.9663 acc_val: 0.7700 time: 0.7041s\n",
      "Epoch: 0038 loss_train: 0.7765 acc_train: 0.6857 loss_val: 0.9596 acc_val: 0.7640 time: 0.8876s\n",
      "Epoch: 0039 loss_train: 0.8856 acc_train: 0.6571 loss_val: 0.9538 acc_val: 0.7640 time: 0.7390s\n",
      "Epoch: 0040 loss_train: 0.7426 acc_train: 0.7571 loss_val: 0.9481 acc_val: 0.7640 time: 0.9405s\n",
      "Epoch: 0041 loss_train: 0.7228 acc_train: 0.7357 loss_val: 0.9426 acc_val: 0.7640 time: 0.6931s\n",
      "Epoch: 0042 loss_train: 0.8366 acc_train: 0.6500 loss_val: 0.9375 acc_val: 0.7600 time: 0.7524s\n",
      "Epoch: 0043 loss_train: 0.6126 acc_train: 0.7857 loss_val: 0.9323 acc_val: 0.7560 time: 0.7031s\n",
      "Epoch: 0044 loss_train: 0.6633 acc_train: 0.7429 loss_val: 0.9267 acc_val: 0.7560 time: 0.9607s\n",
      "Epoch: 0045 loss_train: 0.7174 acc_train: 0.7286 loss_val: 0.9217 acc_val: 0.7540 time: 0.8848s\n",
      "Epoch: 0046 loss_train: 0.7446 acc_train: 0.7643 loss_val: 0.9164 acc_val: 0.7520 time: 0.8759s\n",
      "Epoch: 0047 loss_train: 0.7418 acc_train: 0.6929 loss_val: 0.9110 acc_val: 0.7480 time: 0.8217s\n",
      "Epoch: 0048 loss_train: 0.5998 acc_train: 0.7714 loss_val: 0.9061 acc_val: 0.7480 time: 0.9470s\n",
      "Epoch: 0049 loss_train: 0.7435 acc_train: 0.6929 loss_val: 0.9011 acc_val: 0.7460 time: 0.9993s\n",
      "Epoch: 0050 loss_train: 0.6674 acc_train: 0.7357 loss_val: 0.8962 acc_val: 0.7460 time: 0.9193s\n",
      "Epoch: 0051 loss_train: 0.7043 acc_train: 0.7286 loss_val: 0.8915 acc_val: 0.7480 time: 0.9130s\n",
      "Epoch: 0052 loss_train: 0.7093 acc_train: 0.7071 loss_val: 0.8868 acc_val: 0.7500 time: 0.8621s\n",
      "Epoch: 0053 loss_train: 0.6644 acc_train: 0.7286 loss_val: 0.8830 acc_val: 0.7500 time: 0.8988s\n",
      "Epoch: 0054 loss_train: 0.5917 acc_train: 0.7500 loss_val: 0.8796 acc_val: 0.7500 time: 0.9325s\n",
      "Epoch: 0055 loss_train: 0.6208 acc_train: 0.7429 loss_val: 0.8766 acc_val: 0.7500 time: 0.9073s\n",
      "Epoch: 0056 loss_train: 0.6208 acc_train: 0.7643 loss_val: 0.8741 acc_val: 0.7500 time: 0.8529s\n",
      "Epoch: 0057 loss_train: 0.7326 acc_train: 0.7286 loss_val: 0.8720 acc_val: 0.7480 time: 0.8961s\n",
      "Epoch: 0058 loss_train: 0.5174 acc_train: 0.7929 loss_val: 0.8702 acc_val: 0.7480 time: 0.8879s\n",
      "Epoch: 0059 loss_train: 0.6037 acc_train: 0.7857 loss_val: 0.8689 acc_val: 0.7480 time: 0.8997s\n",
      "Epoch: 0060 loss_train: 0.5984 acc_train: 0.7643 loss_val: 0.8670 acc_val: 0.7480 time: 0.9106s\n",
      "Epoch: 0061 loss_train: 0.5303 acc_train: 0.7786 loss_val: 0.8646 acc_val: 0.7460 time: 0.9162s\n",
      "Epoch: 0062 loss_train: 0.5705 acc_train: 0.7786 loss_val: 0.8622 acc_val: 0.7460 time: 1.0034s\n",
      "Epoch: 0063 loss_train: 0.7409 acc_train: 0.7500 loss_val: 0.8610 acc_val: 0.7500 time: 0.9707s\n",
      "Epoch: 0064 loss_train: 0.5202 acc_train: 0.7929 loss_val: 0.8595 acc_val: 0.7500 time: 0.9730s\n",
      "Epoch: 0065 loss_train: 0.5584 acc_train: 0.7429 loss_val: 0.8578 acc_val: 0.7520 time: 0.9939s\n",
      "Epoch: 0066 loss_train: 0.4910 acc_train: 0.8143 loss_val: 0.8556 acc_val: 0.7560 time: 1.1228s\n",
      "Epoch: 0067 loss_train: 0.6312 acc_train: 0.7214 loss_val: 0.8532 acc_val: 0.7580 time: 1.0270s\n",
      "Epoch: 0068 loss_train: 0.6254 acc_train: 0.7500 loss_val: 0.8506 acc_val: 0.7560 time: 0.9060s\n",
      "Epoch: 0069 loss_train: 0.5894 acc_train: 0.7500 loss_val: 0.8488 acc_val: 0.7520 time: 1.0084s\n",
      "Epoch: 0070 loss_train: 0.5965 acc_train: 0.7429 loss_val: 0.8465 acc_val: 0.7520 time: 1.0056s\n",
      "Epoch: 0071 loss_train: 0.5613 acc_train: 0.7714 loss_val: 0.8445 acc_val: 0.7480 time: 1.0194s\n",
      "Epoch: 0072 loss_train: 0.4820 acc_train: 0.8000 loss_val: 0.8422 acc_val: 0.7460 time: 1.0070s\n",
      "Epoch: 0073 loss_train: 0.5399 acc_train: 0.8000 loss_val: 0.8401 acc_val: 0.7460 time: 0.9765s\n",
      "Epoch: 0074 loss_train: 0.6258 acc_train: 0.7214 loss_val: 0.8385 acc_val: 0.7480 time: 0.8831s\n",
      "Epoch: 0075 loss_train: 0.5050 acc_train: 0.8000 loss_val: 0.8367 acc_val: 0.7480 time: 0.8636s\n",
      "Epoch: 0076 loss_train: 0.5958 acc_train: 0.7429 loss_val: 0.8352 acc_val: 0.7460 time: 0.9438s\n",
      "Epoch: 0077 loss_train: 0.4834 acc_train: 0.8000 loss_val: 0.8339 acc_val: 0.7460 time: 0.8994s\n",
      "Epoch: 0078 loss_train: 0.5871 acc_train: 0.7357 loss_val: 0.8331 acc_val: 0.7380 time: 0.8843s\n",
      "Epoch: 0079 loss_train: 0.5783 acc_train: 0.7714 loss_val: 0.8331 acc_val: 0.7400 time: 0.8686s\n",
      "Epoch: 0080 loss_train: 0.6139 acc_train: 0.7429 loss_val: 0.8329 acc_val: 0.7400 time: 0.8897s\n",
      "Epoch: 0081 loss_train: 0.4710 acc_train: 0.7929 loss_val: 0.8328 acc_val: 0.7380 time: 0.8835s\n",
      "Epoch: 0082 loss_train: 0.5036 acc_train: 0.7714 loss_val: 0.8312 acc_val: 0.7420 time: 0.8812s\n",
      "Epoch: 0083 loss_train: 0.5239 acc_train: 0.7929 loss_val: 0.8300 acc_val: 0.7420 time: 0.9028s\n",
      "Epoch: 0084 loss_train: 0.6076 acc_train: 0.7429 loss_val: 0.8291 acc_val: 0.7420 time: 0.8679s\n",
      "Epoch: 0085 loss_train: 0.4876 acc_train: 0.8071 loss_val: 0.8275 acc_val: 0.7420 time: 0.9334s\n",
      "Epoch: 0086 loss_train: 0.5151 acc_train: 0.7714 loss_val: 0.8261 acc_val: 0.7400 time: 0.8775s\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0087 loss_train: 0.5396 acc_train: 0.7714 loss_val: 0.8238 acc_val: 0.7400 time: 0.9009s\n",
      "Epoch: 0088 loss_train: 0.4567 acc_train: 0.7929 loss_val: 0.8218 acc_val: 0.7420 time: 0.8356s\n",
      "Epoch: 0089 loss_train: 0.5112 acc_train: 0.8214 loss_val: 0.8198 acc_val: 0.7480 time: 0.9037s\n",
      "Epoch: 0090 loss_train: 0.4606 acc_train: 0.8143 loss_val: 0.8178 acc_val: 0.7540 time: 0.8647s\n",
      "Epoch: 0091 loss_train: 0.4749 acc_train: 0.8071 loss_val: 0.8157 acc_val: 0.7540 time: 0.8950s\n",
      "Epoch: 0092 loss_train: 0.5906 acc_train: 0.7143 loss_val: 0.8137 acc_val: 0.7580 time: 0.8814s\n",
      "Epoch: 0093 loss_train: 0.6091 acc_train: 0.7214 loss_val: 0.8119 acc_val: 0.7560 time: 0.8889s\n",
      "Epoch: 0094 loss_train: 0.4917 acc_train: 0.8143 loss_val: 0.8101 acc_val: 0.7560 time: 0.8671s\n",
      "Epoch: 0095 loss_train: 0.5069 acc_train: 0.7500 loss_val: 0.8093 acc_val: 0.7580 time: 0.8948s\n",
      "Epoch: 0096 loss_train: 0.4429 acc_train: 0.8143 loss_val: 0.8081 acc_val: 0.7560 time: 0.8817s\n",
      "Epoch: 0097 loss_train: 0.5176 acc_train: 0.8000 loss_val: 0.8071 acc_val: 0.7620 time: 0.8399s\n",
      "Epoch: 0098 loss_train: 0.5497 acc_train: 0.7500 loss_val: 0.8062 acc_val: 0.7640 time: 0.8872s\n",
      "Epoch: 0099 loss_train: 0.4477 acc_train: 0.8357 loss_val: 0.8057 acc_val: 0.7620 time: 0.8630s\n",
      "Epoch: 0100 loss_train: 0.5028 acc_train: 0.7857 loss_val: 0.8062 acc_val: 0.7620 time: 0.8762s\n",
      "Epoch: 0101 loss_train: 0.4730 acc_train: 0.8143 loss_val: 0.8071 acc_val: 0.7640 time: 0.8624s\n",
      "Epoch: 0102 loss_train: 0.5569 acc_train: 0.7429 loss_val: 0.8089 acc_val: 0.7640 time: 0.8959s\n",
      "Epoch: 0103 loss_train: 0.6106 acc_train: 0.7286 loss_val: 0.8108 acc_val: 0.7600 time: 0.8829s\n",
      "Epoch: 0104 loss_train: 0.3935 acc_train: 0.8214 loss_val: 0.8128 acc_val: 0.7580 time: 0.9264s\n",
      "Epoch: 0105 loss_train: 0.4390 acc_train: 0.8000 loss_val: 0.8140 acc_val: 0.7600 time: 0.8767s\n",
      "Epoch: 0106 loss_train: 0.4999 acc_train: 0.7429 loss_val: 0.8154 acc_val: 0.7600 time: 0.8975s\n",
      "Epoch: 0107 loss_train: 0.5525 acc_train: 0.7500 loss_val: 0.8165 acc_val: 0.7560 time: 0.8530s\n",
      "Epoch: 0108 loss_train: 0.4272 acc_train: 0.8000 loss_val: 0.8179 acc_val: 0.7540 time: 0.9039s\n",
      "Epoch: 0109 loss_train: 0.5076 acc_train: 0.7500 loss_val: 0.8192 acc_val: 0.7540 time: 1.0484s\n",
      "Epoch: 0110 loss_train: 0.5887 acc_train: 0.7357 loss_val: 0.8201 acc_val: 0.7520 time: 0.9135s\n",
      "Epoch: 0111 loss_train: 0.4420 acc_train: 0.8214 loss_val: 0.8206 acc_val: 0.7480 time: 0.8545s\n",
      "Epoch: 0112 loss_train: 0.4882 acc_train: 0.7786 loss_val: 0.8205 acc_val: 0.7460 time: 0.8905s\n",
      "Epoch: 0113 loss_train: 0.4629 acc_train: 0.7786 loss_val: 0.8204 acc_val: 0.7420 time: 0.9012s\n",
      "Epoch: 0114 loss_train: 0.4635 acc_train: 0.8000 loss_val: 0.8205 acc_val: 0.7460 time: 0.8513s\n",
      "Epoch: 0115 loss_train: 0.4603 acc_train: 0.8214 loss_val: 0.8213 acc_val: 0.7460 time: 0.8765s\n",
      "Epoch: 0116 loss_train: 0.5687 acc_train: 0.7214 loss_val: 0.8222 acc_val: 0.7460 time: 0.8700s\n",
      "Epoch: 0117 loss_train: 0.5118 acc_train: 0.7429 loss_val: 0.8220 acc_val: 0.7460 time: 0.9034s\n",
      "Epoch: 0118 loss_train: 0.5038 acc_train: 0.8214 loss_val: 0.8214 acc_val: 0.7440 time: 0.8866s\n",
      "Epoch: 0119 loss_train: 0.5207 acc_train: 0.8000 loss_val: 0.8215 acc_val: 0.7460 time: 0.9128s\n",
      "Epoch: 0120 loss_train: 0.4484 acc_train: 0.8000 loss_val: 0.8214 acc_val: 0.7460 time: 0.8941s\n",
      "Epoch: 0121 loss_train: 0.4580 acc_train: 0.7714 loss_val: 0.8212 acc_val: 0.7460 time: 0.9084s\n",
      "Epoch: 0122 loss_train: 0.4425 acc_train: 0.7929 loss_val: 0.8215 acc_val: 0.7460 time: 0.8752s\n",
      "Epoch: 0123 loss_train: 0.5441 acc_train: 0.7714 loss_val: 0.8219 acc_val: 0.7500 time: 0.9049s\n",
      "Epoch: 0124 loss_train: 0.4044 acc_train: 0.8357 loss_val: 0.8227 acc_val: 0.7540 time: 0.8696s\n",
      "Epoch: 0125 loss_train: 0.5094 acc_train: 0.7857 loss_val: 0.8234 acc_val: 0.7540 time: 0.8961s\n",
      "Epoch: 0126 loss_train: 0.4605 acc_train: 0.7857 loss_val: 0.8236 acc_val: 0.7540 time: 0.8794s\n",
      "Epoch: 0127 loss_train: 0.5000 acc_train: 0.7929 loss_val: 0.8239 acc_val: 0.7540 time: 0.8511s\n",
      "Epoch: 0128 loss_train: 0.5076 acc_train: 0.7429 loss_val: 0.8238 acc_val: 0.7560 time: 0.8860s\n",
      "Epoch: 0129 loss_train: 0.4945 acc_train: 0.7643 loss_val: 0.8234 acc_val: 0.7540 time: 0.8568s\n",
      "Epoch: 0130 loss_train: 0.5386 acc_train: 0.7857 loss_val: 0.8228 acc_val: 0.7560 time: 0.8927s\n",
      "Epoch: 0131 loss_train: 0.4458 acc_train: 0.8214 loss_val: 0.8214 acc_val: 0.7560 time: 0.8794s\n",
      "Epoch: 0132 loss_train: 0.5085 acc_train: 0.7571 loss_val: 0.8206 acc_val: 0.7520 time: 0.9035s\n",
      "Epoch: 0133 loss_train: 0.4797 acc_train: 0.8071 loss_val: 0.8196 acc_val: 0.7520 time: 0.8501s\n",
      "Epoch: 0134 loss_train: 0.4384 acc_train: 0.8286 loss_val: 0.8186 acc_val: 0.7540 time: 0.8695s\n",
      "Epoch: 0135 loss_train: 0.4544 acc_train: 0.8214 loss_val: 0.8173 acc_val: 0.7540 time: 0.8988s\n",
      "Epoch: 0136 loss_train: 0.4495 acc_train: 0.8000 loss_val: 0.8166 acc_val: 0.7540 time: 0.8725s\n",
      "Epoch: 0137 loss_train: 0.4838 acc_train: 0.7857 loss_val: 0.8164 acc_val: 0.7560 time: 0.8793s\n",
      "Epoch: 0138 loss_train: 0.3909 acc_train: 0.8143 loss_val: 0.8162 acc_val: 0.7560 time: 0.8960s\n",
      "Epoch: 0139 loss_train: 0.3938 acc_train: 0.8286 loss_val: 0.8168 acc_val: 0.7540 time: 0.8996s\n",
      "Epoch: 0140 loss_train: 0.3539 acc_train: 0.8286 loss_val: 0.8168 acc_val: 0.7560 time: 0.9423s\n",
      "Epoch: 0141 loss_train: 0.3933 acc_train: 0.8143 loss_val: 0.8170 acc_val: 0.7540 time: 0.8955s\n",
      "Epoch: 0142 loss_train: 0.4215 acc_train: 0.7929 loss_val: 0.8178 acc_val: 0.7540 time: 0.9415s\n",
      "Epoch: 0143 loss_train: 0.4206 acc_train: 0.8214 loss_val: 0.8186 acc_val: 0.7480 time: 0.9866s\n",
      "Epoch: 0144 loss_train: 0.4130 acc_train: 0.8143 loss_val: 0.8207 acc_val: 0.7460 time: 0.9828s\n",
      "Epoch: 0145 loss_train: 0.3957 acc_train: 0.7857 loss_val: 0.8225 acc_val: 0.7460 time: 0.9294s\n",
      "Epoch: 0146 loss_train: 0.4875 acc_train: 0.7714 loss_val: 0.8240 acc_val: 0.7460 time: 1.0891s\n",
      "Epoch: 0147 loss_train: 0.6266 acc_train: 0.7286 loss_val: 0.8247 acc_val: 0.7500 time: 0.9822s\n",
      "Epoch: 0148 loss_train: 0.3676 acc_train: 0.8500 loss_val: 0.8250 acc_val: 0.7500 time: 0.9355s\n",
      "Epoch: 0149 loss_train: 0.4763 acc_train: 0.7714 loss_val: 0.8254 acc_val: 0.7500 time: 0.8892s\n",
      "Epoch: 0150 loss_train: 0.4636 acc_train: 0.7643 loss_val: 0.8257 acc_val: 0.7500 time: 0.9221s\n",
      "Epoch: 0151 loss_train: 0.3700 acc_train: 0.8500 loss_val: 0.8268 acc_val: 0.7500 time: 0.9172s\n",
      "Epoch: 0152 loss_train: 0.4314 acc_train: 0.8286 loss_val: 0.8289 acc_val: 0.7480 time: 0.7259s\n",
      "Epoch: 0153 loss_train: 0.4893 acc_train: 0.7643 loss_val: 0.8303 acc_val: 0.7500 time: 0.7729s\n",
      "Epoch: 0154 loss_train: 0.3782 acc_train: 0.8500 loss_val: 0.8320 acc_val: 0.7460 time: 0.7151s\n",
      "Epoch: 0155 loss_train: 0.4584 acc_train: 0.7500 loss_val: 0.8338 acc_val: 0.7440 time: 0.7320s\n",
      "Epoch: 0156 loss_train: 0.4898 acc_train: 0.7929 loss_val: 0.8348 acc_val: 0.7440 time: 0.6662s\n",
      "Epoch: 0157 loss_train: 0.4225 acc_train: 0.8000 loss_val: 0.8360 acc_val: 0.7460 time: 0.7719s\n",
      "Epoch: 0158 loss_train: 0.4118 acc_train: 0.7786 loss_val: 0.8365 acc_val: 0.7460 time: 0.7480s\n",
      "Epoch: 0159 loss_train: 0.4538 acc_train: 0.7929 loss_val: 0.8373 acc_val: 0.7440 time: 0.7640s\n",
      "Epoch: 0160 loss_train: 0.4589 acc_train: 0.7786 loss_val: 0.8377 acc_val: 0.7480 time: 0.6772s\n",
      "Epoch: 0161 loss_train: 0.4284 acc_train: 0.8000 loss_val: 0.8377 acc_val: 0.7460 time: 0.6752s\n",
      "Epoch: 0162 loss_train: 0.4780 acc_train: 0.7571 loss_val: 0.8379 acc_val: 0.7440 time: 0.7310s\n",
      "Epoch: 0163 loss_train: 0.5651 acc_train: 0.7643 loss_val: 0.8390 acc_val: 0.7440 time: 0.6473s\n",
      "Epoch: 0164 loss_train: 0.4741 acc_train: 0.7857 loss_val: 0.8391 acc_val: 0.7480 time: 0.7240s\n",
      "Epoch: 0165 loss_train: 0.4409 acc_train: 0.7929 loss_val: 0.8397 acc_val: 0.7500 time: 0.7450s\n",
      "Epoch: 0166 loss_train: 0.3935 acc_train: 0.8571 loss_val: 0.8404 acc_val: 0.7460 time: 0.8281s\n",
      "Epoch: 0167 loss_train: 0.5213 acc_train: 0.7571 loss_val: 0.8407 acc_val: 0.7460 time: 0.7300s\n",
      "Epoch: 0168 loss_train: 0.5336 acc_train: 0.7357 loss_val: 0.8405 acc_val: 0.7460 time: 0.7121s\n",
      "Epoch: 0169 loss_train: 0.4845 acc_train: 0.7643 loss_val: 0.8396 acc_val: 0.7460 time: 0.7084s\n",
      "Epoch: 0170 loss_train: 0.4394 acc_train: 0.7857 loss_val: 0.8397 acc_val: 0.7480 time: 0.7211s\n",
      "Epoch: 0171 loss_train: 0.4514 acc_train: 0.8286 loss_val: 0.8395 acc_val: 0.7480 time: 0.7949s\n",
      "Epoch: 0172 loss_train: 0.4061 acc_train: 0.8143 loss_val: 0.8383 acc_val: 0.7480 time: 0.8846s\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0173 loss_train: 0.4525 acc_train: 0.7714 loss_val: 0.8362 acc_val: 0.7480 time: 0.7709s\n",
      "Epoch: 0174 loss_train: 0.3852 acc_train: 0.8214 loss_val: 0.8345 acc_val: 0.7480 time: 0.7420s\n",
      "Epoch: 0175 loss_train: 0.4723 acc_train: 0.7643 loss_val: 0.8326 acc_val: 0.7500 time: 0.8418s\n",
      "Epoch: 0176 loss_train: 0.4064 acc_train: 0.8429 loss_val: 0.8311 acc_val: 0.7460 time: 0.7505s\n",
      "Epoch: 0177 loss_train: 0.4650 acc_train: 0.8000 loss_val: 0.8297 acc_val: 0.7480 time: 0.6732s\n",
      "Epoch: 0178 loss_train: 0.5356 acc_train: 0.7571 loss_val: 0.8292 acc_val: 0.7500 time: 0.7377s\n",
      "Epoch: 0179 loss_train: 0.4458 acc_train: 0.7929 loss_val: 0.8282 acc_val: 0.7520 time: 1.0133s\n",
      "Epoch: 0180 loss_train: 0.4756 acc_train: 0.7714 loss_val: 0.8273 acc_val: 0.7520 time: 0.9674s\n",
      "Epoch: 0181 loss_train: 0.4965 acc_train: 0.7857 loss_val: 0.8267 acc_val: 0.7560 time: 0.7829s\n",
      "Epoch: 0182 loss_train: 0.4456 acc_train: 0.7857 loss_val: 0.8258 acc_val: 0.7560 time: 0.7261s\n",
      "Epoch: 0183 loss_train: 0.4305 acc_train: 0.8286 loss_val: 0.8245 acc_val: 0.7560 time: 0.7081s\n",
      "Epoch: 0184 loss_train: 0.4306 acc_train: 0.7786 loss_val: 0.8233 acc_val: 0.7540 time: 0.6926s\n",
      "Epoch: 0185 loss_train: 0.3784 acc_train: 0.8000 loss_val: 0.8226 acc_val: 0.7540 time: 0.7061s\n",
      "Epoch: 0186 loss_train: 0.4254 acc_train: 0.7714 loss_val: 0.8222 acc_val: 0.7560 time: 0.6992s\n",
      "Epoch: 0187 loss_train: 0.4300 acc_train: 0.7643 loss_val: 0.8219 acc_val: 0.7580 time: 0.6662s\n",
      "Epoch: 0188 loss_train: 0.4973 acc_train: 0.8143 loss_val: 0.8219 acc_val: 0.7580 time: 0.8285s\n",
      "Epoch: 0189 loss_train: 0.4614 acc_train: 0.7929 loss_val: 0.8230 acc_val: 0.7580 time: 0.8900s\n",
      "Epoch: 0190 loss_train: 0.4598 acc_train: 0.7714 loss_val: 0.8246 acc_val: 0.7600 time: 0.8748s\n",
      "Epoch: 0191 loss_train: 0.4533 acc_train: 0.8286 loss_val: 0.8263 acc_val: 0.7640 time: 0.8288s\n",
      "Epoch: 0192 loss_train: 0.4277 acc_train: 0.8071 loss_val: 0.8284 acc_val: 0.7660 time: 0.8752s\n",
      "Epoch: 0193 loss_train: 0.5360 acc_train: 0.7714 loss_val: 0.8302 acc_val: 0.7620 time: 0.7271s\n",
      "Epoch: 0194 loss_train: 0.5450 acc_train: 0.7500 loss_val: 0.8311 acc_val: 0.7620 time: 0.7221s\n",
      "Epoch: 0195 loss_train: 0.3776 acc_train: 0.8500 loss_val: 0.8322 acc_val: 0.7620 time: 0.7520s\n",
      "Epoch: 0196 loss_train: 0.4438 acc_train: 0.7929 loss_val: 0.8327 acc_val: 0.7620 time: 0.7140s\n",
      "Epoch: 0197 loss_train: 0.3600 acc_train: 0.8143 loss_val: 0.8332 acc_val: 0.7620 time: 0.7091s\n",
      "Epoch: 0198 loss_train: 0.4469 acc_train: 0.8000 loss_val: 0.8339 acc_val: 0.7600 time: 0.7630s\n",
      "Epoch: 0199 loss_train: 0.3789 acc_train: 0.8357 loss_val: 0.8353 acc_val: 0.7540 time: 0.7251s\n",
      "Epoch: 0200 loss_train: 0.3561 acc_train: 0.8571 loss_val: 0.8371 acc_val: 0.7540 time: 0.6832s\n",
      "Optimization Finished!\n",
      "Total time elapsed: 167.8993s\n",
      "Test set results: loss= 0.7250 accuracy= 0.7990\n"
     ]
    }
   ],
   "source": [
    "# Train model\n",
    "t_total = time.time()\n",
    "for epoch in range(args.epochs):\n",
    "    train(epoch)\n",
    "print(\"Optimization Finished!\")\n",
    "print(\"Total time elapsed: {:.4f}s\".format(time.time() - t_total))\n",
    "\n",
    "# Testing\n",
    "test()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
