{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip install dgl-cu113 -f https://data.dgl.ai/wheels/repo.html"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fixing some settings:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import random\n",
    "import torch\n",
    "# Define a fixed seed value\n",
    "SEED = 2 # We selected the SEEDs in [1,10]\n",
    "test_size = 0.2\n",
    "val_size = 0.2 \n",
    "num_filters = 16 # Number of features for learnable filters\n",
    "\n",
    "# 1. Set the Python built-in random module's seed\n",
    "random.seed(SEED)\n",
    "\n",
    "# 2. Set the NumPy random seed\n",
    "np.random.seed(SEED)\n",
    "\n",
    "# 3. Set the PyTorch seed (for both CPU and GPU)\n",
    "torch.manual_seed(SEED)# Define a fixed seed value\n",
    "print(torch.cuda.is_available())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import toponetx.datasets as datasets\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "from Utils.sccnn_exp import COSIMO\n",
    "from topomodelx.utils.sparse import from_sparse\n",
    "\n",
    "# %load_ext autoreload\n",
    "# %autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A Function for performing EVD on Laplacians and selecting the first $k$ eigenvalue-eigenvector pairs:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import scipy\n",
    "from scipy import sparse\n",
    "\n",
    "def get_evals_evecs(L, k):\n",
    "    L_sparse = sparse.coo_matrix(L)\n",
    "\n",
    "    evals, evecs = scipy.sparse.linalg.eigs(L_sparse, k=k, ncv=4*k, return_eigenvectors=True)\n",
    "    # evals, evecs = scipy.linalg.eig(L)\n",
    "\n",
    "    evals=torch.tensor(evals.real)\n",
    "    evecs=torch.tensor(evecs.real)\n",
    "\n",
    "    return evals, evecs "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from Utils.preprocessing.simplicial_construction import get_boundary_matrices, get_neighbors, get_weight_matrix_graph, get_weight_matrix_simplex,generate_triangles,_get_laplacians,_get_simplex_features,augment_simplex,augment_simplex_open\n",
    "import argparse\n",
    "from Utils.preprocessing.graph_construction import _get_graph\n",
    "import torch\n",
    "import networkx as nx\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n"
     ]
    }
   ],
   "source": [
    "parser = argparse.ArgumentParser(description='TopoSRL')\n",
    "\n",
    "parser.add_argument('--dataname', type=str, default='contact-high-school', help='Name of dataset.')\n",
    "parser.add_argument('--gpu', type=int, default=0, help='GPU index.')\n",
    "parser.add_argument('--dim', type=int, default=4, help='Order of the simplicial complex.')\n",
    "\n",
    "args = parser.parse_args(args=[])\n",
    "\n",
    "if args.gpu != -1 and torch.cuda.is_available():\n",
    "    # args.device = 'cuda:{}'.format(args.gpu)\n",
    "    args.device = 'cuda'\n",
    "else:\n",
    "    args.device = 'cpu'\n",
    "print(args.device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Read the labels\n",
      "Reading the simplicies\n",
      "Creating tree\n",
      "Computing boundary matrix for dimension 1\n",
      "Computing boundary matrix for dimension 2\n",
      "Computing boundary matrix for dimension 3\n",
      "Computing boundary matrix for dimension 4\n",
      "Got boundaries\n",
      "Anchor nodes initialized\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_31928/3869731468.py:5: FutureWarning: adjacency_matrix will return a scipy.sparse array instead of a matrix in Networkx 3.0.\n",
      "  A = nx.adjacency_matrix(netxG).todense()\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[9 9 3 3 8 8 3 3 7 7 7 8 8 1 1 3 3 8 8 8 9 9 4 4 9 9 9 3 3 2 2 3 3 6 6 3 3\n",
      " 3 9 9 3 3 1 8 8 8 3 3 5 5 4 4 5 5 1 1 2 2 5 5 3 5 5 3 3 3 4 4 4 3 9 5 5 9\n",
      " 9 3 2 2 2 9 9 1 1 2 3 3 1 1 8 8 3 3 9 3 4 4 3 5 5 5 3 9 9 8 8 8 3 3 9 1 1\n",
      " 8 8 3 3 3 3 5 4 4 3 1 1 1 5 5 3 9 3 3 9 2 7 8 1 7 7 2 2 7 7 7 9 7 7 4 6 6\n",
      " 9 7 9 9 9 9 9 4 4 4 4 4 2 4 2 4 4 7 9 4 4 5 5 7 5 5 2 2 5 2 2 5 2 2 5 5 5\n",
      " 2 5 5 5 4 2 2 2 5 5 5 8 2 6 6 5 7 6 6 7 2 6 6 2 2 8 8 6 6 6 8 8 8 8 2 2 8\n",
      " 6 8 1 8 6 8 8 8 5 1 2 6 7 1 1 1 1 1 1 6 8 6 1 1 1 6 8 6 8 1 6 6 6 6 6 7 8\n",
      " 5 1 8 6 1 7 1 1 7 7 4 9 7 7 5 4 7 1 7 9 7 7 6 4 1 7 1 7 7 2 7 7 7 7 6 8 6\n",
      " 6 8 7 7 7 2 7 5 7 8 7 6 2 6 6 1 7 4 7 7 5 9 4 5 9 7 9 2 1 6 4]\n"
     ]
    }
   ],
   "source": [
    "simplex_tree, sc, boundry_matrices, labels =  get_boundary_matrices(args.dataname, args.dim)\n",
    "print(\"Got boundaries\")\n",
    "g, netxG = _get_graph(sc[1])\n",
    "g = g.to(args.device)\n",
    "A = nx.adjacency_matrix(netxG).todense()\n",
    "sm = torch.nn.Softmax(dim=1)\n",
    "W0 = get_weight_matrix_graph(A)\n",
    "W0 = sm(torch.FloatTensor(W0).to(args.device))\n",
    "W0 = W0 * (W0!=W0.min(axis=1).values.unsqueeze(-1))\n",
    "laplacians, lower_laplacians, upper_laplacians = _get_laplacians(boundry_matrices)\n",
    "_X = _get_simplex_features(sc[1:4], g.ndata['features'])\n",
    "print(labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define Neighborhood Strctures\n",
    "Get incidence matrices $\\mathbf{B}_1,\\mathbf{B}_2$ and Hodge Laplacians $\\mathbf{L}_0, \\mathbf{L}_1$ and $\\mathbf{L}_2$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The incidence matrix B1 has shape: (327, 5818).\n",
      "The incidence matrix B2 has shape: (5818, 2370).\n"
     ]
    }
   ],
   "source": [
    "incidence_1 = boundry_matrices[0].cpu().detach().numpy()\n",
    "incidence_2 = boundry_matrices[1].cpu().detach().numpy()\n",
    "\n",
    "print(f\"The incidence matrix B1 has shape: {incidence_1.shape}.\")\n",
    "print(f\"The incidence matrix B2 has shape: {incidence_2.shape}.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([327, 327])\n",
      "torch.Size([5818, 5818])\n",
      "torch.Size([5818, 5818])\n",
      "torch.Size([2370, 2370])\n",
      "torch.Size([2370, 2370])\n"
     ]
    }
   ],
   "source": [
    "laplacian_0 = laplacians[0]\n",
    "laplacian_down_1 = lower_laplacians[1]\n",
    "laplacian_up_1 = upper_laplacians[1]\n",
    "laplacian_down_2 = lower_laplacians[2]\n",
    "laplacian_up_2 = upper_laplacians[2]\n",
    "\n",
    "print(laplacian_0.shape)\n",
    "print(laplacian_down_1.shape)\n",
    "print(laplacian_up_1.shape)\n",
    "print(laplacian_down_2.shape)\n",
    "print(laplacian_up_2.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "kk = 10 # Number of selected eigenvalue-eigenvector pairs:\n",
    "evals_0, evecs_0 = get_evals_evecs(laplacian_0.cpu().detach(), kk)\n",
    "evals_d1, evecs_d1 = get_evals_evecs(laplacian_down_1.cpu().detach(), kk)\n",
    "evals_u1, evecs_u1 = get_evals_evecs(laplacian_up_1.cpu().detach(), kk)\n",
    "evals_d2, evecs_d2 = get_evals_evecs(laplacian_down_2.cpu().detach(), kk)\n",
    "evals_u2, evecs_u2 = get_evals_evecs(laplacian_up_2.cpu().detach(), kk)\n",
    "\n",
    "from scipy.sparse import coo_matrix\n",
    "incidence_1 = coo_matrix(incidence_1)  # Convert NumPy array to COO sparse format\n",
    "incidence_2 = coo_matrix(incidence_2)  # Convert NumPy array to COO sparse format\n",
    "\n",
    "incidence_1 = from_sparse(incidence_1)\n",
    "incidence_2 = from_sparse(incidence_2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import signals ##"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"A function to obtain features based on the input: rank\n",
    "\"\"\"\n",
    "def get_simplicial_features(dataset, rank):\n",
    "    if rank == 0:\n",
    "        which_feat = \"node_feat\"\n",
    "    elif rank == 1:\n",
    "        which_feat = \"edge_feat\"\n",
    "    elif rank == 2:\n",
    "        which_feat = \"face_feat\"\n",
    "    else:\n",
    "        raise ValueError(\n",
    "            \"input dimension must be 0, 1 or 2, because features are supported on nodes, edges and faces\"\n",
    "        )\n",
    "\n",
    "    x = list(dataset.get_simplex_attributes(which_feat).values())\n",
    "    return torch.tensor(np.stack(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "There are 327 nodes with features of dimension 4.\n",
      "There are 5818 edges with features of dimension 4.\n",
      "There are 2370 faces with features of dimension 4.\n"
     ]
    }
   ],
   "source": [
    "x_0 = _X[0]\n",
    "x_1 = _X[1]\n",
    "x_2 = _X[2]\n",
    "print(f\"There are {x_0.shape[0]} nodes with features of dimension {x_0.shape[1]}.\")\n",
    "print(f\"There are {x_1.shape[0]} edges with features of dimension {x_1.shape[1]}.\")\n",
    "print(f\"There are {x_2.shape[0]} faces with features of dimension {x_2.shape[1]}.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define binary labels\n",
    "We retrieve the labels associated to the nodes of each input simplex. In the KarateClub dataset, two social groups emerge. So we assign binary labels to the nodes indicating of which group they are a part.\n",
    "\n",
    "We convert the binary labels into one-hot encoder form, and keep the first four nodes' true labels for the purpose of testing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[8 8 2 2 7 7 2 2 6 6 6 7 7 0 0 2 2 7 7 7 8 8 3 3 8 8 8 2 2 1 1 2 2 5 5 2 2\n",
      " 2 8 8 2 2 0 7 7 7 2 2 4 4 3 3 4 4 0 0 1 1 4 4 2 4 4 2 2 2 3 3 3 2 8 4 4 8\n",
      " 8 2 1 1 1 8 8 0 0 1 2 2 0 0 7 7 2 2 8 2 3 3 2 4 4 4 2 8 8 7 7 7 2 2 8 0 0\n",
      " 7 7 2 2 2 2 4 3 3 2 0 0 0 4 4 2 8 2 2 8 1 6 7 0 6 6 1 1 6 6 6 8 6 6 3 5 5\n",
      " 8 6 8 8 8 8 8 3 3 3 3 3 1 3 1 3 3 6 8 3 3 4 4 6 4 4 1 1 4 1 1 4 1 1 4 4 4\n",
      " 1 4 4 4 3 1 1 1 4 4 4 7 1 5 5 4 6 5 5 6 1 5 5 1 1 7 7 5 5 5 7 7 7 7 1 1 7\n",
      " 5 7 0 7 5 7 7 7 4 0 1 5 6 0 0 0 0 0 0 5 7 5 0 0 0 5 7 5 7 0 5 5 5 5 5 6 7\n",
      " 4 0 7 5 0 6 0 0 6 6 3 8 6 6 4 3 6 0 6 8 6 6 5 3 0 6 0 6 6 1 6 6 6 6 5 7 5\n",
      " 5 7 6 6 6 1 6 4 6 7 6 5 1 5 5 0 6 3 6 6 4 8 3 4 8 6 8 1 0 5 3]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_31928/3220251307.py:6: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments.\n",
      "  one_hot_labels = np.array(F.one_hot(torch.tensor(y), num_classes=num_classes))\n"
     ]
    }
   ],
   "source": [
    "import torch.nn.functional as F\n",
    "\n",
    "y = np.array(labels-1)\n",
    "print(y)\n",
    "num_classes = 9  # Define the number of classes\n",
    "one_hot_labels = np.array(F.one_hot(torch.tensor(y), num_classes=num_classes))\n",
    "\n",
    "y_train, y_test = train_test_split(one_hot_labels,test_size=test_size, shuffle=False)\n",
    "y_train = torch.from_numpy(y_train).to(args.device)\n",
    "y_test = torch.from_numpy(y_test).to(args.device)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create and Train the Continuous Simplcial Neural Network (COSIMO)\n",
    "\n",
    "We specify the model with our pre-made neighborhood structures and specify an optimizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Network(torch.nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        in_channels_all,\n",
    "        hidden_channels_all,\n",
    "        out_channels,\n",
    "        conv_order,\n",
    "        max_rank,\n",
    "        update_func=None,\n",
    "        n_layers=2,\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.base_model = COSIMO(\n",
    "            in_channels_all=in_channels_all,\n",
    "            hidden_channels_all=hidden_channels_all,\n",
    "            conv_order=conv_order,\n",
    "            sc_order=max_rank,\n",
    "            update_func=update_func,\n",
    "            n_layers=n_layers,\n",
    "        )\n",
    "        out_channels_0, _, _ = hidden_channels_all\n",
    "        self.out_linear_0 = torch.nn.Linear(out_channels_0, out_channels)\n",
    "\n",
    "    def forward(self, x_all, eig_eiv_all, incidence_all):\n",
    "        x_all_1 = self.base_model(x_all, eig_eiv_all, incidence_all)\n",
    "        x_0, _, _ = x_all_1\n",
    "\n",
    "        logits = self.out_linear_0(x_0)\n",
    "        return logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Obtain the initial features on all simplices\"\"\"\n",
    "x_all = (x_0.to(args.device), x_1.to(args.device), x_2.to(args.device))\n",
    "\n",
    "# Defining some settings and hyperparameters:\n",
    "conv_order = 2\n",
    "in_channels_all = (x_0.shape[-1], x_1.shape[-1], x_2.shape[-1])\n",
    "intermediate_channels_all = (num_filters, num_filters, num_filters)\n",
    "num_layers = 4\n",
    "out_channels = num_classes  # num classes\n",
    "max_rank = 4\n",
    "\n",
    "eig_eiv_all = (\n",
    "    evals_0.to(args.device), evecs_0.to(args.device),\n",
    "    evals_d1.to(args.device), evecs_d1.to(args.device),\n",
    "    evals_u1.to(args.device), evecs_u1.to(args.device),\n",
    "    evals_d2.to(args.device), evecs_d2.to(args.device),\n",
    "    evals_u2.to(args.device), evecs_u2.to(args.device),\n",
    ")\n",
    "\n",
    "incidence_all = (incidence_1.to(args.device), incidence_2.to(args.device))\n",
    "\n",
    "model = Network(\n",
    "    in_channels_all=in_channels_all,\n",
    "    hidden_channels_all=intermediate_channels_all,\n",
    "    out_channels=out_channels,\n",
    "    conv_order=conv_order,\n",
    "    max_rank=max_rank,\n",
    "    update_func=None,\n",
    "    n_layers=num_layers,\n",
    ")\n",
    "\n",
    "model = model.to(args.device)\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1 loss: 952.1015 Train_acc: 0.11\n",
      "Epoch: 2 loss: 583.0707 Train_acc: 0.17\n",
      "Epoch: 3 loss: 469.8413 Train_acc: 0.15\n",
      "Epoch: 4 loss: 462.0691 Train_acc: 0.11\n",
      "Epoch: 5 loss: 502.5592 Train_acc: 0.14\n",
      "Epoch: 6 loss: 438.2184 Train_acc: 0.14\n",
      "Epoch: 7 loss: 326.4001 Train_acc: 0.11\n",
      "Epoch: 8 loss: 273.2276 Train_acc: 0.14\n",
      "Epoch: 9 loss: 232.0769 Train_acc: 0.25\n",
      "Epoch: 10 loss: 203.6616 Train_acc: 0.21\n",
      "\n",
      "\n",
      "Test_acc: 0.11\n",
      "\n",
      "\n",
      "Epoch: 11 loss: 238.8566 Train_acc: 0.08\n",
      "Epoch: 12 loss: 217.4087 Train_acc: 0.07\n",
      "Epoch: 13 loss: 152.6266 Train_acc: 0.08\n",
      "Epoch: 14 loss: 130.4206 Train_acc: 0.16\n",
      "Epoch: 15 loss: 118.8270 Train_acc: 0.30\n",
      "Epoch: 16 loss: 102.9671 Train_acc: 0.37\n",
      "Epoch: 17 loss: 86.9043 Train_acc: 0.35\n",
      "Epoch: 18 loss: 73.6783 Train_acc: 0.28\n",
      "Epoch: 19 loss: 65.2064 Train_acc: 0.25\n",
      "Epoch: 20 loss: 64.6916 Train_acc: 0.27\n",
      "\n",
      "\n",
      "Test_acc: 0.48\n",
      "\n",
      "\n",
      "Epoch: 21 loss: 78.7434 Train_acc: 0.29\n",
      "Epoch: 22 loss: 84.4630 Train_acc: 0.29\n",
      "Epoch: 23 loss: 76.2616 Train_acc: 0.32\n",
      "Epoch: 24 loss: 66.7467 Train_acc: 0.34\n",
      "Epoch: 25 loss: 65.7768 Train_acc: 0.36\n",
      "Epoch: 26 loss: 66.3114 Train_acc: 0.35\n",
      "Epoch: 27 loss: 60.9964 Train_acc: 0.38\n",
      "Epoch: 28 loss: 51.5145 Train_acc: 0.42\n",
      "Epoch: 29 loss: 40.5620 Train_acc: 0.46\n",
      "Epoch: 30 loss: 32.6030 Train_acc: 0.53\n",
      "\n",
      "\n",
      "Test_acc: 0.48\n",
      "\n",
      "\n",
      "Epoch: 31 loss: 31.9441 Train_acc: 0.51\n",
      "Epoch: 32 loss: 35.2947 Train_acc: 0.42\n",
      "Epoch: 33 loss: 35.9473 Train_acc: 0.40\n",
      "Epoch: 34 loss: 31.3338 Train_acc: 0.45\n",
      "Epoch: 35 loss: 26.0135 Train_acc: 0.50\n",
      "Epoch: 36 loss: 23.9473 Train_acc: 0.56\n",
      "Epoch: 37 loss: 22.8924 Train_acc: 0.56\n",
      "Epoch: 38 loss: 21.9509 Train_acc: 0.55\n",
      "Epoch: 39 loss: 20.7047 Train_acc: 0.54\n",
      "Epoch: 40 loss: 19.0224 Train_acc: 0.54\n",
      "\n",
      "\n",
      "Test_acc: 0.55\n",
      "\n",
      "\n",
      "Epoch: 41 loss: 17.3171 Train_acc: 0.56\n",
      "Epoch: 42 loss: 14.2789 Train_acc: 0.62\n",
      "Epoch: 43 loss: 12.4442 Train_acc: 0.62\n",
      "Epoch: 44 loss: 12.5718 Train_acc: 0.62\n",
      "Epoch: 45 loss: 13.4508 Train_acc: 0.62\n",
      "Epoch: 46 loss: 12.9500 Train_acc: 0.61\n",
      "Epoch: 47 loss: 11.6426 Train_acc: 0.63\n",
      "Epoch: 48 loss: 10.4494 Train_acc: 0.65\n",
      "Epoch: 49 loss: 10.1590 Train_acc: 0.67\n",
      "Epoch: 50 loss: 10.2324 Train_acc: 0.65\n",
      "\n",
      "\n",
      "Test_acc: 0.71\n",
      "\n",
      "\n",
      "Epoch: 51 loss: 9.9834 Train_acc: 0.64\n",
      "Epoch: 52 loss: 9.6351 Train_acc: 0.65\n",
      "Epoch: 53 loss: 9.3733 Train_acc: 0.66\n",
      "Epoch: 54 loss: 8.7374 Train_acc: 0.68\n",
      "Epoch: 55 loss: 8.5248 Train_acc: 0.71\n",
      "Epoch: 56 loss: 8.3955 Train_acc: 0.72\n",
      "Epoch: 57 loss: 7.9774 Train_acc: 0.72\n",
      "Epoch: 58 loss: 7.2754 Train_acc: 0.72\n",
      "Epoch: 59 loss: 6.8163 Train_acc: 0.71\n",
      "Epoch: 60 loss: 6.7290 Train_acc: 0.70\n",
      "\n",
      "\n",
      "Test_acc: 0.79\n",
      "\n",
      "\n",
      "Epoch: 61 loss: 6.6624 Train_acc: 0.71\n",
      "Epoch: 62 loss: 6.3926 Train_acc: 0.72\n",
      "Epoch: 63 loss: 6.2776 Train_acc: 0.73\n",
      "Epoch: 64 loss: 5.9773 Train_acc: 0.75\n",
      "Epoch: 65 loss: 5.6999 Train_acc: 0.75\n",
      "Epoch: 66 loss: 5.7151 Train_acc: 0.75\n",
      "Epoch: 67 loss: 5.7636 Train_acc: 0.76\n",
      "Epoch: 68 loss: 5.6635 Train_acc: 0.78\n",
      "Epoch: 69 loss: 5.4393 Train_acc: 0.78\n",
      "Epoch: 70 loss: 5.2744 Train_acc: 0.79\n",
      "\n",
      "\n",
      "Test_acc: 0.85\n",
      "\n",
      "\n",
      "Epoch: 71 loss: 5.1790 Train_acc: 0.77\n",
      "Epoch: 72 loss: 5.0823 Train_acc: 0.77\n",
      "Epoch: 73 loss: 4.9733 Train_acc: 0.78\n",
      "Epoch: 74 loss: 4.8939 Train_acc: 0.76\n",
      "Epoch: 75 loss: 4.7537 Train_acc: 0.77\n",
      "Epoch: 76 loss: 4.6129 Train_acc: 0.77\n",
      "Epoch: 77 loss: 4.4546 Train_acc: 0.78\n",
      "Epoch: 78 loss: 4.3416 Train_acc: 0.78\n",
      "Epoch: 79 loss: 4.3044 Train_acc: 0.79\n",
      "Epoch: 80 loss: 4.2121 Train_acc: 0.78\n",
      "\n",
      "\n",
      "Test_acc: 0.88\n",
      "\n",
      "\n",
      "Epoch: 81 loss: 4.1698 Train_acc: 0.80\n",
      "Epoch: 82 loss: 4.0897 Train_acc: 0.80\n",
      "Epoch: 83 loss: 4.0074 Train_acc: 0.79\n",
      "Epoch: 84 loss: 3.9272 Train_acc: 0.80\n",
      "Epoch: 85 loss: 3.8294 Train_acc: 0.80\n",
      "Epoch: 86 loss: 3.7371 Train_acc: 0.80\n",
      "Epoch: 87 loss: 3.6990 Train_acc: 0.80\n",
      "Epoch: 88 loss: 3.6396 Train_acc: 0.80\n",
      "Epoch: 89 loss: 3.5657 Train_acc: 0.80\n",
      "Epoch: 90 loss: 3.4825 Train_acc: 0.79\n",
      "\n",
      "\n",
      "Test_acc: 0.85\n",
      "\n",
      "\n",
      "Epoch: 91 loss: 3.4160 Train_acc: 0.82\n",
      "Epoch: 92 loss: 3.3618 Train_acc: 0.82\n",
      "Epoch: 93 loss: 3.2964 Train_acc: 0.82\n",
      "Epoch: 94 loss: 3.2373 Train_acc: 0.82\n",
      "Epoch: 95 loss: 3.1874 Train_acc: 0.82\n",
      "Epoch: 96 loss: 3.1335 Train_acc: 0.83\n",
      "Epoch: 97 loss: 3.0626 Train_acc: 0.84\n",
      "Epoch: 98 loss: 3.0019 Train_acc: 0.83\n",
      "Epoch: 99 loss: 2.9424 Train_acc: 0.83\n",
      "Epoch: 100 loss: 2.8974 Train_acc: 0.83\n",
      "\n",
      "\n",
      "Test_acc: 0.86\n",
      "\n",
      "\n",
      "Epoch: 101 loss: 2.8554 Train_acc: 0.82\n",
      "Epoch: 102 loss: 2.7888 Train_acc: 0.84\n",
      "Epoch: 103 loss: 2.7327 Train_acc: 0.84\n",
      "Epoch: 104 loss: 2.6784 Train_acc: 0.84\n",
      "Epoch: 105 loss: 2.6331 Train_acc: 0.85\n",
      "Epoch: 106 loss: 2.5860 Train_acc: 0.84\n",
      "Epoch: 107 loss: 2.5308 Train_acc: 0.84\n",
      "Epoch: 108 loss: 2.4842 Train_acc: 0.84\n",
      "Epoch: 109 loss: 2.4301 Train_acc: 0.85\n",
      "Epoch: 110 loss: 2.3846 Train_acc: 0.85\n",
      "\n",
      "\n",
      "Test_acc: 0.89\n",
      "\n",
      "\n",
      "Epoch: 111 loss: 2.3407 Train_acc: 0.85\n",
      "Epoch: 112 loss: 2.2901 Train_acc: 0.84\n",
      "Epoch: 113 loss: 2.2398 Train_acc: 0.85\n",
      "Epoch: 114 loss: 2.1957 Train_acc: 0.85\n",
      "Epoch: 115 loss: 2.1454 Train_acc: 0.85\n",
      "Epoch: 116 loss: 2.1056 Train_acc: 0.85\n",
      "Epoch: 117 loss: 2.0602 Train_acc: 0.86\n",
      "Epoch: 118 loss: 2.0165 Train_acc: 0.86\n",
      "Epoch: 119 loss: 1.9756 Train_acc: 0.85\n",
      "Epoch: 120 loss: 1.9317 Train_acc: 0.85\n",
      "\n",
      "\n",
      "Test_acc: 0.89\n",
      "\n",
      "\n",
      "Epoch: 121 loss: 1.8923 Train_acc: 0.86\n",
      "Epoch: 122 loss: 1.8503 Train_acc: 0.87\n",
      "Epoch: 123 loss: 1.8178 Train_acc: 0.85\n",
      "Epoch: 124 loss: 1.7691 Train_acc: 0.87\n",
      "Epoch: 125 loss: 1.7406 Train_acc: 0.87\n",
      "Epoch: 126 loss: 1.6925 Train_acc: 0.87\n",
      "Epoch: 127 loss: 1.6647 Train_acc: 0.86\n",
      "Epoch: 128 loss: 1.6252 Train_acc: 0.87\n",
      "Epoch: 129 loss: 1.5947 Train_acc: 0.87\n",
      "Epoch: 130 loss: 1.5547 Train_acc: 0.87\n",
      "\n",
      "\n",
      "Test_acc: 0.88\n",
      "\n",
      "\n",
      "Epoch: 131 loss: 1.5261 Train_acc: 0.87\n",
      "Epoch: 132 loss: 1.4949 Train_acc: 0.86\n",
      "Epoch: 133 loss: 1.4555 Train_acc: 0.87\n",
      "Epoch: 134 loss: 1.4265 Train_acc: 0.88\n",
      "Epoch: 135 loss: 1.3922 Train_acc: 0.87\n",
      "Epoch: 136 loss: 1.3636 Train_acc: 0.87\n",
      "Epoch: 137 loss: 1.3327 Train_acc: 0.88\n",
      "Epoch: 138 loss: 1.3034 Train_acc: 0.88\n",
      "Epoch: 139 loss: 1.2771 Train_acc: 0.87\n",
      "Epoch: 140 loss: 1.2472 Train_acc: 0.88\n",
      "\n",
      "\n",
      "Test_acc: 0.88\n",
      "\n",
      "\n",
      "Epoch: 141 loss: 1.2223 Train_acc: 0.88\n",
      "Epoch: 142 loss: 1.1927 Train_acc: 0.88\n",
      "Epoch: 143 loss: 1.1668 Train_acc: 0.88\n",
      "Epoch: 144 loss: 1.1407 Train_acc: 0.88\n",
      "Epoch: 145 loss: 1.1140 Train_acc: 0.89\n",
      "Epoch: 146 loss: 1.0890 Train_acc: 0.89\n",
      "Epoch: 147 loss: 1.0635 Train_acc: 0.89\n",
      "Epoch: 148 loss: 1.0379 Train_acc: 0.89\n",
      "Epoch: 149 loss: 1.0129 Train_acc: 0.89\n",
      "Epoch: 150 loss: 0.9876 Train_acc: 0.90\n",
      "\n",
      "\n",
      "Test_acc: 0.89\n",
      "\n",
      "\n",
      "Epoch: 151 loss: 0.9617 Train_acc: 0.89\n",
      "Epoch: 152 loss: 0.9356 Train_acc: 0.90\n",
      "Epoch: 153 loss: 0.9095 Train_acc: 0.90\n",
      "Epoch: 154 loss: 0.8838 Train_acc: 0.90\n",
      "Epoch: 155 loss: 0.8589 Train_acc: 0.90\n",
      "Epoch: 156 loss: 0.8344 Train_acc: 0.90\n",
      "Epoch: 157 loss: 0.8114 Train_acc: 0.90\n",
      "Epoch: 158 loss: 0.7878 Train_acc: 0.90\n",
      "Epoch: 159 loss: 0.7640 Train_acc: 0.90\n",
      "Epoch: 160 loss: 0.7399 Train_acc: 0.90\n",
      "\n",
      "\n",
      "Test_acc: 0.89\n",
      "\n",
      "\n",
      "Epoch: 161 loss: 0.7154 Train_acc: 0.90\n",
      "Epoch: 162 loss: 0.6900 Train_acc: 0.91\n",
      "Epoch: 163 loss: 0.6643 Train_acc: 0.91\n",
      "Epoch: 164 loss: 0.6399 Train_acc: 0.91\n",
      "Epoch: 165 loss: 0.6203 Train_acc: 0.92\n",
      "Epoch: 166 loss: 0.5980 Train_acc: 0.92\n",
      "Epoch: 167 loss: 0.5830 Train_acc: 0.92\n",
      "Epoch: 168 loss: 0.5676 Train_acc: 0.93\n",
      "Epoch: 169 loss: 0.5536 Train_acc: 0.93\n",
      "Epoch: 170 loss: 0.5399 Train_acc: 0.93\n",
      "\n",
      "\n",
      "Test_acc: 0.91\n",
      "\n",
      "\n",
      "Epoch: 171 loss: 0.5245 Train_acc: 0.93\n",
      "Epoch: 172 loss: 0.5116 Train_acc: 0.93\n",
      "Epoch: 173 loss: 0.4981 Train_acc: 0.93\n",
      "Epoch: 174 loss: 0.4854 Train_acc: 0.93\n",
      "Epoch: 175 loss: 0.4726 Train_acc: 0.93\n",
      "Epoch: 176 loss: 0.4595 Train_acc: 0.93\n",
      "Epoch: 177 loss: 0.4467 Train_acc: 0.93\n",
      "Epoch: 178 loss: 0.4348 Train_acc: 0.93\n",
      "Epoch: 179 loss: 0.4233 Train_acc: 0.93\n",
      "Epoch: 180 loss: 0.4122 Train_acc: 0.93\n",
      "\n",
      "\n",
      "Test_acc: 0.94\n",
      "\n",
      "\n",
      "Epoch: 181 loss: 0.4009 Train_acc: 0.93\n",
      "Epoch: 182 loss: 0.3915 Train_acc: 0.94\n",
      "Epoch: 183 loss: 0.3800 Train_acc: 0.93\n",
      "Epoch: 184 loss: 0.3713 Train_acc: 0.94\n",
      "Epoch: 185 loss: 0.3604 Train_acc: 0.94\n",
      "Epoch: 186 loss: 0.3510 Train_acc: 0.94\n",
      "Epoch: 187 loss: 0.3427 Train_acc: 0.95\n",
      "Epoch: 188 loss: 0.3347 Train_acc: 0.95\n",
      "Epoch: 189 loss: 0.3272 Train_acc: 0.95\n",
      "Epoch: 190 loss: 0.3205 Train_acc: 0.95\n",
      "\n",
      "\n",
      "Test_acc: 0.94\n",
      "\n",
      "\n",
      "Epoch: 191 loss: 0.3161 Train_acc: 0.95\n",
      "Epoch: 192 loss: 0.3368 Train_acc: 0.94\n",
      "Epoch: 193 loss: 0.3156 Train_acc: 0.94\n",
      "Epoch: 194 loss: 0.3336 Train_acc: 0.95\n",
      "Epoch: 195 loss: 0.3339 Train_acc: 0.95\n",
      "Epoch: 196 loss: 0.3366 Train_acc: 0.94\n",
      "Epoch: 197 loss: 0.3140 Train_acc: 0.95\n",
      "Epoch: 198 loss: 0.3108 Train_acc: 0.95\n",
      "Epoch: 199 loss: 0.2952 Train_acc: 0.95\n",
      "Epoch: 200 loss: 0.3004 Train_acc: 0.95\n",
      "\n",
      "\n",
      "Test_acc: 0.92\n",
      "\n",
      "\n",
      "Epoch: 201 loss: 0.2865 Train_acc: 0.95\n",
      "Epoch: 202 loss: 0.2946 Train_acc: 0.95\n",
      "Epoch: 203 loss: 0.2694 Train_acc: 0.95\n",
      "Epoch: 204 loss: 0.2912 Train_acc: 0.95\n",
      "Epoch: 205 loss: 0.2690 Train_acc: 0.95\n",
      "Epoch: 206 loss: 0.2802 Train_acc: 0.95\n",
      "Epoch: 207 loss: 0.2662 Train_acc: 0.95\n",
      "Epoch: 208 loss: 0.2625 Train_acc: 0.95\n",
      "Epoch: 209 loss: 0.2678 Train_acc: 0.94\n",
      "Epoch: 210 loss: 0.2650 Train_acc: 0.96\n",
      "\n",
      "\n",
      "Test_acc: 0.91\n",
      "\n",
      "\n",
      "Epoch: 211 loss: 0.2596 Train_acc: 0.96\n",
      "Epoch: 212 loss: 0.2680 Train_acc: 0.95\n",
      "Epoch: 213 loss: 0.2523 Train_acc: 0.96\n",
      "Epoch: 214 loss: 0.2555 Train_acc: 0.95\n",
      "Epoch: 215 loss: 0.2616 Train_acc: 0.94\n",
      "Epoch: 216 loss: 0.2458 Train_acc: 0.96\n",
      "Epoch: 217 loss: 0.2573 Train_acc: 0.96\n",
      "Epoch: 218 loss: 0.2501 Train_acc: 0.96\n",
      "Epoch: 219 loss: 0.2474 Train_acc: 0.96\n",
      "Epoch: 220 loss: 0.2453 Train_acc: 0.96\n",
      "\n",
      "\n",
      "Test_acc: 0.89\n",
      "\n",
      "\n",
      "Epoch: 221 loss: 0.2475 Train_acc: 0.95\n",
      "Epoch: 222 loss: 0.2381 Train_acc: 0.97\n",
      "Epoch: 223 loss: 0.2411 Train_acc: 0.97\n",
      "Epoch: 224 loss: 0.2421 Train_acc: 0.96\n",
      "Epoch: 225 loss: 0.2372 Train_acc: 0.97\n",
      "Epoch: 226 loss: 0.2319 Train_acc: 0.97\n",
      "Epoch: 227 loss: 0.2379 Train_acc: 0.95\n",
      "Epoch: 228 loss: 0.2293 Train_acc: 0.97\n",
      "Epoch: 229 loss: 0.2314 Train_acc: 0.97\n",
      "Epoch: 230 loss: 0.2312 Train_acc: 0.97\n",
      "\n",
      "\n",
      "Test_acc: 0.91\n",
      "\n",
      "\n",
      "Epoch: 231 loss: 0.2275 Train_acc: 0.97\n",
      "Epoch: 232 loss: 0.2252 Train_acc: 0.97\n",
      "Epoch: 233 loss: 0.2255 Train_acc: 0.96\n",
      "Epoch: 234 loss: 0.2230 Train_acc: 0.96\n",
      "Epoch: 235 loss: 0.2222 Train_acc: 0.96\n",
      "Epoch: 236 loss: 0.2202 Train_acc: 0.97\n",
      "Epoch: 237 loss: 0.2192 Train_acc: 0.97\n",
      "Epoch: 238 loss: 0.2177 Train_acc: 0.97\n",
      "Epoch: 239 loss: 0.2157 Train_acc: 0.97\n",
      "Epoch: 240 loss: 0.2152 Train_acc: 0.97\n",
      "\n",
      "\n",
      "Test_acc: 0.92\n",
      "\n",
      "\n",
      "Epoch: 241 loss: 0.2142 Train_acc: 0.97\n",
      "Epoch: 242 loss: 0.2118 Train_acc: 0.97\n",
      "Epoch: 243 loss: 0.2110 Train_acc: 0.97\n",
      "Epoch: 244 loss: 0.2103 Train_acc: 0.97\n",
      "Epoch: 245 loss: 0.2082 Train_acc: 0.97\n",
      "Epoch: 246 loss: 0.2076 Train_acc: 0.97\n",
      "Epoch: 247 loss: 0.2067 Train_acc: 0.97\n",
      "Epoch: 248 loss: 0.2051 Train_acc: 0.97\n",
      "Epoch: 249 loss: 0.2041 Train_acc: 0.97\n",
      "Epoch: 250 loss: 0.2035 Train_acc: 0.97\n",
      "\n",
      "\n",
      "Test_acc: 0.92\n",
      "\n",
      "\n",
      "Epoch: 251 loss: 0.2020 Train_acc: 0.97\n",
      "Epoch: 252 loss: 0.2010 Train_acc: 0.97\n",
      "Epoch: 253 loss: 0.2002 Train_acc: 0.97\n",
      "Epoch: 254 loss: 0.1991 Train_acc: 0.97\n",
      "Epoch: 255 loss: 0.1979 Train_acc: 0.97\n",
      "Epoch: 256 loss: 0.1970 Train_acc: 0.97\n",
      "Epoch: 257 loss: 0.1961 Train_acc: 0.97\n",
      "Epoch: 258 loss: 0.1949 Train_acc: 0.97\n",
      "Epoch: 259 loss: 0.1940 Train_acc: 0.97\n",
      "Epoch: 260 loss: 0.1930 Train_acc: 0.97\n",
      "\n",
      "\n",
      "Test_acc: 0.92\n",
      "\n",
      "\n",
      "Epoch: 261 loss: 0.1920 Train_acc: 0.97\n",
      "Epoch: 262 loss: 0.1909 Train_acc: 0.97\n",
      "Epoch: 263 loss: 0.1901 Train_acc: 0.97\n",
      "Epoch: 264 loss: 0.1890 Train_acc: 0.97\n",
      "Epoch: 265 loss: 0.1880 Train_acc: 0.97\n",
      "Epoch: 266 loss: 0.1871 Train_acc: 0.97\n",
      "Epoch: 267 loss: 0.1862 Train_acc: 0.97\n",
      "Epoch: 268 loss: 0.1852 Train_acc: 0.97\n",
      "Epoch: 269 loss: 0.1843 Train_acc: 0.97\n",
      "Epoch: 270 loss: 0.1834 Train_acc: 0.97\n",
      "\n",
      "\n",
      "Test_acc: 0.92\n",
      "\n",
      "\n",
      "Epoch: 271 loss: 0.1824 Train_acc: 0.97\n",
      "Epoch: 272 loss: 0.1815 Train_acc: 0.97\n",
      "Epoch: 273 loss: 0.1806 Train_acc: 0.97\n",
      "Epoch: 274 loss: 0.1797 Train_acc: 0.97\n",
      "Epoch: 275 loss: 0.1787 Train_acc: 0.97\n",
      "Epoch: 276 loss: 0.1779 Train_acc: 0.97\n",
      "Epoch: 277 loss: 0.1770 Train_acc: 0.97\n",
      "Epoch: 278 loss: 0.1761 Train_acc: 0.97\n",
      "Epoch: 279 loss: 0.1752 Train_acc: 0.97\n",
      "Epoch: 280 loss: 0.1743 Train_acc: 0.97\n",
      "\n",
      "\n",
      "Test_acc: 0.92\n",
      "\n",
      "\n",
      "Epoch: 281 loss: 0.1734 Train_acc: 0.97\n",
      "Epoch: 282 loss: 0.1725 Train_acc: 0.97\n",
      "Epoch: 283 loss: 0.1717 Train_acc: 0.97\n",
      "Epoch: 284 loss: 0.1708 Train_acc: 0.97\n",
      "Epoch: 285 loss: 0.1700 Train_acc: 0.97\n",
      "Epoch: 286 loss: 0.1691 Train_acc: 0.97\n",
      "Epoch: 287 loss: 0.1683 Train_acc: 0.97\n",
      "Epoch: 288 loss: 0.1674 Train_acc: 0.97\n",
      "Epoch: 289 loss: 0.1666 Train_acc: 0.97\n",
      "Epoch: 290 loss: 0.1657 Train_acc: 0.97\n",
      "\n",
      "\n",
      "Test_acc: 0.92\n",
      "\n",
      "\n",
      "Epoch: 291 loss: 0.1649 Train_acc: 0.97\n",
      "Epoch: 292 loss: 0.1640 Train_acc: 0.97\n",
      "Epoch: 293 loss: 0.1632 Train_acc: 0.97\n",
      "Epoch: 294 loss: 0.1624 Train_acc: 0.97\n",
      "Epoch: 295 loss: 0.1616 Train_acc: 0.97\n",
      "Epoch: 296 loss: 0.1607 Train_acc: 0.97\n",
      "Epoch: 297 loss: 0.1599 Train_acc: 0.97\n",
      "Epoch: 298 loss: 0.1591 Train_acc: 0.97\n",
      "Epoch: 299 loss: 0.1583 Train_acc: 0.97\n",
      "Epoch: 300 loss: 0.1575 Train_acc: 0.97\n",
      "\n",
      "\n",
      "Test_acc: 0.92\n",
      "\n",
      "\n",
      "Epoch: 301 loss: 0.1567 Train_acc: 0.97\n",
      "Epoch: 302 loss: 0.1559 Train_acc: 0.97\n",
      "Epoch: 303 loss: 0.1551 Train_acc: 0.97\n",
      "Epoch: 304 loss: 0.1543 Train_acc: 0.97\n",
      "Epoch: 305 loss: 0.1535 Train_acc: 0.97\n",
      "Epoch: 306 loss: 0.1527 Train_acc: 0.97\n",
      "Epoch: 307 loss: 0.1519 Train_acc: 0.97\n",
      "Epoch: 308 loss: 0.1511 Train_acc: 0.97\n",
      "Epoch: 309 loss: 0.1503 Train_acc: 0.97\n",
      "Epoch: 310 loss: 0.1495 Train_acc: 0.97\n",
      "\n",
      "\n",
      "Test_acc: 0.92\n",
      "\n",
      "\n",
      "Epoch: 311 loss: 0.1488 Train_acc: 0.97\n",
      "Epoch: 312 loss: 0.1480 Train_acc: 0.97\n",
      "Epoch: 313 loss: 0.1472 Train_acc: 0.97\n",
      "Epoch: 314 loss: 0.1465 Train_acc: 0.97\n",
      "Epoch: 315 loss: 0.1457 Train_acc: 0.97\n",
      "Epoch: 316 loss: 0.1449 Train_acc: 0.97\n",
      "Epoch: 317 loss: 0.1442 Train_acc: 0.97\n",
      "Epoch: 318 loss: 0.1434 Train_acc: 0.97\n",
      "Epoch: 319 loss: 0.1427 Train_acc: 0.97\n",
      "Epoch: 320 loss: 0.1419 Train_acc: 0.97\n",
      "\n",
      "\n",
      "Test_acc: 0.92\n",
      "\n",
      "\n",
      "Epoch: 321 loss: 0.1412 Train_acc: 0.97\n",
      "Epoch: 322 loss: 0.1404 Train_acc: 0.97\n",
      "Epoch: 323 loss: 0.1397 Train_acc: 0.97\n",
      "Epoch: 324 loss: 0.1389 Train_acc: 0.97\n",
      "Epoch: 325 loss: 0.1382 Train_acc: 0.97\n",
      "Epoch: 326 loss: 0.1375 Train_acc: 0.97\n",
      "Epoch: 327 loss: 0.1367 Train_acc: 0.97\n",
      "Epoch: 328 loss: 0.1360 Train_acc: 0.97\n",
      "Epoch: 329 loss: 0.1353 Train_acc: 0.97\n",
      "Epoch: 330 loss: 0.1346 Train_acc: 0.97\n",
      "\n",
      "\n",
      "Test_acc: 0.92\n",
      "\n",
      "\n",
      "Epoch: 331 loss: 0.1338 Train_acc: 0.97\n",
      "Epoch: 332 loss: 0.1331 Train_acc: 0.97\n",
      "Epoch: 333 loss: 0.1324 Train_acc: 0.97\n",
      "Epoch: 334 loss: 0.1317 Train_acc: 0.97\n",
      "Epoch: 335 loss: 0.1310 Train_acc: 0.97\n",
      "Epoch: 336 loss: 0.1303 Train_acc: 0.97\n",
      "Epoch: 337 loss: 0.1296 Train_acc: 0.97\n",
      "Epoch: 338 loss: 0.1288 Train_acc: 0.97\n",
      "Epoch: 339 loss: 0.1281 Train_acc: 0.97\n",
      "Epoch: 340 loss: 0.1274 Train_acc: 0.97\n",
      "\n",
      "\n",
      "Test_acc: 0.92\n",
      "\n",
      "\n",
      "Epoch: 341 loss: 0.1267 Train_acc: 0.97\n",
      "Epoch: 342 loss: 0.1260 Train_acc: 0.97\n",
      "Epoch: 343 loss: 0.1253 Train_acc: 0.97\n",
      "Epoch: 344 loss: 0.1246 Train_acc: 0.97\n",
      "Epoch: 345 loss: 0.1240 Train_acc: 0.97\n",
      "Epoch: 346 loss: 0.1233 Train_acc: 0.98\n",
      "Epoch: 347 loss: 0.1226 Train_acc: 0.98\n",
      "Epoch: 348 loss: 0.1219 Train_acc: 0.98\n",
      "Epoch: 349 loss: 0.1212 Train_acc: 0.98\n",
      "Epoch: 350 loss: 0.1205 Train_acc: 0.98\n",
      "\n",
      "\n",
      "Test_acc: 0.92\n",
      "\n",
      "\n",
      "Epoch: 351 loss: 0.1198 Train_acc: 0.98\n",
      "Epoch: 352 loss: 0.1191 Train_acc: 0.98\n",
      "Epoch: 353 loss: 0.1185 Train_acc: 0.98\n",
      "Epoch: 354 loss: 0.1178 Train_acc: 0.98\n",
      "Epoch: 355 loss: 0.1171 Train_acc: 0.98\n",
      "Epoch: 356 loss: 0.1164 Train_acc: 0.98\n",
      "Epoch: 357 loss: 0.1157 Train_acc: 0.98\n",
      "Epoch: 358 loss: 0.1151 Train_acc: 0.98\n",
      "Epoch: 359 loss: 0.1144 Train_acc: 0.98\n",
      "Epoch: 360 loss: 0.1137 Train_acc: 0.98\n",
      "\n",
      "\n",
      "Test_acc: 0.92\n",
      "\n",
      "\n",
      "Epoch: 361 loss: 0.1131 Train_acc: 0.98\n",
      "Epoch: 362 loss: 0.1124 Train_acc: 0.98\n",
      "Epoch: 363 loss: 0.1117 Train_acc: 0.98\n",
      "Epoch: 364 loss: 0.1111 Train_acc: 0.98\n",
      "Epoch: 365 loss: 0.1104 Train_acc: 0.98\n",
      "Epoch: 366 loss: 0.1097 Train_acc: 0.98\n",
      "Epoch: 367 loss: 0.1091 Train_acc: 0.98\n",
      "Epoch: 368 loss: 0.1084 Train_acc: 0.98\n",
      "Epoch: 369 loss: 0.1077 Train_acc: 0.98\n",
      "Epoch: 370 loss: 0.1071 Train_acc: 0.98\n",
      "\n",
      "\n",
      "Test_acc: 0.92\n",
      "\n",
      "\n",
      "Epoch: 371 loss: 0.1064 Train_acc: 0.98\n",
      "Epoch: 372 loss: 0.1058 Train_acc: 0.98\n",
      "Epoch: 373 loss: 0.1051 Train_acc: 0.98\n",
      "Epoch: 374 loss: 0.1045 Train_acc: 0.98\n",
      "Epoch: 375 loss: 0.1038 Train_acc: 0.98\n",
      "Epoch: 376 loss: 0.1032 Train_acc: 0.98\n",
      "Epoch: 377 loss: 0.1025 Train_acc: 0.98\n",
      "Epoch: 378 loss: 0.1019 Train_acc: 0.98\n",
      "Epoch: 379 loss: 0.1012 Train_acc: 0.98\n",
      "Epoch: 380 loss: 0.1006 Train_acc: 0.98\n",
      "\n",
      "\n",
      "Test_acc: 0.92\n",
      "\n",
      "\n",
      "Epoch: 381 loss: 0.0999 Train_acc: 0.98\n",
      "Epoch: 382 loss: 0.0993 Train_acc: 0.98\n",
      "Epoch: 383 loss: 0.0986 Train_acc: 0.98\n",
      "Epoch: 384 loss: 0.0980 Train_acc: 0.98\n",
      "Epoch: 385 loss: 0.0974 Train_acc: 0.98\n",
      "Epoch: 386 loss: 0.0967 Train_acc: 0.98\n",
      "Epoch: 387 loss: 0.0961 Train_acc: 0.98\n",
      "Epoch: 388 loss: 0.0955 Train_acc: 0.98\n",
      "Epoch: 389 loss: 0.0948 Train_acc: 0.98\n",
      "Epoch: 390 loss: 0.0942 Train_acc: 0.98\n",
      "\n",
      "\n",
      "Test_acc: 0.92\n",
      "\n",
      "\n",
      "Epoch: 391 loss: 0.0936 Train_acc: 0.98\n",
      "Epoch: 392 loss: 0.0930 Train_acc: 0.98\n",
      "Epoch: 393 loss: 0.0924 Train_acc: 0.98\n",
      "Epoch: 394 loss: 0.0917 Train_acc: 0.98\n",
      "Epoch: 395 loss: 0.0911 Train_acc: 0.98\n",
      "Epoch: 396 loss: 0.0905 Train_acc: 0.98\n",
      "Epoch: 397 loss: 0.0899 Train_acc: 0.98\n",
      "Epoch: 398 loss: 0.0893 Train_acc: 0.98\n",
      "Epoch: 399 loss: 0.0888 Train_acc: 0.98\n",
      "Epoch: 400 loss: 0.0882 Train_acc: 0.98\n",
      "\n",
      "\n",
      "Test_acc: 0.92\n",
      "\n",
      "\n",
      "Epoch: 401 loss: 0.0876 Train_acc: 0.98\n",
      "Epoch: 402 loss: 0.0870 Train_acc: 0.98\n",
      "Epoch: 403 loss: 0.0865 Train_acc: 0.98\n",
      "Epoch: 404 loss: 0.0859 Train_acc: 0.98\n",
      "Epoch: 405 loss: 0.0854 Train_acc: 0.98\n",
      "Epoch: 406 loss: 0.0849 Train_acc: 0.98\n",
      "Epoch: 407 loss: 0.0843 Train_acc: 0.98\n",
      "Epoch: 408 loss: 0.0838 Train_acc: 0.98\n",
      "Epoch: 409 loss: 0.0833 Train_acc: 0.98\n",
      "Epoch: 410 loss: 0.0828 Train_acc: 0.98\n",
      "\n",
      "\n",
      "Test_acc: 0.92\n",
      "\n",
      "\n",
      "Epoch: 411 loss: 0.0824 Train_acc: 0.98\n",
      "Epoch: 412 loss: 0.0819 Train_acc: 0.98\n",
      "Epoch: 413 loss: 0.0814 Train_acc: 0.98\n",
      "Epoch: 414 loss: 0.0810 Train_acc: 0.98\n",
      "Epoch: 415 loss: 0.0806 Train_acc: 0.98\n",
      "Epoch: 416 loss: 0.0801 Train_acc: 0.98\n",
      "Epoch: 417 loss: 0.0797 Train_acc: 0.98\n",
      "Epoch: 418 loss: 0.0793 Train_acc: 0.98\n",
      "Epoch: 419 loss: 0.0789 Train_acc: 0.98\n",
      "Epoch: 420 loss: 0.0786 Train_acc: 0.98\n",
      "\n",
      "\n",
      "Test_acc: 0.92\n",
      "\n",
      "\n",
      "Epoch: 421 loss: 0.0782 Train_acc: 0.98\n",
      "Epoch: 422 loss: 0.0778 Train_acc: 0.98\n",
      "Epoch: 423 loss: 0.0775 Train_acc: 0.98\n",
      "Epoch: 424 loss: 0.0772 Train_acc: 0.98\n",
      "Epoch: 425 loss: 0.0768 Train_acc: 0.98\n",
      "Epoch: 426 loss: 0.0765 Train_acc: 0.98\n",
      "Epoch: 427 loss: 0.0762 Train_acc: 0.98\n",
      "Epoch: 428 loss: 0.0759 Train_acc: 0.98\n",
      "Epoch: 429 loss: 0.0756 Train_acc: 0.98\n",
      "Epoch: 430 loss: 0.0753 Train_acc: 0.98\n",
      "\n",
      "\n",
      "Test_acc: 0.92\n",
      "\n",
      "\n",
      "Epoch: 431 loss: 0.0751 Train_acc: 0.98\n",
      "Epoch: 432 loss: 0.0748 Train_acc: 0.98\n",
      "Epoch: 433 loss: 0.0745 Train_acc: 0.98\n",
      "Epoch: 434 loss: 0.0743 Train_acc: 0.98\n",
      "Epoch: 435 loss: 0.0740 Train_acc: 0.98\n",
      "Epoch: 436 loss: 0.0737 Train_acc: 0.98\n",
      "Epoch: 437 loss: 0.0735 Train_acc: 0.98\n",
      "Epoch: 438 loss: 0.0733 Train_acc: 0.98\n",
      "Epoch: 439 loss: 0.0730 Train_acc: 0.98\n",
      "Epoch: 440 loss: 0.0728 Train_acc: 0.98\n",
      "\n",
      "\n",
      "Test_acc: 0.92\n",
      "\n",
      "\n",
      "Epoch: 441 loss: 0.0726 Train_acc: 0.98\n",
      "Epoch: 442 loss: 0.0723 Train_acc: 0.98\n",
      "Epoch: 443 loss: 0.0721 Train_acc: 0.98\n",
      "Epoch: 444 loss: 0.0719 Train_acc: 0.98\n",
      "Epoch: 445 loss: 0.0717 Train_acc: 0.98\n",
      "Epoch: 446 loss: 0.0715 Train_acc: 0.98\n",
      "Epoch: 447 loss: 0.0713 Train_acc: 0.98\n",
      "Epoch: 448 loss: 0.0711 Train_acc: 0.98\n",
      "Epoch: 449 loss: 0.0709 Train_acc: 0.98\n",
      "Epoch: 450 loss: 0.0707 Train_acc: 0.98\n",
      "\n",
      "\n",
      "Test_acc: 0.92\n",
      "\n",
      "\n",
      "Epoch: 451 loss: 0.0705 Train_acc: 0.98\n",
      "Epoch: 452 loss: 0.0703 Train_acc: 0.98\n",
      "Epoch: 453 loss: 0.0701 Train_acc: 0.98\n",
      "Epoch: 454 loss: 0.0699 Train_acc: 0.98\n",
      "Epoch: 455 loss: 0.0697 Train_acc: 0.98\n",
      "Epoch: 456 loss: 0.0695 Train_acc: 0.98\n",
      "Epoch: 457 loss: 0.0693 Train_acc: 0.98\n",
      "Epoch: 458 loss: 0.0691 Train_acc: 0.98\n",
      "Epoch: 459 loss: 0.0690 Train_acc: 0.98\n",
      "Epoch: 460 loss: 0.0688 Train_acc: 0.98\n",
      "\n",
      "\n",
      "Test_acc: 0.92\n",
      "\n",
      "\n",
      "Epoch: 461 loss: 0.0686 Train_acc: 0.98\n",
      "Epoch: 462 loss: 0.0684 Train_acc: 0.98\n",
      "Epoch: 463 loss: 0.0683 Train_acc: 0.98\n",
      "Epoch: 464 loss: 0.0681 Train_acc: 0.98\n",
      "Epoch: 465 loss: 0.0679 Train_acc: 0.98\n",
      "Epoch: 466 loss: 0.0678 Train_acc: 0.98\n",
      "Epoch: 467 loss: 0.0676 Train_acc: 0.98\n",
      "Epoch: 468 loss: 0.0674 Train_acc: 0.98\n",
      "Epoch: 469 loss: 0.0673 Train_acc: 0.98\n",
      "Epoch: 470 loss: 0.0671 Train_acc: 0.98\n",
      "\n",
      "\n",
      "Test_acc: 0.92\n",
      "\n",
      "\n",
      "Epoch: 471 loss: 0.0669 Train_acc: 0.98\n",
      "Epoch: 472 loss: 0.0668 Train_acc: 0.98\n",
      "Epoch: 473 loss: 0.0666 Train_acc: 0.98\n",
      "Epoch: 474 loss: 0.0665 Train_acc: 0.98\n",
      "Epoch: 475 loss: 0.0663 Train_acc: 0.98\n",
      "Epoch: 476 loss: 0.0661 Train_acc: 0.98\n",
      "Epoch: 477 loss: 0.0660 Train_acc: 0.98\n",
      "Epoch: 478 loss: 0.0658 Train_acc: 0.98\n",
      "Epoch: 479 loss: 0.0657 Train_acc: 0.98\n",
      "Epoch: 480 loss: 0.0655 Train_acc: 0.98\n",
      "\n",
      "\n",
      "Test_acc: 0.92\n",
      "\n",
      "\n",
      "Epoch: 481 loss: 0.0654 Train_acc: 0.98\n",
      "Epoch: 482 loss: 0.0652 Train_acc: 0.98\n",
      "Epoch: 483 loss: 0.0651 Train_acc: 0.98\n",
      "Epoch: 484 loss: 0.0649 Train_acc: 0.98\n",
      "Epoch: 485 loss: 0.0648 Train_acc: 0.98\n",
      "Epoch: 486 loss: 0.0646 Train_acc: 0.98\n",
      "Epoch: 487 loss: 0.0645 Train_acc: 0.98\n",
      "Epoch: 488 loss: 0.0643 Train_acc: 0.98\n",
      "Epoch: 489 loss: 0.0642 Train_acc: 0.98\n",
      "Epoch: 490 loss: 0.0641 Train_acc: 0.98\n",
      "\n",
      "\n",
      "Test_acc: 0.92\n",
      "\n",
      "\n",
      "Epoch: 491 loss: 0.0639 Train_acc: 0.98\n",
      "Epoch: 492 loss: 0.0638 Train_acc: 0.98\n",
      "Epoch: 493 loss: 0.0636 Train_acc: 0.98\n",
      "Epoch: 494 loss: 0.0635 Train_acc: 0.98\n",
      "Epoch: 495 loss: 0.0633 Train_acc: 0.98\n",
      "Epoch: 496 loss: 0.0632 Train_acc: 0.98\n",
      "Epoch: 497 loss: 0.0631 Train_acc: 0.98\n",
      "Epoch: 498 loss: 0.0629 Train_acc: 0.98\n",
      "Epoch: 499 loss: 0.0628 Train_acc: 0.99\n",
      "Epoch: 500 loss: 0.0627 Train_acc: 0.99\n",
      "\n",
      "\n",
      "Test_acc: 0.92\n",
      "\n",
      "\n",
      "Epoch: 501 loss: 0.0625 Train_acc: 0.99\n",
      "Epoch: 502 loss: 0.0624 Train_acc: 0.99\n",
      "Epoch: 503 loss: 0.0622 Train_acc: 0.99\n",
      "Epoch: 504 loss: 0.0621 Train_acc: 0.99\n",
      "Epoch: 505 loss: 0.0620 Train_acc: 0.99\n",
      "Epoch: 506 loss: 0.0618 Train_acc: 0.99\n",
      "Epoch: 507 loss: 0.0617 Train_acc: 0.99\n",
      "Epoch: 508 loss: 0.0616 Train_acc: 0.99\n",
      "Epoch: 509 loss: 0.0615 Train_acc: 0.99\n",
      "Epoch: 510 loss: 0.0613 Train_acc: 0.99\n",
      "\n",
      "\n",
      "Test_acc: 0.92\n",
      "\n",
      "\n",
      "Epoch: 511 loss: 0.0612 Train_acc: 0.99\n",
      "Epoch: 512 loss: 0.0611 Train_acc: 0.99\n",
      "Epoch: 513 loss: 0.0609 Train_acc: 0.99\n",
      "Epoch: 514 loss: 0.0608 Train_acc: 0.99\n",
      "Epoch: 515 loss: 0.0607 Train_acc: 0.99\n",
      "Epoch: 516 loss: 0.0605 Train_acc: 0.99\n",
      "Epoch: 517 loss: 0.0604 Train_acc: 0.99\n",
      "Epoch: 518 loss: 0.0603 Train_acc: 0.99\n",
      "Epoch: 519 loss: 0.0602 Train_acc: 0.99\n",
      "Epoch: 520 loss: 0.0600 Train_acc: 0.99\n",
      "\n",
      "\n",
      "Test_acc: 0.92\n",
      "\n",
      "\n",
      "Epoch: 521 loss: 0.0599 Train_acc: 0.99\n",
      "Epoch: 522 loss: 0.0598 Train_acc: 0.99\n",
      "Epoch: 523 loss: 0.0597 Train_acc: 0.99\n",
      "Epoch: 524 loss: 0.0595 Train_acc: 0.99\n",
      "Epoch: 525 loss: 0.0594 Train_acc: 0.99\n",
      "Epoch: 526 loss: 0.0593 Train_acc: 0.99\n",
      "Epoch: 527 loss: 0.0592 Train_acc: 0.99\n",
      "Epoch: 528 loss: 0.0590 Train_acc: 0.99\n",
      "Epoch: 529 loss: 0.0589 Train_acc: 0.99\n",
      "Epoch: 530 loss: 0.0588 Train_acc: 0.99\n",
      "\n",
      "\n",
      "Test_acc: 0.92\n",
      "\n",
      "\n",
      "Epoch: 531 loss: 0.0587 Train_acc: 0.99\n",
      "Epoch: 532 loss: 0.0586 Train_acc: 0.99\n",
      "Epoch: 533 loss: 0.0584 Train_acc: 0.99\n",
      "Epoch: 534 loss: 0.0583 Train_acc: 0.99\n",
      "Epoch: 535 loss: 0.0582 Train_acc: 0.99\n",
      "Epoch: 536 loss: 0.0581 Train_acc: 0.99\n",
      "Epoch: 537 loss: 0.0580 Train_acc: 0.99\n",
      "Epoch: 538 loss: 0.0578 Train_acc: 0.99\n",
      "Epoch: 539 loss: 0.0577 Train_acc: 0.99\n",
      "Epoch: 540 loss: 0.0576 Train_acc: 0.99\n",
      "\n",
      "\n",
      "Test_acc: 0.91\n",
      "\n",
      "\n",
      "Epoch: 541 loss: 0.0575 Train_acc: 0.99\n",
      "Epoch: 542 loss: 0.0574 Train_acc: 0.99\n",
      "Epoch: 543 loss: 0.0572 Train_acc: 0.99\n",
      "Epoch: 544 loss: 0.0571 Train_acc: 0.99\n",
      "Epoch: 545 loss: 0.0570 Train_acc: 0.99\n",
      "Epoch: 546 loss: 0.0569 Train_acc: 0.99\n",
      "Epoch: 547 loss: 0.0568 Train_acc: 0.99\n",
      "Epoch: 548 loss: 0.0567 Train_acc: 0.99\n",
      "Epoch: 549 loss: 0.0565 Train_acc: 0.99\n",
      "Epoch: 550 loss: 0.0564 Train_acc: 0.99\n",
      "\n",
      "\n",
      "Test_acc: 0.91\n",
      "\n",
      "\n",
      "Epoch: 551 loss: 0.0563 Train_acc: 0.99\n",
      "Epoch: 552 loss: 0.0562 Train_acc: 0.99\n",
      "Epoch: 553 loss: 0.0561 Train_acc: 0.99\n",
      "Epoch: 554 loss: 0.0560 Train_acc: 0.99\n",
      "Epoch: 555 loss: 0.0559 Train_acc: 0.99\n",
      "Epoch: 556 loss: 0.0557 Train_acc: 0.99\n",
      "Epoch: 557 loss: 0.0556 Train_acc: 0.99\n",
      "Epoch: 558 loss: 0.0555 Train_acc: 0.99\n",
      "Epoch: 559 loss: 0.0554 Train_acc: 0.99\n",
      "Epoch: 560 loss: 0.0553 Train_acc: 0.99\n",
      "\n",
      "\n",
      "Test_acc: 0.91\n",
      "\n",
      "\n",
      "Epoch: 561 loss: 0.0552 Train_acc: 0.99\n",
      "Epoch: 562 loss: 0.0551 Train_acc: 0.99\n",
      "Epoch: 563 loss: 0.0550 Train_acc: 0.99\n",
      "Epoch: 564 loss: 0.0548 Train_acc: 0.99\n",
      "Epoch: 565 loss: 0.0547 Train_acc: 0.99\n",
      "Epoch: 566 loss: 0.0546 Train_acc: 0.99\n",
      "Epoch: 567 loss: 0.0545 Train_acc: 0.99\n",
      "Epoch: 568 loss: 0.0544 Train_acc: 0.99\n",
      "Epoch: 569 loss: 0.0543 Train_acc: 0.99\n",
      "Epoch: 570 loss: 0.0542 Train_acc: 0.99\n",
      "\n",
      "\n",
      "Test_acc: 0.91\n",
      "\n",
      "\n",
      "Epoch: 571 loss: 0.0541 Train_acc: 0.99\n",
      "Epoch: 572 loss: 0.0540 Train_acc: 0.99\n",
      "Epoch: 573 loss: 0.0539 Train_acc: 0.99\n",
      "Epoch: 574 loss: 0.0537 Train_acc: 0.99\n",
      "Epoch: 575 loss: 0.0536 Train_acc: 0.99\n",
      "Epoch: 576 loss: 0.0535 Train_acc: 0.99\n",
      "Epoch: 577 loss: 0.0534 Train_acc: 0.99\n",
      "Epoch: 578 loss: 0.0533 Train_acc: 0.99\n",
      "Epoch: 579 loss: 0.0532 Train_acc: 0.99\n",
      "Epoch: 580 loss: 0.0531 Train_acc: 0.99\n",
      "\n",
      "\n",
      "Test_acc: 0.91\n",
      "\n",
      "\n",
      "Epoch: 581 loss: 0.0530 Train_acc: 0.99\n",
      "Epoch: 582 loss: 0.0529 Train_acc: 0.99\n",
      "Epoch: 583 loss: 0.0528 Train_acc: 0.99\n",
      "Epoch: 584 loss: 0.0527 Train_acc: 0.99\n",
      "Epoch: 585 loss: 0.0526 Train_acc: 0.99\n",
      "Epoch: 586 loss: 0.0525 Train_acc: 0.99\n",
      "Epoch: 587 loss: 0.0524 Train_acc: 0.99\n",
      "Epoch: 588 loss: 0.0523 Train_acc: 0.99\n",
      "Epoch: 589 loss: 0.0522 Train_acc: 0.99\n",
      "Epoch: 590 loss: 0.0521 Train_acc: 0.99\n",
      "\n",
      "\n",
      "Test_acc: 0.91\n",
      "\n",
      "\n",
      "Epoch: 591 loss: 0.0520 Train_acc: 0.99\n",
      "Epoch: 592 loss: 0.0519 Train_acc: 0.99\n",
      "Epoch: 593 loss: 0.0517 Train_acc: 0.99\n",
      "Epoch: 594 loss: 0.0516 Train_acc: 0.99\n",
      "Epoch: 595 loss: 0.0515 Train_acc: 0.99\n",
      "Epoch: 596 loss: 0.0514 Train_acc: 0.99\n",
      "Epoch: 597 loss: 0.0513 Train_acc: 0.99\n",
      "Epoch: 598 loss: 0.0512 Train_acc: 0.99\n",
      "Epoch: 599 loss: 0.0511 Train_acc: 0.99\n",
      "Epoch: 600 loss: 0.0511 Train_acc: 0.99\n",
      "\n",
      "\n",
      "Test_acc: 0.91\n",
      "\n",
      "\n",
      "Epoch: 601 loss: 0.0510 Train_acc: 0.99\n",
      "Epoch: 602 loss: 0.0509 Train_acc: 0.99\n",
      "Epoch: 603 loss: 0.0508 Train_acc: 0.99\n",
      "Epoch: 604 loss: 0.0507 Train_acc: 0.99\n",
      "Epoch: 605 loss: 0.0506 Train_acc: 0.99\n",
      "Epoch: 606 loss: 0.0505 Train_acc: 0.99\n",
      "Epoch: 607 loss: 0.0504 Train_acc: 0.99\n",
      "Epoch: 608 loss: 0.0503 Train_acc: 0.99\n",
      "Epoch: 609 loss: 0.0502 Train_acc: 0.99\n",
      "Epoch: 610 loss: 0.0501 Train_acc: 0.99\n",
      "\n",
      "\n",
      "Test_acc: 0.91\n",
      "\n",
      "\n",
      "Epoch: 611 loss: 0.0500 Train_acc: 0.99\n",
      "Epoch: 612 loss: 0.0499 Train_acc: 0.99\n",
      "Epoch: 613 loss: 0.0498 Train_acc: 0.99\n",
      "Epoch: 614 loss: 0.0498 Train_acc: 0.99\n",
      "Epoch: 615 loss: 0.0497 Train_acc: 0.99\n",
      "Epoch: 616 loss: 0.0496 Train_acc: 0.99\n",
      "Epoch: 617 loss: 0.0495 Train_acc: 0.99\n",
      "Epoch: 618 loss: 0.0494 Train_acc: 0.99\n",
      "Epoch: 619 loss: 0.0493 Train_acc: 0.99\n",
      "Epoch: 620 loss: 0.0492 Train_acc: 0.99\n",
      "\n",
      "\n",
      "Test_acc: 0.91\n",
      "\n",
      "\n",
      "Epoch: 621 loss: 0.0492 Train_acc: 0.99\n",
      "Epoch: 622 loss: 0.0491 Train_acc: 0.99\n",
      "Epoch: 623 loss: 0.0490 Train_acc: 0.99\n",
      "Epoch: 624 loss: 0.0489 Train_acc: 0.99\n",
      "Epoch: 625 loss: 0.0488 Train_acc: 0.99\n",
      "Epoch: 626 loss: 0.0488 Train_acc: 0.99\n",
      "Epoch: 627 loss: 0.0487 Train_acc: 0.99\n",
      "Epoch: 628 loss: 0.0486 Train_acc: 0.99\n",
      "Epoch: 629 loss: 0.0485 Train_acc: 0.99\n",
      "Epoch: 630 loss: 0.0484 Train_acc: 0.99\n",
      "\n",
      "\n",
      "Test_acc: 0.91\n",
      "\n",
      "\n",
      "Epoch: 631 loss: 0.0484 Train_acc: 0.99\n",
      "Epoch: 632 loss: 0.0483 Train_acc: 0.99\n",
      "Epoch: 633 loss: 0.0482 Train_acc: 0.99\n",
      "Epoch: 634 loss: 0.0481 Train_acc: 0.99\n",
      "Epoch: 635 loss: 0.0481 Train_acc: 0.99\n",
      "Epoch: 636 loss: 0.0480 Train_acc: 0.99\n",
      "Epoch: 637 loss: 0.0479 Train_acc: 0.99\n",
      "Epoch: 638 loss: 0.0478 Train_acc: 0.99\n",
      "Epoch: 639 loss: 0.0478 Train_acc: 0.99\n",
      "Epoch: 640 loss: 0.0477 Train_acc: 0.99\n",
      "\n",
      "\n",
      "Test_acc: 0.91\n",
      "\n",
      "\n",
      "Epoch: 641 loss: 0.0476 Train_acc: 0.99\n",
      "Epoch: 642 loss: 0.0475 Train_acc: 0.99\n",
      "Epoch: 643 loss: 0.0475 Train_acc: 0.99\n",
      "Epoch: 644 loss: 0.0474 Train_acc: 0.99\n",
      "Epoch: 645 loss: 0.0473 Train_acc: 0.99\n",
      "Epoch: 646 loss: 0.0473 Train_acc: 0.99\n",
      "Epoch: 647 loss: 0.0472 Train_acc: 0.99\n",
      "Epoch: 648 loss: 0.0471 Train_acc: 0.99\n",
      "Epoch: 649 loss: 0.0470 Train_acc: 0.99\n",
      "Epoch: 650 loss: 0.0470 Train_acc: 0.99\n",
      "\n",
      "\n",
      "Test_acc: 0.91\n",
      "\n",
      "\n",
      "Epoch: 651 loss: 0.0469 Train_acc: 0.99\n",
      "Epoch: 652 loss: 0.0468 Train_acc: 0.99\n",
      "Epoch: 653 loss: 0.0468 Train_acc: 0.99\n",
      "Epoch: 654 loss: 0.0467 Train_acc: 0.99\n",
      "Epoch: 655 loss: 0.0466 Train_acc: 0.99\n",
      "Epoch: 656 loss: 0.0466 Train_acc: 0.99\n",
      "Epoch: 657 loss: 0.0465 Train_acc: 0.99\n",
      "Epoch: 658 loss: 0.0464 Train_acc: 0.99\n",
      "Epoch: 659 loss: 0.0463 Train_acc: 0.99\n",
      "Epoch: 660 loss: 0.0463 Train_acc: 0.99\n",
      "\n",
      "\n",
      "Test_acc: 0.91\n",
      "\n",
      "\n",
      "Epoch: 661 loss: 0.0462 Train_acc: 0.99\n",
      "Epoch: 662 loss: 0.0461 Train_acc: 0.99\n",
      "Epoch: 663 loss: 0.0461 Train_acc: 0.99\n",
      "Epoch: 664 loss: 0.0460 Train_acc: 0.99\n",
      "Epoch: 665 loss: 0.0459 Train_acc: 0.99\n",
      "Epoch: 666 loss: 0.0459 Train_acc: 0.99\n",
      "Epoch: 667 loss: 0.0458 Train_acc: 0.99\n",
      "Epoch: 668 loss: 0.0457 Train_acc: 0.99\n",
      "Epoch: 669 loss: 0.0457 Train_acc: 0.99\n",
      "Epoch: 670 loss: 0.0456 Train_acc: 0.99\n",
      "\n",
      "\n",
      "Test_acc: 0.91\n",
      "\n",
      "\n",
      "Epoch: 671 loss: 0.0455 Train_acc: 0.99\n",
      "Epoch: 672 loss: 0.0455 Train_acc: 0.99\n",
      "Epoch: 673 loss: 0.0454 Train_acc: 0.99\n",
      "Epoch: 674 loss: 0.0453 Train_acc: 0.99\n",
      "Epoch: 675 loss: 0.0453 Train_acc: 0.99\n",
      "Epoch: 676 loss: 0.0452 Train_acc: 0.99\n",
      "Epoch: 677 loss: 0.0451 Train_acc: 0.99\n",
      "Epoch: 678 loss: 0.0451 Train_acc: 0.99\n",
      "Epoch: 679 loss: 0.0450 Train_acc: 0.99\n",
      "Epoch: 680 loss: 0.0449 Train_acc: 0.99\n",
      "\n",
      "\n",
      "Test_acc: 0.91\n",
      "\n",
      "\n",
      "Epoch: 681 loss: 0.0449 Train_acc: 0.99\n",
      "Epoch: 682 loss: 0.0448 Train_acc: 0.99\n",
      "Epoch: 683 loss: 0.0447 Train_acc: 0.99\n",
      "Epoch: 684 loss: 0.0447 Train_acc: 0.99\n",
      "Epoch: 685 loss: 0.0446 Train_acc: 0.99\n",
      "Epoch: 686 loss: 0.0446 Train_acc: 0.99\n",
      "Epoch: 687 loss: 0.0445 Train_acc: 0.99\n",
      "Epoch: 688 loss: 0.0444 Train_acc: 0.99\n",
      "Epoch: 689 loss: 0.0444 Train_acc: 0.99\n",
      "Epoch: 690 loss: 0.0443 Train_acc: 0.99\n",
      "\n",
      "\n",
      "Test_acc: 0.91\n",
      "\n",
      "\n",
      "Epoch: 691 loss: 0.0442 Train_acc: 0.99\n",
      "Epoch: 692 loss: 0.0442 Train_acc: 0.99\n",
      "Epoch: 693 loss: 0.0441 Train_acc: 0.99\n",
      "Epoch: 694 loss: 0.0440 Train_acc: 0.99\n",
      "Epoch: 695 loss: 0.0440 Train_acc: 0.99\n",
      "Epoch: 696 loss: 0.0439 Train_acc: 0.99\n",
      "Epoch: 697 loss: 0.0439 Train_acc: 0.99\n",
      "Epoch: 698 loss: 0.0438 Train_acc: 0.99\n",
      "Epoch: 699 loss: 0.0437 Train_acc: 0.99\n",
      "Epoch: 700 loss: 0.0437 Train_acc: 0.99\n",
      "\n",
      "\n",
      "Test_acc: 0.91\n",
      "\n",
      "\n",
      "Test_acc: 0.91\n"
     ]
    }
   ],
   "source": [
    "test_interval = 10\n",
    "num_epochs = 700\n",
    "val_accuracy_best = -1\n",
    "test_accuracy = -1\n",
    "# Define cross-entropy loss function\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "for epoch_i in range(1, num_epochs + 1):\n",
    "    epoch_loss = []\n",
    "    model.train()\n",
    "    optimizer.zero_grad()\n",
    "    y_hat = model(x_all, eig_eiv_all, incidence_all)\n",
    "    \n",
    "    # Compute loss\n",
    "    loss = criterion(y_hat[: len(y_train)], torch.argmax(y_train.float(), dim=1))\n",
    "    \n",
    "    epoch_loss.append(loss.item())\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    probs = torch.softmax(y_hat, dim=1)\n",
    "    # Get predictions (index of the max probability)\n",
    "    y_pred = torch.argmax(probs, dim=1)\n",
    "    correct = (y_pred[: len(y_train)] == torch.argmax(y_train.float(), dim=1)).sum().item()\n",
    "    accuracy = correct / y_train.size(0)\n",
    "    print(\n",
    "        f\"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f} Train_acc: {accuracy:.2f}\",\n",
    "        flush=True,\n",
    "    )\n",
    "\n",
    "    if epoch_i % test_interval == 0:\n",
    "        with torch.no_grad():\n",
    "            y_hat_test = model(x_all, eig_eiv_all, incidence_all)\n",
    "            # Prodiction to node-level:\n",
    "            probs = torch.softmax(y_hat_test, dim=1)\n",
    "            # Get predictions (index of the max probability)\n",
    "            y_pred_test = torch.argmax(probs, dim=1)\n",
    "            correct = (y_pred_test[-len(y_test) :] == torch.argmax(y_test.float(), dim=1)).sum().item()\n",
    "            test_accuracy = correct / y_test.size(0)\n",
    "            \n",
    "            print()\n",
    "            print()\n",
    "            print(f\"Test_acc: {test_accuracy:.2f}\", flush=True)\n",
    "            print()\n",
    "            print()\n",
    "\n",
    "print(f\"Test_acc: {test_accuracy:.2f}\", flush=True)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tmx",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
