{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import division\n",
    "from __future__ import print_function\n",
    "\n",
    "import time\n",
    "import argparse\n",
    "import numpy as np\n",
    "import os\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import networkx as nx\n",
    "from scipy import sparse\n",
    "from scipy.linalg import fractional_matrix_power\n",
    "\n",
    "from utils import *\n",
    "from models import Graphsn_GIN\n",
    "from dataset_utils import DataLoader\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "_StoreAction(option_strings=['--dataset'], dest='dataset', nargs=None, const=None, default='citeseer', type=None, choices=None, help='Dataset name.', metavar=None)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "parser = argparse.ArgumentParser()\n",
    "parser.add_argument('--no-cuda', action='store_true', default=False,\n",
    "                    help='Disables CUDA training.')\n",
    "parser.add_argument('--fastmode', action='store_true', default=False,\n",
    "                    help='Validate during training pass.')\n",
    "parser.add_argument('--seed', type=int, default=42, help='Random seed.')\n",
    "parser.add_argument('--epochs', type=int, default=200,\n",
    "                    help='Number of epochs to train.')\n",
    "parser.add_argument('--lr', type=float, default=0.002,\n",
    "                    help='Initial learning rate.')\n",
    "parser.add_argument('--weight_decay', type=float, default=9e-3,\n",
    "                    help='Weight decay (L2 loss on parameters).')\n",
    "parser.add_argument('--hidden', type=int, default=64,\n",
    "                    help='Number of hidden units.')\n",
    "parser.add_argument('--dropout', type=float, default=0.88,\n",
    "                    help='Dropout rate (1 - keep probability).')\n",
    "parser.add_argument('--dataset', default='citeseer', help='Dataset name.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = parser.parse_args(\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x22923cd3900>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.random.seed(args.seed)\n",
    "torch.manual_seed(args.seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "dname = args.dataset\n",
    "dataset = DataLoader(dname)\n",
    "data = dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "A_norm, A, X, labels, idx_train, idx_val, idx_test = load_citation_data(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "G = nx.from_numpy_matrix(A)\n",
    "feature_dictionary = {}\n",
    "\n",
    "for i in np.arange(len(labels)):\n",
    "    feature_dictionary[i] = labels[i]\n",
    "\n",
    "nx.set_node_attributes(G, feature_dictionary, \"attr_name\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "sub_graphs = []\n",
    "\n",
    "for i in np.arange(len(A)):\n",
    "    s_indexes = []\n",
    "    for j in np.arange(len(A)):\n",
    "        s_indexes.append(i)\n",
    "        if(A[i][j]==1):\n",
    "            s_indexes.append(j)\n",
    "    sub_graphs.append(G.subgraph(s_indexes))\n",
    "\n",
    "subgraph_nodes_list = []\n",
    "\n",
    "for i in np.arange(len(sub_graphs)):\n",
    "    subgraph_nodes_list.append(list(sub_graphs[i].nodes))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "sub_graphs_adj = []\n",
    "for index in np.arange(len(sub_graphs)):\n",
    "    sub_graphs_adj.append(nx.adjacency_matrix(sub_graphs[index]).toarray())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_adj = torch.zeros(A.shape[0], A.shape[0])\n",
    "\n",
    "for node in np.arange(len(subgraph_nodes_list)):\n",
    "    sub_adj = sub_graphs_adj[node]\n",
    "    for neighbors in np.arange(len(subgraph_nodes_list[node])):\n",
    "        index = subgraph_nodes_list[node][neighbors]\n",
    "        count = torch.tensor(0).float()\n",
    "        if(index==node):\n",
    "            continue\n",
    "        else:\n",
    "            c_neighbors = set(subgraph_nodes_list[node]).intersection(subgraph_nodes_list[index])\n",
    "            if index in c_neighbors:\n",
    "                nodes_list = subgraph_nodes_list[node]\n",
    "                sub_graph_index = nodes_list.index(index)\n",
    "                c_neighbors_list = list(c_neighbors)\n",
    "                for i, item1 in enumerate(nodes_list):\n",
    "                    if(item1 in c_neighbors):\n",
    "                        for item2 in c_neighbors_list:\n",
    "                            j = nodes_list.index(item2)\n",
    "                            count += sub_adj[i][j]\n",
    "\n",
    "            new_adj[node][index] = count/2\n",
    "            new_adj[node][index] = new_adj[node][index]/(len(c_neighbors)*(len(c_neighbors)-1))\n",
    "            new_adj[node][index] = new_adj[node][index] * (len(c_neighbors)**1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "features = torch.FloatTensor(X)\n",
    "labels = torch.LongTensor(labels)\n",
    "\n",
    "weight = torch.FloatTensor(new_adj)\n",
    "weight = weight / weight.sum(1, keepdim=True)\n",
    "\n",
    "weight = weight + torch.FloatTensor(A)\n",
    "\n",
    "coeff = weight.sum(1, keepdim=True)\n",
    "coeff = torch.diag((coeff.T)[0])\n",
    "\n",
    "weight = weight + coeff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "weight = weight.detach().numpy()\n",
    "weight = np.nan_to_num(weight, nan=0)\n",
    "adj = torch.FloatTensor(weight)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model and optimizer\n",
    "model = Graphsn_GIN(nfeat=features.shape[1],\n",
    "                    nhid=args.hidden,\n",
    "                    nclass=labels.max().item() + 1,\n",
    "                    dropout=args.dropout)\n",
    "\n",
    "optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(epoch):\n",
    "    t = time.time()\n",
    "    model.train()\n",
    "    optimizer.zero_grad()\n",
    "    output = model(features, adj)\n",
    "    loss_train = F.nll_loss(output[idx_train], labels[idx_train])\n",
    "    acc_train = accuracy(output[idx_train], labels[idx_train])\n",
    "    loss_train.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    if not args.fastmode:\n",
    "        # Evaluate validation set performance separately, deactivates dropout during validation run.\n",
    "        model.eval()\n",
    "        output = model(features, adj)\n",
    "\n",
    "    loss_val = F.nll_loss(output[idx_val], labels[idx_val])\n",
    "    acc_val = accuracy(output[idx_val], labels[idx_val])\n",
    "    print('Epoch: {:04d}'.format(epoch+1),\n",
    "          'loss_train: {:.4f}'.format(loss_train.item()),\n",
    "          'acc_train: {:.4f}'.format(acc_train.item()),\n",
    "          'loss_val: {:.4f}'.format(loss_val.item()),\n",
    "          'acc_val: {:.4f}'.format(acc_val.item()),\n",
    "          'time: {:.4f}s'.format(time.time() - t))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test():\n",
    "    model.eval()\n",
    "    output = model(features, adj)\n",
    "    loss_test = F.nll_loss(output[idx_test], labels[idx_test])\n",
    "    acc_test = accuracy(output[idx_test], labels[idx_test])\n",
    "    print(\"Test set results:\",\n",
    "          \"loss= {:.4f}\".format(loss_test.item()),\n",
    "          \"accuracy= {:.4f}\".format(acc_test.item()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0001 loss_train: 1.8148 acc_train: 0.1333 loss_val: 1.7548 acc_val: 0.2220 time: 2.2916s\n",
      "Epoch: 0002 loss_train: 1.7776 acc_train: 0.1667 loss_val: 1.7285 acc_val: 0.2960 time: 2.1459s\n",
      "Epoch: 0003 loss_train: 1.7175 acc_train: 0.3083 loss_val: 1.7092 acc_val: 0.3440 time: 2.1690s\n",
      "Epoch: 0004 loss_train: 1.6715 acc_train: 0.2917 loss_val: 1.6865 acc_val: 0.4000 time: 2.1391s\n",
      "Epoch: 0005 loss_train: 1.5913 acc_train: 0.3833 loss_val: 1.6625 acc_val: 0.4420 time: 2.0691s\n",
      "Epoch: 0006 loss_train: 1.6095 acc_train: 0.3333 loss_val: 1.6351 acc_val: 0.4660 time: 2.1318s\n",
      "Epoch: 0007 loss_train: 1.5363 acc_train: 0.4250 loss_val: 1.6037 acc_val: 0.5220 time: 2.1828s\n",
      "Epoch: 0008 loss_train: 1.5636 acc_train: 0.3833 loss_val: 1.5699 acc_val: 0.5820 time: 2.2963s\n",
      "Epoch: 0009 loss_train: 1.4269 acc_train: 0.4167 loss_val: 1.5387 acc_val: 0.6140 time: 2.2797s\n",
      "Epoch: 0010 loss_train: 1.3135 acc_train: 0.4583 loss_val: 1.5073 acc_val: 0.6300 time: 2.1398s\n",
      "Epoch: 0011 loss_train: 1.3647 acc_train: 0.4667 loss_val: 1.4789 acc_val: 0.6440 time: 2.0382s\n",
      "Epoch: 0012 loss_train: 1.3631 acc_train: 0.4583 loss_val: 1.4539 acc_val: 0.6480 time: 2.3629s\n",
      "Epoch: 0013 loss_train: 1.2775 acc_train: 0.5083 loss_val: 1.4309 acc_val: 0.6540 time: 2.7361s\n",
      "Epoch: 0014 loss_train: 1.3010 acc_train: 0.4667 loss_val: 1.4095 acc_val: 0.6480 time: 2.2435s\n",
      "Epoch: 0015 loss_train: 1.2071 acc_train: 0.5000 loss_val: 1.3882 acc_val: 0.6460 time: 2.0927s\n",
      "Epoch: 0016 loss_train: 1.2646 acc_train: 0.5583 loss_val: 1.3660 acc_val: 0.6600 time: 2.0529s\n",
      "Epoch: 0017 loss_train: 1.0925 acc_train: 0.5667 loss_val: 1.3442 acc_val: 0.6620 time: 2.1007s\n",
      "Epoch: 0018 loss_train: 1.0970 acc_train: 0.5167 loss_val: 1.3233 acc_val: 0.6700 time: 2.0690s\n",
      "Epoch: 0019 loss_train: 1.1321 acc_train: 0.6000 loss_val: 1.3034 acc_val: 0.6760 time: 2.1029s\n",
      "Epoch: 0020 loss_train: 1.0490 acc_train: 0.5750 loss_val: 1.2841 acc_val: 0.6800 time: 2.0702s\n",
      "Epoch: 0021 loss_train: 0.9323 acc_train: 0.6917 loss_val: 1.2664 acc_val: 0.6800 time: 2.0794s\n",
      "Epoch: 0022 loss_train: 1.0538 acc_train: 0.5167 loss_val: 1.2497 acc_val: 0.6780 time: 2.1856s\n",
      "Epoch: 0023 loss_train: 1.0000 acc_train: 0.5917 loss_val: 1.2351 acc_val: 0.6820 time: 2.0154s\n",
      "Epoch: 0024 loss_train: 0.9735 acc_train: 0.6000 loss_val: 1.2223 acc_val: 0.6820 time: 2.2882s\n",
      "Epoch: 0025 loss_train: 0.8866 acc_train: 0.6833 loss_val: 1.2103 acc_val: 0.6860 time: 2.8049s\n",
      "Epoch: 0026 loss_train: 0.9343 acc_train: 0.6167 loss_val: 1.1984 acc_val: 0.6860 time: 2.0699s\n",
      "Epoch: 0027 loss_train: 0.9043 acc_train: 0.6583 loss_val: 1.1876 acc_val: 0.6860 time: 2.0480s\n",
      "Epoch: 0028 loss_train: 0.8010 acc_train: 0.7000 loss_val: 1.1764 acc_val: 0.6860 time: 2.1162s\n",
      "Epoch: 0029 loss_train: 0.7970 acc_train: 0.6500 loss_val: 1.1662 acc_val: 0.6820 time: 2.0703s\n",
      "Epoch: 0030 loss_train: 0.8290 acc_train: 0.6750 loss_val: 1.1578 acc_val: 0.6820 time: 2.0852s\n",
      "Epoch: 0031 loss_train: 0.7779 acc_train: 0.6917 loss_val: 1.1493 acc_val: 0.6760 time: 2.0973s\n",
      "Epoch: 0032 loss_train: 0.8756 acc_train: 0.6000 loss_val: 1.1410 acc_val: 0.6800 time: 2.0748s\n",
      "Epoch: 0033 loss_train: 0.6820 acc_train: 0.7750 loss_val: 1.1336 acc_val: 0.6800 time: 2.0705s\n",
      "Epoch: 0034 loss_train: 0.7395 acc_train: 0.6917 loss_val: 1.1260 acc_val: 0.6780 time: 2.0499s\n",
      "Epoch: 0035 loss_train: 0.7653 acc_train: 0.7000 loss_val: 1.1192 acc_val: 0.6780 time: 2.1173s\n",
      "Epoch: 0036 loss_train: 0.6668 acc_train: 0.7667 loss_val: 1.1125 acc_val: 0.6820 time: 2.1262s\n",
      "Epoch: 0037 loss_train: 0.6283 acc_train: 0.7833 loss_val: 1.1072 acc_val: 0.6820 time: 2.1489s\n",
      "Epoch: 0038 loss_train: 0.7234 acc_train: 0.6667 loss_val: 1.1026 acc_val: 0.6840 time: 2.2315s\n",
      "Epoch: 0039 loss_train: 0.6754 acc_train: 0.7250 loss_val: 1.0996 acc_val: 0.6820 time: 2.1446s\n",
      "Epoch: 0040 loss_train: 0.7199 acc_train: 0.7417 loss_val: 1.0970 acc_val: 0.6800 time: 2.1086s\n",
      "Epoch: 0041 loss_train: 0.5950 acc_train: 0.7500 loss_val: 1.0942 acc_val: 0.6820 time: 2.1279s\n",
      "Epoch: 0042 loss_train: 0.6414 acc_train: 0.7917 loss_val: 1.0921 acc_val: 0.6840 time: 2.0997s\n",
      "Epoch: 0043 loss_train: 0.6235 acc_train: 0.8000 loss_val: 1.0905 acc_val: 0.6820 time: 2.1009s\n",
      "Epoch: 0044 loss_train: 0.5653 acc_train: 0.7750 loss_val: 1.0898 acc_val: 0.6780 time: 2.1321s\n",
      "Epoch: 0045 loss_train: 0.5629 acc_train: 0.8000 loss_val: 1.0898 acc_val: 0.6740 time: 2.0695s\n",
      "Epoch: 0046 loss_train: 0.6146 acc_train: 0.7333 loss_val: 1.0895 acc_val: 0.6740 time: 2.1257s\n",
      "Epoch: 0047 loss_train: 0.5655 acc_train: 0.7750 loss_val: 1.0879 acc_val: 0.6680 time: 2.0696s\n",
      "Epoch: 0048 loss_train: 0.6473 acc_train: 0.7333 loss_val: 1.0865 acc_val: 0.6640 time: 2.0452s\n",
      "Epoch: 0049 loss_train: 0.6147 acc_train: 0.8083 loss_val: 1.0846 acc_val: 0.6660 time: 2.0815s\n",
      "Epoch: 0050 loss_train: 0.6837 acc_train: 0.7333 loss_val: 1.0829 acc_val: 0.6640 time: 2.1082s\n",
      "Epoch: 0051 loss_train: 0.6563 acc_train: 0.7833 loss_val: 1.0814 acc_val: 0.6640 time: 2.0526s\n",
      "Epoch: 0052 loss_train: 0.4592 acc_train: 0.8417 loss_val: 1.0797 acc_val: 0.6640 time: 2.2825s\n",
      "Epoch: 0053 loss_train: 0.5141 acc_train: 0.7750 loss_val: 1.0781 acc_val: 0.6620 time: 2.7300s\n",
      "Epoch: 0054 loss_train: 0.6046 acc_train: 0.7417 loss_val: 1.0777 acc_val: 0.6620 time: 2.8575s\n",
      "Epoch: 0055 loss_train: 0.6073 acc_train: 0.7417 loss_val: 1.0778 acc_val: 0.6640 time: 2.9832s\n",
      "Epoch: 0056 loss_train: 0.6164 acc_train: 0.7333 loss_val: 1.0793 acc_val: 0.6620 time: 3.0211s\n",
      "Epoch: 0057 loss_train: 0.5682 acc_train: 0.7917 loss_val: 1.0811 acc_val: 0.6580 time: 3.0366s\n",
      "Epoch: 0058 loss_train: 0.5667 acc_train: 0.7250 loss_val: 1.0827 acc_val: 0.6560 time: 3.0547s\n",
      "Epoch: 0059 loss_train: 0.6146 acc_train: 0.6917 loss_val: 1.0839 acc_val: 0.6580 time: 2.9817s\n",
      "Epoch: 0060 loss_train: 0.4687 acc_train: 0.8167 loss_val: 1.0847 acc_val: 0.6580 time: 2.9935s\n",
      "Epoch: 0061 loss_train: 0.5095 acc_train: 0.8000 loss_val: 1.0856 acc_val: 0.6560 time: 3.0496s\n",
      "Epoch: 0062 loss_train: 0.4226 acc_train: 0.8417 loss_val: 1.0865 acc_val: 0.6560 time: 3.0398s\n",
      "Epoch: 0063 loss_train: 0.4326 acc_train: 0.8417 loss_val: 1.0874 acc_val: 0.6580 time: 3.0089s\n",
      "Epoch: 0064 loss_train: 0.5823 acc_train: 0.7750 loss_val: 1.0886 acc_val: 0.6580 time: 3.1120s\n",
      "Epoch: 0065 loss_train: 0.4822 acc_train: 0.8250 loss_val: 1.0893 acc_val: 0.6600 time: 3.0500s\n",
      "Epoch: 0066 loss_train: 0.5429 acc_train: 0.7667 loss_val: 1.0911 acc_val: 0.6600 time: 2.9898s\n",
      "Epoch: 0067 loss_train: 0.5416 acc_train: 0.7333 loss_val: 1.0931 acc_val: 0.6580 time: 2.9851s\n",
      "Epoch: 0068 loss_train: 0.5575 acc_train: 0.7250 loss_val: 1.0952 acc_val: 0.6580 time: 2.9825s\n",
      "Epoch: 0069 loss_train: 0.4385 acc_train: 0.8250 loss_val: 1.0969 acc_val: 0.6560 time: 3.0361s\n",
      "Epoch: 0070 loss_train: 0.3706 acc_train: 0.8583 loss_val: 1.0985 acc_val: 0.6540 time: 3.0435s\n",
      "Epoch: 0071 loss_train: 0.4474 acc_train: 0.8083 loss_val: 1.0995 acc_val: 0.6580 time: 3.0395s\n",
      "Epoch: 0072 loss_train: 0.5213 acc_train: 0.7333 loss_val: 1.1000 acc_val: 0.6620 time: 3.1788s\n",
      "Epoch: 0073 loss_train: 0.4616 acc_train: 0.7750 loss_val: 1.0978 acc_val: 0.6620 time: 3.0169s\n",
      "Epoch: 0074 loss_train: 0.4075 acc_train: 0.8167 loss_val: 1.0952 acc_val: 0.6640 time: 2.9744s\n",
      "Epoch: 0075 loss_train: 0.4850 acc_train: 0.8250 loss_val: 1.0925 acc_val: 0.6640 time: 3.0156s\n",
      "Epoch: 0076 loss_train: 0.6089 acc_train: 0.6833 loss_val: 1.0893 acc_val: 0.6600 time: 2.9828s\n",
      "Epoch: 0077 loss_train: 0.5372 acc_train: 0.7667 loss_val: 1.0859 acc_val: 0.6620 time: 2.9891s\n",
      "Epoch: 0078 loss_train: 0.4454 acc_train: 0.7583 loss_val: 1.0830 acc_val: 0.6640 time: 3.0036s\n",
      "Epoch: 0079 loss_train: 0.3925 acc_train: 0.8167 loss_val: 1.0803 acc_val: 0.6660 time: 3.1400s\n",
      "Epoch: 0080 loss_train: 0.3654 acc_train: 0.8833 loss_val: 1.0791 acc_val: 0.6680 time: 3.1041s\n",
      "Epoch: 0081 loss_train: 0.4729 acc_train: 0.8083 loss_val: 1.0780 acc_val: 0.6700 time: 3.2437s\n",
      "Epoch: 0082 loss_train: 0.5274 acc_train: 0.7333 loss_val: 1.0776 acc_val: 0.6700 time: 2.4476s\n",
      "Epoch: 0083 loss_train: 0.3897 acc_train: 0.8250 loss_val: 1.0771 acc_val: 0.6720 time: 2.0880s\n",
      "Epoch: 0084 loss_train: 0.4419 acc_train: 0.8000 loss_val: 1.0761 acc_val: 0.6740 time: 2.1264s\n",
      "Epoch: 0085 loss_train: 0.5429 acc_train: 0.7417 loss_val: 1.0755 acc_val: 0.6760 time: 2.1499s\n",
      "Epoch: 0086 loss_train: 0.4969 acc_train: 0.7583 loss_val: 1.0757 acc_val: 0.6780 time: 2.0976s\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0087 loss_train: 0.4607 acc_train: 0.7917 loss_val: 1.0772 acc_val: 0.6800 time: 2.1857s\n",
      "Epoch: 0088 loss_train: 0.4612 acc_train: 0.7833 loss_val: 1.0790 acc_val: 0.6760 time: 2.0693s\n",
      "Epoch: 0089 loss_train: 0.4038 acc_train: 0.8083 loss_val: 1.0796 acc_val: 0.6740 time: 2.0378s\n",
      "Epoch: 0090 loss_train: 0.3484 acc_train: 0.8500 loss_val: 1.0807 acc_val: 0.6720 time: 2.0844s\n",
      "Epoch: 0091 loss_train: 0.4106 acc_train: 0.8083 loss_val: 1.0811 acc_val: 0.6720 time: 2.0847s\n",
      "Epoch: 0092 loss_train: 0.4071 acc_train: 0.8500 loss_val: 1.0814 acc_val: 0.6740 time: 2.0846s\n",
      "Epoch: 0093 loss_train: 0.4340 acc_train: 0.7833 loss_val: 1.0821 acc_val: 0.6780 time: 2.1078s\n",
      "Epoch: 0094 loss_train: 0.3514 acc_train: 0.8667 loss_val: 1.0826 acc_val: 0.6780 time: 2.2033s\n",
      "Epoch: 0095 loss_train: 0.3943 acc_train: 0.8250 loss_val: 1.0826 acc_val: 0.6700 time: 2.7529s\n",
      "Epoch: 0096 loss_train: 0.3664 acc_train: 0.8667 loss_val: 1.0819 acc_val: 0.6680 time: 3.1116s\n",
      "Epoch: 0097 loss_train: 0.4203 acc_train: 0.7833 loss_val: 1.0815 acc_val: 0.6660 time: 2.1632s\n",
      "Epoch: 0098 loss_train: 0.4276 acc_train: 0.8250 loss_val: 1.0817 acc_val: 0.6620 time: 2.0537s\n",
      "Epoch: 0099 loss_train: 0.3808 acc_train: 0.8250 loss_val: 1.0820 acc_val: 0.6600 time: 2.1318s\n",
      "Epoch: 0100 loss_train: 0.4496 acc_train: 0.8000 loss_val: 1.0826 acc_val: 0.6580 time: 2.1434s\n",
      "Epoch: 0101 loss_train: 0.4715 acc_train: 0.7833 loss_val: 1.0828 acc_val: 0.6600 time: 2.1547s\n",
      "Epoch: 0102 loss_train: 0.4238 acc_train: 0.7833 loss_val: 1.0825 acc_val: 0.6600 time: 2.0555s\n",
      "Epoch: 0103 loss_train: 0.4232 acc_train: 0.7833 loss_val: 1.0817 acc_val: 0.6640 time: 2.0382s\n",
      "Epoch: 0104 loss_train: 0.4756 acc_train: 0.7750 loss_val: 1.0804 acc_val: 0.6680 time: 2.1009s\n",
      "Epoch: 0105 loss_train: 0.4290 acc_train: 0.8167 loss_val: 1.0795 acc_val: 0.6700 time: 2.0547s\n",
      "Epoch: 0106 loss_train: 0.4897 acc_train: 0.7750 loss_val: 1.0776 acc_val: 0.6700 time: 2.1125s\n",
      "Epoch: 0107 loss_train: 0.4647 acc_train: 0.7750 loss_val: 1.0758 acc_val: 0.6740 time: 2.1022s\n",
      "Epoch: 0108 loss_train: 0.4284 acc_train: 0.8250 loss_val: 1.0740 acc_val: 0.6740 time: 2.0992s\n",
      "Epoch: 0109 loss_train: 0.4978 acc_train: 0.7583 loss_val: 1.0721 acc_val: 0.6720 time: 2.1938s\n",
      "Epoch: 0110 loss_train: 0.3239 acc_train: 0.8750 loss_val: 1.0698 acc_val: 0.6720 time: 2.1615s\n",
      "Epoch: 0111 loss_train: 0.3955 acc_train: 0.8083 loss_val: 1.0676 acc_val: 0.6700 time: 2.0364s\n",
      "Epoch: 0112 loss_train: 0.4305 acc_train: 0.7750 loss_val: 1.0652 acc_val: 0.6720 time: 2.1201s\n",
      "Epoch: 0113 loss_train: 0.4323 acc_train: 0.8083 loss_val: 1.0632 acc_val: 0.6720 time: 2.0088s\n",
      "Epoch: 0114 loss_train: 0.4694 acc_train: 0.7667 loss_val: 1.0614 acc_val: 0.6720 time: 2.1338s\n",
      "Epoch: 0115 loss_train: 0.4467 acc_train: 0.8250 loss_val: 1.0599 acc_val: 0.6700 time: 2.3154s\n",
      "Epoch: 0116 loss_train: 0.4520 acc_train: 0.8000 loss_val: 1.0585 acc_val: 0.6740 time: 2.0883s\n",
      "Epoch: 0117 loss_train: 0.4672 acc_train: 0.7583 loss_val: 1.0578 acc_val: 0.6740 time: 2.0993s\n",
      "Epoch: 0118 loss_train: 0.4845 acc_train: 0.7667 loss_val: 1.0573 acc_val: 0.6760 time: 2.0342s\n",
      "Epoch: 0119 loss_train: 0.4289 acc_train: 0.8500 loss_val: 1.0567 acc_val: 0.6760 time: 2.0835s\n",
      "Epoch: 0120 loss_train: 0.5296 acc_train: 0.7417 loss_val: 1.0562 acc_val: 0.6760 time: 2.0783s\n",
      "Epoch: 0121 loss_train: 0.4722 acc_train: 0.7667 loss_val: 1.0557 acc_val: 0.6760 time: 2.1072s\n",
      "Epoch: 0122 loss_train: 0.3947 acc_train: 0.7833 loss_val: 1.0553 acc_val: 0.6780 time: 2.0512s\n",
      "Epoch: 0123 loss_train: 0.4185 acc_train: 0.7750 loss_val: 1.0553 acc_val: 0.6740 time: 2.0765s\n",
      "Epoch: 0124 loss_train: 0.4128 acc_train: 0.8083 loss_val: 1.0552 acc_val: 0.6740 time: 2.8544s\n",
      "Epoch: 0125 loss_train: 0.3566 acc_train: 0.8500 loss_val: 1.0545 acc_val: 0.6740 time: 2.7968s\n",
      "Epoch: 0126 loss_train: 0.3750 acc_train: 0.8167 loss_val: 1.0540 acc_val: 0.6740 time: 3.0649s\n",
      "Epoch: 0127 loss_train: 0.3437 acc_train: 0.8583 loss_val: 1.0543 acc_val: 0.6760 time: 3.2563s\n",
      "Epoch: 0128 loss_train: 0.2719 acc_train: 0.8750 loss_val: 1.0553 acc_val: 0.6780 time: 2.9664s\n",
      "Epoch: 0129 loss_train: 0.5191 acc_train: 0.7500 loss_val: 1.0560 acc_val: 0.6760 time: 2.9897s\n",
      "Epoch: 0130 loss_train: 0.3538 acc_train: 0.8417 loss_val: 1.0566 acc_val: 0.6740 time: 2.9950s\n",
      "Epoch: 0131 loss_train: 0.3322 acc_train: 0.8167 loss_val: 1.0571 acc_val: 0.6720 time: 2.9893s\n",
      "Epoch: 0132 loss_train: 0.3972 acc_train: 0.8250 loss_val: 1.0574 acc_val: 0.6720 time: 2.9774s\n",
      "Epoch: 0133 loss_train: 0.4288 acc_train: 0.8667 loss_val: 1.0586 acc_val: 0.6740 time: 3.0088s\n",
      "Epoch: 0134 loss_train: 0.3662 acc_train: 0.8750 loss_val: 1.0598 acc_val: 0.6740 time: 2.9844s\n",
      "Epoch: 0135 loss_train: 0.3857 acc_train: 0.8417 loss_val: 1.0603 acc_val: 0.6760 time: 2.9673s\n",
      "Epoch: 0136 loss_train: 0.3227 acc_train: 0.8583 loss_val: 1.0612 acc_val: 0.6780 time: 2.9847s\n",
      "Epoch: 0137 loss_train: 0.4191 acc_train: 0.8167 loss_val: 1.0621 acc_val: 0.6760 time: 3.1684s\n",
      "Epoch: 0138 loss_train: 0.4217 acc_train: 0.8167 loss_val: 1.0634 acc_val: 0.6740 time: 3.3130s\n",
      "Epoch: 0139 loss_train: 0.5117 acc_train: 0.7667 loss_val: 1.0654 acc_val: 0.6740 time: 3.1209s\n",
      "Epoch: 0140 loss_train: 0.3731 acc_train: 0.8417 loss_val: 1.0676 acc_val: 0.6740 time: 3.0037s\n",
      "Epoch: 0141 loss_train: 0.4850 acc_train: 0.7417 loss_val: 1.0691 acc_val: 0.6720 time: 3.0149s\n",
      "Epoch: 0142 loss_train: 0.3064 acc_train: 0.8750 loss_val: 1.0703 acc_val: 0.6720 time: 3.0150s\n",
      "Epoch: 0143 loss_train: 0.3376 acc_train: 0.7917 loss_val: 1.0712 acc_val: 0.6740 time: 3.2104s\n",
      "Epoch: 0144 loss_train: 0.3819 acc_train: 0.8667 loss_val: 1.0726 acc_val: 0.6740 time: 3.0395s\n",
      "Epoch: 0145 loss_train: 0.3444 acc_train: 0.8583 loss_val: 1.0744 acc_val: 0.6740 time: 3.0046s\n",
      "Epoch: 0146 loss_train: 0.5623 acc_train: 0.7583 loss_val: 1.0755 acc_val: 0.6720 time: 3.1068s\n",
      "Epoch: 0147 loss_train: 0.3261 acc_train: 0.8583 loss_val: 1.0765 acc_val: 0.6720 time: 3.0917s\n",
      "Epoch: 0148 loss_train: 0.4237 acc_train: 0.7583 loss_val: 1.0772 acc_val: 0.6720 time: 3.0600s\n",
      "Epoch: 0149 loss_train: 0.4416 acc_train: 0.7917 loss_val: 1.0774 acc_val: 0.6740 time: 3.0544s\n",
      "Epoch: 0150 loss_train: 0.4490 acc_train: 0.7667 loss_val: 1.0775 acc_val: 0.6740 time: 3.0892s\n",
      "Epoch: 0151 loss_train: 0.3193 acc_train: 0.8417 loss_val: 1.0770 acc_val: 0.6760 time: 3.0100s\n",
      "Epoch: 0152 loss_train: 0.3876 acc_train: 0.8083 loss_val: 1.0774 acc_val: 0.6700 time: 3.0254s\n",
      "Epoch: 0153 loss_train: 0.3621 acc_train: 0.8250 loss_val: 1.0784 acc_val: 0.6720 time: 3.1563s\n",
      "Epoch: 0154 loss_train: 0.2894 acc_train: 0.9083 loss_val: 1.0781 acc_val: 0.6720 time: 2.1132s\n",
      "Epoch: 0155 loss_train: 0.3714 acc_train: 0.8250 loss_val: 1.0772 acc_val: 0.6740 time: 2.1124s\n",
      "Epoch: 0156 loss_train: 0.4159 acc_train: 0.8250 loss_val: 1.0761 acc_val: 0.6740 time: 2.2048s\n",
      "Epoch: 0157 loss_train: 0.3595 acc_train: 0.8417 loss_val: 1.0750 acc_val: 0.6740 time: 2.0377s\n",
      "Epoch: 0158 loss_train: 0.4288 acc_train: 0.8083 loss_val: 1.0749 acc_val: 0.6720 time: 2.0900s\n",
      "Epoch: 0159 loss_train: 0.4088 acc_train: 0.8083 loss_val: 1.0753 acc_val: 0.6700 time: 2.2290s\n",
      "Epoch: 0160 loss_train: 0.3808 acc_train: 0.8000 loss_val: 1.0763 acc_val: 0.6700 time: 2.1596s\n",
      "Epoch: 0161 loss_train: 0.4787 acc_train: 0.7667 loss_val: 1.0780 acc_val: 0.6700 time: 2.0888s\n",
      "Epoch: 0162 loss_train: 0.4260 acc_train: 0.8333 loss_val: 1.0798 acc_val: 0.6720 time: 2.1127s\n",
      "Epoch: 0163 loss_train: 0.4653 acc_train: 0.7667 loss_val: 1.0823 acc_val: 0.6740 time: 2.0673s\n",
      "Epoch: 0164 loss_train: 0.3204 acc_train: 0.8583 loss_val: 1.0838 acc_val: 0.6740 time: 2.0794s\n",
      "Epoch: 0165 loss_train: 0.4070 acc_train: 0.8083 loss_val: 1.0846 acc_val: 0.6740 time: 2.0832s\n",
      "Epoch: 0166 loss_train: 0.3769 acc_train: 0.7917 loss_val: 1.0853 acc_val: 0.6720 time: 2.8133s\n",
      "Epoch: 0167 loss_train: 0.4627 acc_train: 0.7833 loss_val: 1.0869 acc_val: 0.6720 time: 2.6668s\n",
      "Epoch: 0168 loss_train: 0.5171 acc_train: 0.7250 loss_val: 1.0893 acc_val: 0.6720 time: 2.0823s\n",
      "Epoch: 0169 loss_train: 0.3163 acc_train: 0.8500 loss_val: 1.0913 acc_val: 0.6720 time: 2.2224s\n",
      "Epoch: 0170 loss_train: 0.4312 acc_train: 0.8083 loss_val: 1.0929 acc_val: 0.6720 time: 2.2397s\n",
      "Epoch: 0171 loss_train: 0.3819 acc_train: 0.8333 loss_val: 1.0943 acc_val: 0.6660 time: 2.3715s\n",
      "Epoch: 0172 loss_train: 0.3465 acc_train: 0.8333 loss_val: 1.0961 acc_val: 0.6660 time: 2.3492s\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0173 loss_train: 0.3928 acc_train: 0.8583 loss_val: 1.0962 acc_val: 0.6680 time: 2.5139s\n",
      "Epoch: 0174 loss_train: 0.3568 acc_train: 0.8583 loss_val: 1.0961 acc_val: 0.6700 time: 2.5219s\n",
      "Epoch: 0175 loss_train: 0.4383 acc_train: 0.8000 loss_val: 1.0961 acc_val: 0.6680 time: 2.5468s\n",
      "Epoch: 0176 loss_train: 0.4136 acc_train: 0.8167 loss_val: 1.0957 acc_val: 0.6680 time: 2.4608s\n",
      "Epoch: 0177 loss_train: 0.3687 acc_train: 0.8083 loss_val: 1.0951 acc_val: 0.6680 time: 2.5738s\n",
      "Epoch: 0178 loss_train: 0.4437 acc_train: 0.8083 loss_val: 1.0962 acc_val: 0.6680 time: 2.4260s\n",
      "Epoch: 0179 loss_train: 0.3606 acc_train: 0.8250 loss_val: 1.0966 acc_val: 0.6680 time: 2.4033s\n",
      "Epoch: 0180 loss_train: 0.4558 acc_train: 0.7917 loss_val: 1.0956 acc_val: 0.6660 time: 2.2115s\n",
      "Epoch: 0181 loss_train: 0.3556 acc_train: 0.8083 loss_val: 1.0941 acc_val: 0.6680 time: 2.1554s\n",
      "Epoch: 0182 loss_train: 0.4045 acc_train: 0.8250 loss_val: 1.0933 acc_val: 0.6680 time: 2.1688s\n",
      "Epoch: 0183 loss_train: 0.3833 acc_train: 0.8333 loss_val: 1.0938 acc_val: 0.6600 time: 2.1939s\n",
      "Epoch: 0184 loss_train: 0.3691 acc_train: 0.8000 loss_val: 1.0923 acc_val: 0.6740 time: 2.1450s\n",
      "Epoch: 0185 loss_train: 0.4023 acc_train: 0.8083 loss_val: 1.0918 acc_val: 0.6700 time: 2.2922s\n",
      "Epoch: 0186 loss_train: 0.3346 acc_train: 0.8417 loss_val: 1.0925 acc_val: 0.6700 time: 2.4050s\n",
      "Epoch: 0187 loss_train: 0.3348 acc_train: 0.8417 loss_val: 1.0927 acc_val: 0.6700 time: 2.3188s\n",
      "Epoch: 0188 loss_train: 0.3266 acc_train: 0.8583 loss_val: 1.0923 acc_val: 0.6700 time: 2.0870s\n",
      "Epoch: 0189 loss_train: 0.3212 acc_train: 0.8833 loss_val: 1.0917 acc_val: 0.6680 time: 2.0832s\n",
      "Epoch: 0190 loss_train: 0.3432 acc_train: 0.8417 loss_val: 1.0909 acc_val: 0.6700 time: 2.1449s\n",
      "Epoch: 0191 loss_train: 0.3644 acc_train: 0.8167 loss_val: 1.0908 acc_val: 0.6700 time: 2.0978s\n",
      "Epoch: 0192 loss_train: 0.3843 acc_train: 0.8500 loss_val: 1.0919 acc_val: 0.6680 time: 2.3952s\n",
      "Epoch: 0193 loss_train: 0.3052 acc_train: 0.8083 loss_val: 1.0930 acc_val: 0.6720 time: 2.6783s\n",
      "Epoch: 0194 loss_train: 0.3690 acc_train: 0.8417 loss_val: 1.0941 acc_val: 0.6720 time: 2.8039s\n",
      "Epoch: 0195 loss_train: 0.4179 acc_train: 0.8167 loss_val: 1.0970 acc_val: 0.6680 time: 3.4986s\n",
      "Epoch: 0196 loss_train: 0.3902 acc_train: 0.8000 loss_val: 1.0996 acc_val: 0.6680 time: 3.3892s\n",
      "Epoch: 0197 loss_train: 0.5190 acc_train: 0.7833 loss_val: 1.1027 acc_val: 0.6680 time: 3.1226s\n",
      "Epoch: 0198 loss_train: 0.3682 acc_train: 0.8250 loss_val: 1.1051 acc_val: 0.6640 time: 3.7620s\n",
      "Epoch: 0199 loss_train: 0.3824 acc_train: 0.8083 loss_val: 1.1082 acc_val: 0.6620 time: 3.0896s\n",
      "Epoch: 0200 loss_train: 0.3730 acc_train: 0.8000 loss_val: 1.1119 acc_val: 0.6640 time: 3.0363s\n",
      "Optimization Finished!\n",
      "Total time elapsed: 498.1889s\n",
      "Test set results: loss= 1.0734 accuracy= 0.6870\n"
     ]
    }
   ],
   "source": [
    "# Train model\n",
    "t_total = time.time()\n",
    "for epoch in range(args.epochs):\n",
    "    train(epoch)\n",
    "print(\"Optimization Finished!\")\n",
    "print(\"Total time elapsed: {:.4f}s\".format(time.time() - t_total))\n",
    "\n",
    "# Testing\n",
    "test()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
