{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Protein Graph Classification by training a COSIMO:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2887809/1151068085.py:123: FutureWarning: adjacency_matrix will return a scipy.sparse array instead of a matrix in Networkx 3.0.\n",
      "  node_attributes.append(nx.adjacency_matrix(g).todense().sum(1))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Length of dataset: 923\n",
      "(923,)\n",
      "923\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 70%|██████▉   | 646/923 [00:02<00:00, 289.93it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "592\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 923/923 [00:03<00:00, 255.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5\n",
      "9\n",
      "7\n",
      "922\n",
      "torch.Size([56, 1])\n",
      "922\n",
      "torch.Size([79, 1])\n",
      "922\n",
      "torch.Size([8, 1])\n",
      "922\n",
      "Network(\n",
      "  (base_model): SCCNN(\n",
      "    (in_linear_0): Linear(in_features=1, out_features=16, bias=True)\n",
      "    (in_linear_1): Linear(in_features=1, out_features=16, bias=True)\n",
      "    (in_linear_2): Linear(in_features=1, out_features=16, bias=True)\n",
      "    (layers): ModuleList(\n",
      "      (0-1): 2 x SCCNNLayer_exp(\n",
      "        (diff_derivative_00): Time_derivative_diffusion()\n",
      "        (diff_derivative_10): Time_derivative_diffusion()\n",
      "        (diff_derivative_d1): Time_derivative_diffusion()\n",
      "        (diff_derivative_u1): Time_derivative_diffusion()\n",
      "        (diff_derivative_01): Time_derivative_diffusion()\n",
      "        (diff_derivative_21): Time_derivative_diffusion()\n",
      "        (diff_derivative_2): Time_derivative_diffusion()\n",
      "        (diff_derivative_d2): Time_derivative_diffusion()\n",
      "        (diff_derivative_u2): Time_derivative_diffusion()\n",
      "        (diff_derivative_12): Time_derivative_diffusion()\n",
      "      )\n",
      "    )\n",
      "  )\n",
      "  (out_linear_0): Linear(in_features=16, out_features=2, bias=True)\n",
      "  (out_linear_1): Linear(in_features=16, out_features=2, bias=True)\n",
      "  (out_linear_2): Linear(in_features=16, out_features=2, bias=True)\n",
      ")\n",
      "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2887809/1151068085.py:382: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments.\n",
      "  one_hot_labels = np.array(F.one_hot(torch.tensor(y), num_classes=num_classes))\n",
      "/tmp/ipykernel_2887809/1151068085.py:433: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  x_0 = torch.tensor(x_0)\n",
      "/tmp/ipykernel_2887809/1151068085.py:434: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  x_1 = torch.tensor(x_1)\n",
      "/tmp/ipykernel_2887809/1151068085.py:435: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  x_2 = torch.tensor(x_2)\n",
      "/tmp/ipykernel_2887809/1151068085.py:436: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  y = torch.tensor(y, dtype=torch.float)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iter: 0 Epoch: 1 loss: 2.0379 Train_acc: 0.61\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2887809/1151068085.py:491: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  x_0 = torch.tensor(x_0)\n",
      "/tmp/ipykernel_2887809/1151068085.py:492: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  x_1 = torch.tensor(x_1)\n",
      "/tmp/ipykernel_2887809/1151068085.py:493: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  x_2 = torch.tensor(x_2)\n",
      "/tmp/ipykernel_2887809/1151068085.py:494: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  y = torch.tensor(y, dtype=torch.float)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Val_acc: 0.6689\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2887809/1151068085.py:543: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  x_0 = torch.tensor(x_0)\n",
      "/tmp/ipykernel_2887809/1151068085.py:544: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  x_1 = torch.tensor(x_1)\n",
      "/tmp/ipykernel_2887809/1151068085.py:545: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  x_2 = torch.tensor(x_2)\n",
      "/tmp/ipykernel_2887809/1151068085.py:546: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  y = torch.tensor(y, dtype=torch.float)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test_acc-improved: 0.7243\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 0 Epoch: 2 loss: 0.8380 Train_acc: 0.62\n",
      "Val_acc: 0.6081\n",
      "Test_acc-still: 0.7243\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 0 Epoch: 3 loss: 0.7105 Train_acc: 0.63\n",
      "Val_acc: 0.6014\n",
      "Test_acc-still: 0.7243\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 0 Epoch: 4 loss: 0.6699 Train_acc: 0.65\n",
      "Val_acc: 0.6149\n",
      "Test_acc-still: 0.7243\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 0 Epoch: 5 loss: 0.6478 Train_acc: 0.66\n",
      "Val_acc: 0.6486\n",
      "Test_acc-still: 0.7243\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 0 Epoch: 6 loss: 0.6349 Train_acc: 0.66\n",
      "Val_acc: 0.6757\n",
      "Test_acc-improved: 0.7730\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 0 Epoch: 7 loss: 0.6288 Train_acc: 0.66\n",
      "Val_acc: 0.6959\n",
      "Test_acc-improved: 0.8054\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 0 Epoch: 8 loss: 0.6253 Train_acc: 0.67\n",
      "Val_acc: 0.7095\n",
      "Test_acc-improved: 0.8054\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 0 Epoch: 9 loss: 0.6209 Train_acc: 0.67\n",
      "Val_acc: 0.7230\n",
      "Test_acc-improved: 0.8108\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 0 Epoch: 10 loss: 0.6178 Train_acc: 0.67\n",
      "Val_acc: 0.7095\n",
      "Test_acc-still: 0.8108\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 0 Epoch: 11 loss: 0.6092 Train_acc: 0.68\n",
      "Val_acc: 0.7162\n",
      "Test_acc-still: 0.8108\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 0 Epoch: 12 loss: 0.6133 Train_acc: 0.67\n",
      "Val_acc: 0.7095\n",
      "Test_acc-still: 0.8108\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 0 Epoch: 13 loss: 0.6059 Train_acc: 0.68\n",
      "Val_acc: 0.7162\n",
      "Test_acc-still: 0.8108\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 0 Epoch: 14 loss: 0.5988 Train_acc: 0.69\n",
      "Val_acc: 0.7162\n",
      "Test_acc-still: 0.8108\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 0 Epoch: 15 loss: 0.6020 Train_acc: 0.69\n",
      "Val_acc: 0.7297\n",
      "Test_acc-improved: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 0 Epoch: 16 loss: 0.5963 Train_acc: 0.69\n",
      "Val_acc: 0.7230\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 0 Epoch: 17 loss: 0.5950 Train_acc: 0.69\n",
      "Val_acc: 0.7297\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 0 Epoch: 18 loss: 0.5930 Train_acc: 0.70\n",
      "Val_acc: 0.7230\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 0 Epoch: 19 loss: 0.5892 Train_acc: 0.70\n",
      "Val_acc: 0.7297\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 0 Epoch: 20 loss: 0.5909 Train_acc: 0.70\n",
      "Val_acc: 0.7230\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 0 Epoch: 21 loss: 0.5858 Train_acc: 0.70\n",
      "Val_acc: 0.7230\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 0 Epoch: 22 loss: 0.5872 Train_acc: 0.70\n",
      "Val_acc: 0.7230\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 0 Epoch: 23 loss: 0.5885 Train_acc: 0.70\n",
      "Val_acc: 0.7230\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 0 Epoch: 24 loss: 0.5834 Train_acc: 0.71\n",
      "Val_acc: 0.7162\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 0 Epoch: 25 loss: 0.5832 Train_acc: 0.70\n",
      "Val_acc: 0.7230\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 0 Epoch: 26 loss: 0.5829 Train_acc: 0.70\n",
      "Val_acc: 0.7230\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 0 Epoch: 27 loss: 0.5820 Train_acc: 0.70\n",
      "Val_acc: 0.7162\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 0 Epoch: 28 loss: 0.5812 Train_acc: 0.70\n",
      "Val_acc: 0.7162\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 0 Epoch: 29 loss: 0.5811 Train_acc: 0.69\n",
      "Val_acc: 0.7162\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 0 Epoch: 30 loss: 0.5805 Train_acc: 0.69\n",
      "Val_acc: 0.7162\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "True\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2887809/1151068085.py:123: FutureWarning: adjacency_matrix will return a scipy.sparse array instead of a matrix in Networkx 3.0.\n",
      "  node_attributes.append(nx.adjacency_matrix(g).todense().sum(1))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Length of dataset: 923\n",
      "(923,)\n",
      "923\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 69%|██████▊   | 633/923 [00:02<00:01, 289.39it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "592\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 923/923 [00:03<00:00, 255.40it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5\n",
      "9\n",
      "7\n",
      "922\n",
      "torch.Size([56, 1])\n",
      "922\n",
      "torch.Size([79, 1])\n",
      "922\n",
      "torch.Size([8, 1])\n",
      "922\n",
      "Network(\n",
      "  (base_model): SCCNN(\n",
      "    (in_linear_0): Linear(in_features=1, out_features=16, bias=True)\n",
      "    (in_linear_1): Linear(in_features=1, out_features=16, bias=True)\n",
      "    (in_linear_2): Linear(in_features=1, out_features=16, bias=True)\n",
      "    (layers): ModuleList(\n",
      "      (0-1): 2 x SCCNNLayer_exp(\n",
      "        (diff_derivative_00): Time_derivative_diffusion()\n",
      "        (diff_derivative_10): Time_derivative_diffusion()\n",
      "        (diff_derivative_d1): Time_derivative_diffusion()\n",
      "        (diff_derivative_u1): Time_derivative_diffusion()\n",
      "        (diff_derivative_01): Time_derivative_diffusion()\n",
      "        (diff_derivative_21): Time_derivative_diffusion()\n",
      "        (diff_derivative_2): Time_derivative_diffusion()\n",
      "        (diff_derivative_d2): Time_derivative_diffusion()\n",
      "        (diff_derivative_u2): Time_derivative_diffusion()\n",
      "        (diff_derivative_12): Time_derivative_diffusion()\n",
      "      )\n",
      "    )\n",
      "  )\n",
      "  (out_linear_0): Linear(in_features=16, out_features=2, bias=True)\n",
      "  (out_linear_1): Linear(in_features=16, out_features=2, bias=True)\n",
      "  (out_linear_2): Linear(in_features=16, out_features=2, bias=True)\n",
      ")\n",
      "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]\n",
      "iter: 1 Epoch: 1 loss: 2.5695 Train_acc: 0.62\n",
      "Val_acc: 0.5405\n",
      "Test_acc-improved: 0.6216\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 1 Epoch: 2 loss: 1.1264 Train_acc: 0.60\n",
      "Val_acc: 0.5878\n",
      "Test_acc-improved: 0.6811\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 1 Epoch: 3 loss: 0.8458 Train_acc: 0.62\n",
      "Val_acc: 0.6486\n",
      "Test_acc-improved: 0.7405\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 1 Epoch: 4 loss: 0.7331 Train_acc: 0.65\n",
      "Val_acc: 0.6622\n",
      "Test_acc-improved: 0.7351\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 1 Epoch: 5 loss: 0.6789 Train_acc: 0.65\n",
      "Val_acc: 0.6486\n",
      "Test_acc-still: 0.7351\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 1 Epoch: 6 loss: 0.6525 Train_acc: 0.65\n",
      "Val_acc: 0.6486\n",
      "Test_acc-still: 0.7351\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 1 Epoch: 7 loss: 0.6401 Train_acc: 0.66\n",
      "Val_acc: 0.6757\n",
      "Test_acc-improved: 0.7622\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 1 Epoch: 8 loss: 0.6327 Train_acc: 0.65\n",
      "Val_acc: 0.6959\n",
      "Test_acc-improved: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 1 Epoch: 9 loss: 0.6271 Train_acc: 0.66\n",
      "Val_acc: 0.7162\n",
      "Test_acc-improved: 0.8000\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 1 Epoch: 10 loss: 0.6230 Train_acc: 0.66\n",
      "Val_acc: 0.7162\n",
      "Test_acc-still: 0.8000\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 1 Epoch: 11 loss: 0.6233 Train_acc: 0.66\n",
      "Val_acc: 0.7162\n",
      "Test_acc-still: 0.8000\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 1 Epoch: 12 loss: 0.6185 Train_acc: 0.67\n",
      "Val_acc: 0.7162\n",
      "Test_acc-still: 0.8000\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 1 Epoch: 13 loss: 0.6082 Train_acc: 0.67\n",
      "Val_acc: 0.7162\n",
      "Test_acc-still: 0.8000\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 1 Epoch: 14 loss: 0.6047 Train_acc: 0.68\n",
      "Val_acc: 0.7230\n",
      "Test_acc-improved: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 1 Epoch: 15 loss: 0.6036 Train_acc: 0.68\n",
      "Val_acc: 0.7230\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 1 Epoch: 16 loss: 0.6059 Train_acc: 0.68\n",
      "Val_acc: 0.7297\n",
      "Test_acc-improved: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 1 Epoch: 17 loss: 0.6023 Train_acc: 0.68\n",
      "Val_acc: 0.7432\n",
      "Test_acc-improved: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 1 Epoch: 18 loss: 0.5946 Train_acc: 0.69\n",
      "Val_acc: 0.7365\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 1 Epoch: 19 loss: 0.5940 Train_acc: 0.69\n",
      "Val_acc: 0.7432\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 1 Epoch: 20 loss: 0.5924 Train_acc: 0.70\n",
      "Val_acc: 0.7432\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 1 Epoch: 21 loss: 0.5910 Train_acc: 0.69\n",
      "Val_acc: 0.7432\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 1 Epoch: 22 loss: 0.5897 Train_acc: 0.69\n",
      "Val_acc: 0.7432\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 1 Epoch: 23 loss: 0.5899 Train_acc: 0.69\n",
      "Val_acc: 0.7432\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 1 Epoch: 24 loss: 0.5871 Train_acc: 0.69\n",
      "Val_acc: 0.7432\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 1 Epoch: 25 loss: 0.5866 Train_acc: 0.69\n",
      "Val_acc: 0.7432\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 1 Epoch: 26 loss: 0.5833 Train_acc: 0.70\n",
      "Val_acc: 0.7432\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 1 Epoch: 27 loss: 0.5826 Train_acc: 0.70\n",
      "Val_acc: 0.7432\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 1 Epoch: 28 loss: 0.5855 Train_acc: 0.68\n",
      "Val_acc: 0.7432\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 1 Epoch: 29 loss: 0.5829 Train_acc: 0.69\n",
      "Val_acc: 0.7432\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 1 Epoch: 30 loss: 0.5823 Train_acc: 0.70\n",
      "Val_acc: 0.7432\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "True\n",
      "Length of dataset: 923\n",
      "(923,)\n",
      "923\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 67%|██████▋   | 619/923 [00:02<00:01, 268.30it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "592\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 923/923 [00:03<00:00, 261.39it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5\n",
      "9\n",
      "7\n",
      "922\n",
      "torch.Size([56, 1])\n",
      "922\n",
      "torch.Size([79, 1])\n",
      "922\n",
      "torch.Size([8, 1])\n",
      "922\n",
      "Network(\n",
      "  (base_model): SCCNN(\n",
      "    (in_linear_0): Linear(in_features=1, out_features=16, bias=True)\n",
      "    (in_linear_1): Linear(in_features=1, out_features=16, bias=True)\n",
      "    (in_linear_2): Linear(in_features=1, out_features=16, bias=True)\n",
      "    (layers): ModuleList(\n",
      "      (0-1): 2 x SCCNNLayer_exp(\n",
      "        (diff_derivative_00): Time_derivative_diffusion()\n",
      "        (diff_derivative_10): Time_derivative_diffusion()\n",
      "        (diff_derivative_d1): Time_derivative_diffusion()\n",
      "        (diff_derivative_u1): Time_derivative_diffusion()\n",
      "        (diff_derivative_01): Time_derivative_diffusion()\n",
      "        (diff_derivative_21): Time_derivative_diffusion()\n",
      "        (diff_derivative_2): Time_derivative_diffusion()\n",
      "        (diff_derivative_d2): Time_derivative_diffusion()\n",
      "        (diff_derivative_u2): Time_derivative_diffusion()\n",
      "        (diff_derivative_12): Time_derivative_diffusion()\n",
      "      )\n",
      "    )\n",
      "  )\n",
      "  (out_linear_0): Linear(in_features=16, out_features=2, bias=True)\n",
      "  (out_linear_1): Linear(in_features=16, out_features=2, bias=True)\n",
      "  (out_linear_2): Linear(in_features=16, out_features=2, bias=True)\n",
      ")\n",
      "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]\n",
      "iter: 2 Epoch: 1 loss: 2.2666 Train_acc: 0.61\n",
      "Val_acc: 0.6419\n",
      "Test_acc-improved: 0.7243\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 2 Epoch: 2 loss: 0.9318 Train_acc: 0.62\n",
      "Val_acc: 0.6284\n",
      "Test_acc-still: 0.7243\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 2 Epoch: 3 loss: 0.7317 Train_acc: 0.62\n",
      "Val_acc: 0.5743\n",
      "Test_acc-still: 0.7243\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 2 Epoch: 4 loss: 0.6618 Train_acc: 0.65\n",
      "Val_acc: 0.6014\n",
      "Test_acc-still: 0.7243\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 2 Epoch: 5 loss: 0.6355 Train_acc: 0.67\n",
      "Val_acc: 0.6351\n",
      "Test_acc-still: 0.7243\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 2 Epoch: 6 loss: 0.6235 Train_acc: 0.67\n",
      "Val_acc: 0.6622\n",
      "Test_acc-improved: 0.7622\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 2 Epoch: 7 loss: 0.6181 Train_acc: 0.68\n",
      "Val_acc: 0.6622\n",
      "Test_acc-still: 0.7622\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 2 Epoch: 8 loss: 0.6153 Train_acc: 0.68\n",
      "Val_acc: 0.6757\n",
      "Test_acc-improved: 0.7676\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 2 Epoch: 9 loss: 0.6135 Train_acc: 0.67\n",
      "Val_acc: 0.6824\n",
      "Test_acc-improved: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 2 Epoch: 10 loss: 0.6107 Train_acc: 0.67\n",
      "Val_acc: 0.6757\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 2 Epoch: 11 loss: 0.6076 Train_acc: 0.68\n",
      "Val_acc: 0.6959\n",
      "Test_acc-improved: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 2 Epoch: 12 loss: 0.6083 Train_acc: 0.68\n",
      "Val_acc: 0.7027\n",
      "Test_acc-improved: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 2 Epoch: 13 loss: 0.6035 Train_acc: 0.68\n",
      "Val_acc: 0.7027\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 2 Epoch: 14 loss: 0.6006 Train_acc: 0.69\n",
      "Val_acc: 0.7095\n",
      "Test_acc-improved: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 2 Epoch: 15 loss: 0.6000 Train_acc: 0.68\n",
      "Val_acc: 0.7162\n",
      "Test_acc-improved: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 2 Epoch: 16 loss: 0.5976 Train_acc: 0.69\n",
      "Val_acc: 0.7095\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 2 Epoch: 17 loss: 0.5938 Train_acc: 0.69\n",
      "Val_acc: 0.7230\n",
      "Test_acc-improved: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 2 Epoch: 18 loss: 0.5915 Train_acc: 0.70\n",
      "Val_acc: 0.7162\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 2 Epoch: 19 loss: 0.5921 Train_acc: 0.70\n",
      "Val_acc: 0.7230\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 2 Epoch: 20 loss: 0.5896 Train_acc: 0.69\n",
      "Val_acc: 0.7162\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 2 Epoch: 21 loss: 0.5866 Train_acc: 0.70\n",
      "Val_acc: 0.7297\n",
      "Test_acc-improved: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 2 Epoch: 22 loss: 0.5862 Train_acc: 0.70\n",
      "Val_acc: 0.7297\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 2 Epoch: 23 loss: 0.5843 Train_acc: 0.70\n",
      "Val_acc: 0.7297\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 2 Epoch: 24 loss: 0.5838 Train_acc: 0.70\n",
      "Val_acc: 0.7297\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 2 Epoch: 25 loss: 0.5831 Train_acc: 0.70\n",
      "Val_acc: 0.7230\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 2 Epoch: 26 loss: 0.5828 Train_acc: 0.70\n",
      "Val_acc: 0.7230\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 2 Epoch: 27 loss: 0.5811 Train_acc: 0.70\n",
      "Val_acc: 0.7297\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 2 Epoch: 28 loss: 0.5799 Train_acc: 0.70\n",
      "Val_acc: 0.7297\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 2 Epoch: 29 loss: 0.5786 Train_acc: 0.71\n",
      "Val_acc: 0.7297\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 2 Epoch: 30 loss: 0.5794 Train_acc: 0.70\n",
      "Val_acc: 0.7297\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "True\n",
      "Length of dataset: 923\n",
      "(923,)\n",
      "923\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 70%|██████▉   | 642/923 [00:02<00:00, 295.00it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "592\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 923/923 [00:03<00:00, 266.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5\n",
      "9\n",
      "7\n",
      "922\n",
      "torch.Size([56, 1])\n",
      "922\n",
      "torch.Size([79, 1])\n",
      "922\n",
      "torch.Size([8, 1])\n",
      "922\n",
      "Network(\n",
      "  (base_model): SCCNN(\n",
      "    (in_linear_0): Linear(in_features=1, out_features=16, bias=True)\n",
      "    (in_linear_1): Linear(in_features=1, out_features=16, bias=True)\n",
      "    (in_linear_2): Linear(in_features=1, out_features=16, bias=True)\n",
      "    (layers): ModuleList(\n",
      "      (0-1): 2 x SCCNNLayer_exp(\n",
      "        (diff_derivative_00): Time_derivative_diffusion()\n",
      "        (diff_derivative_10): Time_derivative_diffusion()\n",
      "        (diff_derivative_d1): Time_derivative_diffusion()\n",
      "        (diff_derivative_u1): Time_derivative_diffusion()\n",
      "        (diff_derivative_01): Time_derivative_diffusion()\n",
      "        (diff_derivative_21): Time_derivative_diffusion()\n",
      "        (diff_derivative_2): Time_derivative_diffusion()\n",
      "        (diff_derivative_d2): Time_derivative_diffusion()\n",
      "        (diff_derivative_u2): Time_derivative_diffusion()\n",
      "        (diff_derivative_12): Time_derivative_diffusion()\n",
      "      )\n",
      "    )\n",
      "  )\n",
      "  (out_linear_0): Linear(in_features=16, out_features=2, bias=True)\n",
      "  (out_linear_1): Linear(in_features=16, out_features=2, bias=True)\n",
      "  (out_linear_2): Linear(in_features=16, out_features=2, bias=True)\n",
      ")\n",
      "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]\n",
      "iter: 3 Epoch: 1 loss: 2.4381 Train_acc: 0.61\n",
      "Val_acc: 0.3649\n",
      "Test_acc-improved: 0.4541\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 3 Epoch: 2 loss: 0.9815 Train_acc: 0.66\n",
      "Val_acc: 0.6959\n",
      "Test_acc-improved: 0.7622\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 3 Epoch: 3 loss: 0.7307 Train_acc: 0.67\n",
      "Val_acc: 0.6351\n",
      "Test_acc-still: 0.7622\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 3 Epoch: 4 loss: 0.6656 Train_acc: 0.67\n",
      "Val_acc: 0.6486\n",
      "Test_acc-still: 0.7622\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 3 Epoch: 5 loss: 0.6400 Train_acc: 0.68\n",
      "Val_acc: 0.6622\n",
      "Test_acc-still: 0.7622\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 3 Epoch: 6 loss: 0.6291 Train_acc: 0.68\n",
      "Val_acc: 0.6824\n",
      "Test_acc-still: 0.7622\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 3 Epoch: 7 loss: 0.6276 Train_acc: 0.68\n",
      "Val_acc: 0.7230\n",
      "Test_acc-improved: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 3 Epoch: 8 loss: 0.6244 Train_acc: 0.68\n",
      "Val_acc: 0.7230\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 3 Epoch: 9 loss: 0.6230 Train_acc: 0.69\n",
      "Val_acc: 0.7162\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 3 Epoch: 10 loss: 0.6220 Train_acc: 0.68\n",
      "Val_acc: 0.7230\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 3 Epoch: 11 loss: 0.6142 Train_acc: 0.68\n",
      "Val_acc: 0.7230\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 3 Epoch: 12 loss: 0.6135 Train_acc: 0.68\n",
      "Val_acc: 0.7162\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 3 Epoch: 13 loss: 0.6074 Train_acc: 0.68\n",
      "Val_acc: 0.7162\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 3 Epoch: 14 loss: 0.6111 Train_acc: 0.68\n",
      "Val_acc: 0.7162\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 3 Epoch: 15 loss: 0.6035 Train_acc: 0.69\n",
      "Val_acc: 0.7365\n",
      "Test_acc-improved: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 3 Epoch: 16 loss: 0.6012 Train_acc: 0.69\n",
      "Val_acc: 0.7297\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 3 Epoch: 17 loss: 0.6010 Train_acc: 0.69\n",
      "Val_acc: 0.7297\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 3 Epoch: 18 loss: 0.5942 Train_acc: 0.69\n",
      "Val_acc: 0.7162\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 3 Epoch: 19 loss: 0.5974 Train_acc: 0.70\n",
      "Val_acc: 0.7230\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 3 Epoch: 20 loss: 0.5933 Train_acc: 0.69\n",
      "Val_acc: 0.7162\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 3 Epoch: 21 loss: 0.5918 Train_acc: 0.69\n",
      "Val_acc: 0.7297\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 3 Epoch: 22 loss: 0.5888 Train_acc: 0.69\n",
      "Val_acc: 0.7297\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 3 Epoch: 23 loss: 0.5913 Train_acc: 0.70\n",
      "Val_acc: 0.7162\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 3 Epoch: 24 loss: 0.5884 Train_acc: 0.70\n",
      "Val_acc: 0.7230\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 3 Epoch: 25 loss: 0.5857 Train_acc: 0.69\n",
      "Val_acc: 0.7162\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 3 Epoch: 26 loss: 0.5874 Train_acc: 0.70\n",
      "Val_acc: 0.7230\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 3 Epoch: 27 loss: 0.5837 Train_acc: 0.70\n",
      "Val_acc: 0.7365\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 3 Epoch: 28 loss: 0.5841 Train_acc: 0.70\n",
      "Val_acc: 0.7297\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 3 Epoch: 29 loss: 0.5811 Train_acc: 0.70\n",
      "Val_acc: 0.7230\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 3 Epoch: 30 loss: 0.5823 Train_acc: 0.70\n",
      "Val_acc: 0.7297\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "True\n",
      "Length of dataset: 923\n",
      "(923,)\n",
      "923\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 67%|██████▋   | 618/923 [00:02<00:01, 256.69it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "592\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 923/923 [00:03<00:00, 235.34it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5\n",
      "9\n",
      "7\n",
      "922\n",
      "torch.Size([56, 1])\n",
      "922\n",
      "torch.Size([79, 1])\n",
      "922\n",
      "torch.Size([8, 1])\n",
      "922\n",
      "Network(\n",
      "  (base_model): SCCNN(\n",
      "    (in_linear_0): Linear(in_features=1, out_features=16, bias=True)\n",
      "    (in_linear_1): Linear(in_features=1, out_features=16, bias=True)\n",
      "    (in_linear_2): Linear(in_features=1, out_features=16, bias=True)\n",
      "    (layers): ModuleList(\n",
      "      (0-1): 2 x SCCNNLayer_exp(\n",
      "        (diff_derivative_00): Time_derivative_diffusion()\n",
      "        (diff_derivative_10): Time_derivative_diffusion()\n",
      "        (diff_derivative_d1): Time_derivative_diffusion()\n",
      "        (diff_derivative_u1): Time_derivative_diffusion()\n",
      "        (diff_derivative_01): Time_derivative_diffusion()\n",
      "        (diff_derivative_21): Time_derivative_diffusion()\n",
      "        (diff_derivative_2): Time_derivative_diffusion()\n",
      "        (diff_derivative_d2): Time_derivative_diffusion()\n",
      "        (diff_derivative_u2): Time_derivative_diffusion()\n",
      "        (diff_derivative_12): Time_derivative_diffusion()\n",
      "      )\n",
      "    )\n",
      "  )\n",
      "  (out_linear_0): Linear(in_features=16, out_features=2, bias=True)\n",
      "  (out_linear_1): Linear(in_features=16, out_features=2, bias=True)\n",
      "  (out_linear_2): Linear(in_features=16, out_features=2, bias=True)\n",
      ")\n",
      "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]\n",
      "iter: 4 Epoch: 1 loss: 2.5242 Train_acc: 0.59\n",
      "Val_acc: 0.6554\n",
      "Test_acc-improved: 0.7514\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 4 Epoch: 2 loss: 0.9388 Train_acc: 0.61\n",
      "Val_acc: 0.6892\n",
      "Test_acc-improved: 0.7730\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 4 Epoch: 3 loss: 0.7207 Train_acc: 0.65\n",
      "Val_acc: 0.6284\n",
      "Test_acc-still: 0.7730\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 4 Epoch: 4 loss: 0.6683 Train_acc: 0.65\n",
      "Val_acc: 0.6419\n",
      "Test_acc-still: 0.7730\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 4 Epoch: 5 loss: 0.6470 Train_acc: 0.66\n",
      "Val_acc: 0.6554\n",
      "Test_acc-still: 0.7730\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 4 Epoch: 6 loss: 0.6432 Train_acc: 0.66\n",
      "Val_acc: 0.6689\n",
      "Test_acc-still: 0.7730\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 4 Epoch: 7 loss: 0.6379 Train_acc: 0.67\n",
      "Val_acc: 0.6892\n",
      "Test_acc-still: 0.7730\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 4 Epoch: 8 loss: 0.6345 Train_acc: 0.68\n",
      "Val_acc: 0.7095\n",
      "Test_acc-improved: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 4 Epoch: 9 loss: 0.6379 Train_acc: 0.68\n",
      "Val_acc: 0.7095\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 4 Epoch: 10 loss: 0.6272 Train_acc: 0.69\n",
      "Val_acc: 0.7095\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 4 Epoch: 11 loss: 0.6185 Train_acc: 0.69\n",
      "Val_acc: 0.6959\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 4 Epoch: 12 loss: 0.6208 Train_acc: 0.68\n",
      "Val_acc: 0.6824\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 4 Epoch: 13 loss: 0.6145 Train_acc: 0.69\n",
      "Val_acc: 0.7432\n",
      "Test_acc-improved: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 4 Epoch: 14 loss: 0.6089 Train_acc: 0.69\n",
      "Val_acc: 0.7027\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 4 Epoch: 15 loss: 0.6035 Train_acc: 0.69\n",
      "Val_acc: 0.7162\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 4 Epoch: 16 loss: 0.6006 Train_acc: 0.70\n",
      "Val_acc: 0.7095\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 4 Epoch: 17 loss: 0.6007 Train_acc: 0.69\n",
      "Val_acc: 0.7027\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 4 Epoch: 18 loss: 0.5969 Train_acc: 0.69\n",
      "Val_acc: 0.7297\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 4 Epoch: 19 loss: 0.5961 Train_acc: 0.70\n",
      "Val_acc: 0.7095\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 4 Epoch: 20 loss: 0.5912 Train_acc: 0.70\n",
      "Val_acc: 0.7027\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 4 Epoch: 21 loss: 0.5904 Train_acc: 0.70\n",
      "Val_acc: 0.7095\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 4 Epoch: 22 loss: 0.5897 Train_acc: 0.70\n",
      "Val_acc: 0.7162\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 4 Epoch: 23 loss: 0.5872 Train_acc: 0.70\n",
      "Val_acc: 0.7162\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 4 Epoch: 24 loss: 0.5845 Train_acc: 0.70\n",
      "Val_acc: 0.7230\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 4 Epoch: 25 loss: 0.5849 Train_acc: 0.70\n",
      "Val_acc: 0.7162\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 4 Epoch: 26 loss: 0.5809 Train_acc: 0.71\n",
      "Val_acc: 0.7162\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 4 Epoch: 27 loss: 0.5829 Train_acc: 0.71\n",
      "Val_acc: 0.7230\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 4 Epoch: 28 loss: 0.5790 Train_acc: 0.71\n",
      "Val_acc: 0.7162\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 4 Epoch: 29 loss: 0.5787 Train_acc: 0.71\n",
      "Val_acc: 0.7230\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 4 Epoch: 30 loss: 0.5779 Train_acc: 0.70\n",
      "Val_acc: 0.7162\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "True\n",
      "Length of dataset: 923\n",
      "(923,)\n",
      "923\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 70%|███████   | 649/923 [00:02<00:00, 292.78it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "592\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 923/923 [00:03<00:00, 253.04it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5\n",
      "9\n",
      "7\n",
      "922\n",
      "torch.Size([56, 1])\n",
      "922\n",
      "torch.Size([79, 1])\n",
      "922\n",
      "torch.Size([8, 1])\n",
      "922\n",
      "Network(\n",
      "  (base_model): SCCNN(\n",
      "    (in_linear_0): Linear(in_features=1, out_features=16, bias=True)\n",
      "    (in_linear_1): Linear(in_features=1, out_features=16, bias=True)\n",
      "    (in_linear_2): Linear(in_features=1, out_features=16, bias=True)\n",
      "    (layers): ModuleList(\n",
      "      (0-1): 2 x SCCNNLayer_exp(\n",
      "        (diff_derivative_00): Time_derivative_diffusion()\n",
      "        (diff_derivative_10): Time_derivative_diffusion()\n",
      "        (diff_derivative_d1): Time_derivative_diffusion()\n",
      "        (diff_derivative_u1): Time_derivative_diffusion()\n",
      "        (diff_derivative_01): Time_derivative_diffusion()\n",
      "        (diff_derivative_21): Time_derivative_diffusion()\n",
      "        (diff_derivative_2): Time_derivative_diffusion()\n",
      "        (diff_derivative_d2): Time_derivative_diffusion()\n",
      "        (diff_derivative_u2): Time_derivative_diffusion()\n",
      "        (diff_derivative_12): Time_derivative_diffusion()\n",
      "      )\n",
      "    )\n",
      "  )\n",
      "  (out_linear_0): Linear(in_features=16, out_features=2, bias=True)\n",
      "  (out_linear_1): Linear(in_features=16, out_features=2, bias=True)\n",
      "  (out_linear_2): Linear(in_features=16, out_features=2, bias=True)\n",
      ")\n",
      "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]\n",
      "iter: 5 Epoch: 1 loss: 3.0999 Train_acc: 0.60\n",
      "Val_acc: 0.6351\n",
      "Test_acc-improved: 0.7459\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 5 Epoch: 2 loss: 1.0455 Train_acc: 0.60\n",
      "Val_acc: 0.6284\n",
      "Test_acc-still: 0.7459\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 5 Epoch: 3 loss: 0.8704 Train_acc: 0.61\n",
      "Val_acc: 0.6351\n",
      "Test_acc-still: 0.7459\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 5 Epoch: 4 loss: 0.7701 Train_acc: 0.63\n",
      "Val_acc: 0.6351\n",
      "Test_acc-still: 0.7459\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 5 Epoch: 5 loss: 0.7201 Train_acc: 0.64\n",
      "Val_acc: 0.6419\n",
      "Test_acc-improved: 0.7405\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 5 Epoch: 6 loss: 0.6975 Train_acc: 0.65\n",
      "Val_acc: 0.6554\n",
      "Test_acc-improved: 0.7568\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 5 Epoch: 7 loss: 0.6835 Train_acc: 0.66\n",
      "Val_acc: 0.6554\n",
      "Test_acc-still: 0.7568\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 5 Epoch: 8 loss: 0.6697 Train_acc: 0.66\n",
      "Val_acc: 0.6554\n",
      "Test_acc-still: 0.7568\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 5 Epoch: 9 loss: 0.6611 Train_acc: 0.65\n",
      "Val_acc: 0.6824\n",
      "Test_acc-improved: 0.7676\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 5 Epoch: 10 loss: 0.6553 Train_acc: 0.65\n",
      "Val_acc: 0.6959\n",
      "Test_acc-improved: 0.7676\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 5 Epoch: 11 loss: 0.6443 Train_acc: 0.65\n",
      "Val_acc: 0.6892\n",
      "Test_acc-still: 0.7676\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 5 Epoch: 12 loss: 0.6400 Train_acc: 0.66\n",
      "Val_acc: 0.6959\n",
      "Test_acc-still: 0.7676\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 5 Epoch: 13 loss: 0.6299 Train_acc: 0.67\n",
      "Val_acc: 0.7027\n",
      "Test_acc-improved: 0.7676\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 5 Epoch: 14 loss: 0.6287 Train_acc: 0.66\n",
      "Val_acc: 0.7027\n",
      "Test_acc-still: 0.7676\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 5 Epoch: 15 loss: 0.6354 Train_acc: 0.65\n",
      "Val_acc: 0.7162\n",
      "Test_acc-improved: 0.7730\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 5 Epoch: 16 loss: 0.6154 Train_acc: 0.67\n",
      "Val_acc: 0.7095\n",
      "Test_acc-still: 0.7730\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 5 Epoch: 17 loss: 0.6059 Train_acc: 0.67\n",
      "Val_acc: 0.7230\n",
      "Test_acc-improved: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 5 Epoch: 18 loss: 0.6033 Train_acc: 0.67\n",
      "Val_acc: 0.7297\n",
      "Test_acc-improved: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 5 Epoch: 19 loss: 0.6042 Train_acc: 0.68\n",
      "Val_acc: 0.7365\n",
      "Test_acc-improved: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 5 Epoch: 20 loss: 0.6114 Train_acc: 0.69\n",
      "Val_acc: 0.7162\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 5 Epoch: 21 loss: 0.6012 Train_acc: 0.68\n",
      "Val_acc: 0.7297\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 5 Epoch: 22 loss: 0.5923 Train_acc: 0.69\n",
      "Val_acc: 0.7365\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 5 Epoch: 23 loss: 0.5978 Train_acc: 0.69\n",
      "Val_acc: 0.7297\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 5 Epoch: 24 loss: 0.5955 Train_acc: 0.69\n",
      "Val_acc: 0.7365\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 5 Epoch: 25 loss: 0.5895 Train_acc: 0.70\n",
      "Val_acc: 0.7230\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 5 Epoch: 26 loss: 0.5905 Train_acc: 0.70\n",
      "Val_acc: 0.7297\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 5 Epoch: 27 loss: 0.5887 Train_acc: 0.70\n",
      "Val_acc: 0.7230\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 5 Epoch: 28 loss: 0.5899 Train_acc: 0.70\n",
      "Val_acc: 0.7297\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 5 Epoch: 29 loss: 0.5833 Train_acc: 0.70\n",
      "Val_acc: 0.7297\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 5 Epoch: 30 loss: 0.5868 Train_acc: 0.71\n",
      "Val_acc: 0.7297\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "True\n",
      "Length of dataset: 923\n",
      "(923,)\n",
      "923\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 69%|██████▉   | 638/923 [00:02<00:00, 289.84it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "592\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 923/923 [00:03<00:00, 252.46it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5\n",
      "9\n",
      "7\n",
      "922\n",
      "torch.Size([56, 1])\n",
      "922\n",
      "torch.Size([79, 1])\n",
      "922\n",
      "torch.Size([8, 1])\n",
      "922\n",
      "Network(\n",
      "  (base_model): SCCNN(\n",
      "    (in_linear_0): Linear(in_features=1, out_features=16, bias=True)\n",
      "    (in_linear_1): Linear(in_features=1, out_features=16, bias=True)\n",
      "    (in_linear_2): Linear(in_features=1, out_features=16, bias=True)\n",
      "    (layers): ModuleList(\n",
      "      (0-1): 2 x SCCNNLayer_exp(\n",
      "        (diff_derivative_00): Time_derivative_diffusion()\n",
      "        (diff_derivative_10): Time_derivative_diffusion()\n",
      "        (diff_derivative_d1): Time_derivative_diffusion()\n",
      "        (diff_derivative_u1): Time_derivative_diffusion()\n",
      "        (diff_derivative_01): Time_derivative_diffusion()\n",
      "        (diff_derivative_21): Time_derivative_diffusion()\n",
      "        (diff_derivative_2): Time_derivative_diffusion()\n",
      "        (diff_derivative_d2): Time_derivative_diffusion()\n",
      "        (diff_derivative_u2): Time_derivative_diffusion()\n",
      "        (diff_derivative_12): Time_derivative_diffusion()\n",
      "      )\n",
      "    )\n",
      "  )\n",
      "  (out_linear_0): Linear(in_features=16, out_features=2, bias=True)\n",
      "  (out_linear_1): Linear(in_features=16, out_features=2, bias=True)\n",
      "  (out_linear_2): Linear(in_features=16, out_features=2, bias=True)\n",
      ")\n",
      "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]\n",
      "iter: 6 Epoch: 1 loss: 1.9096 Train_acc: 0.62\n",
      "Val_acc: 0.5541\n",
      "Test_acc-improved: 0.6649\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 6 Epoch: 2 loss: 0.9569 Train_acc: 0.62\n",
      "Val_acc: 0.5878\n",
      "Test_acc-improved: 0.6973\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 6 Epoch: 3 loss: 0.7623 Train_acc: 0.64\n",
      "Val_acc: 0.6149\n",
      "Test_acc-improved: 0.7568\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 6 Epoch: 4 loss: 0.6844 Train_acc: 0.65\n",
      "Val_acc: 0.6216\n",
      "Test_acc-improved: 0.7514\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 6 Epoch: 5 loss: 0.6445 Train_acc: 0.67\n",
      "Val_acc: 0.6216\n",
      "Test_acc-still: 0.7514\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 6 Epoch: 6 loss: 0.6229 Train_acc: 0.68\n",
      "Val_acc: 0.6622\n",
      "Test_acc-improved: 0.7784\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 6 Epoch: 7 loss: 0.6118 Train_acc: 0.68\n",
      "Val_acc: 0.6689\n",
      "Test_acc-improved: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 6 Epoch: 8 loss: 0.6071 Train_acc: 0.69\n",
      "Val_acc: 0.6959\n",
      "Test_acc-improved: 0.7784\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 6 Epoch: 9 loss: 0.6043 Train_acc: 0.69\n",
      "Val_acc: 0.6959\n",
      "Test_acc-still: 0.7784\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 6 Epoch: 10 loss: 0.6015 Train_acc: 0.69\n",
      "Val_acc: 0.6959\n",
      "Test_acc-still: 0.7784\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 6 Epoch: 11 loss: 0.5994 Train_acc: 0.70\n",
      "Val_acc: 0.6959\n",
      "Test_acc-still: 0.7784\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 6 Epoch: 12 loss: 0.5978 Train_acc: 0.69\n",
      "Val_acc: 0.7027\n",
      "Test_acc-improved: 0.8054\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 6 Epoch: 13 loss: 0.5959 Train_acc: 0.69\n",
      "Val_acc: 0.6959\n",
      "Test_acc-still: 0.8054\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 6 Epoch: 14 loss: 0.5945 Train_acc: 0.70\n",
      "Val_acc: 0.7027\n",
      "Test_acc-still: 0.8054\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 6 Epoch: 15 loss: 0.5908 Train_acc: 0.70\n",
      "Val_acc: 0.7095\n",
      "Test_acc-improved: 0.8054\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 6 Epoch: 16 loss: 0.5930 Train_acc: 0.70\n",
      "Val_acc: 0.7027\n",
      "Test_acc-still: 0.8054\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 6 Epoch: 17 loss: 0.5886 Train_acc: 0.69\n",
      "Val_acc: 0.7095\n",
      "Test_acc-still: 0.8054\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 6 Epoch: 18 loss: 0.5882 Train_acc: 0.70\n",
      "Val_acc: 0.7162\n",
      "Test_acc-improved: 0.8000\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 6 Epoch: 19 loss: 0.5876 Train_acc: 0.70\n",
      "Val_acc: 0.7297\n",
      "Test_acc-improved: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 6 Epoch: 20 loss: 0.5851 Train_acc: 0.70\n",
      "Val_acc: 0.7297\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 6 Epoch: 21 loss: 0.5833 Train_acc: 0.70\n",
      "Val_acc: 0.7365\n",
      "Test_acc-improved: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 6 Epoch: 22 loss: 0.5831 Train_acc: 0.70\n",
      "Val_acc: 0.7432\n",
      "Test_acc-improved: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 6 Epoch: 23 loss: 0.5824 Train_acc: 0.71\n",
      "Val_acc: 0.7297\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 6 Epoch: 24 loss: 0.5816 Train_acc: 0.71\n",
      "Val_acc: 0.7365\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 6 Epoch: 25 loss: 0.5816 Train_acc: 0.71\n",
      "Val_acc: 0.7365\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 6 Epoch: 26 loss: 0.5802 Train_acc: 0.70\n",
      "Val_acc: 0.7432\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 6 Epoch: 27 loss: 0.5795 Train_acc: 0.71\n",
      "Val_acc: 0.7432\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 6 Epoch: 28 loss: 0.5788 Train_acc: 0.71\n",
      "Val_acc: 0.7365\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 6 Epoch: 29 loss: 0.5784 Train_acc: 0.71\n",
      "Val_acc: 0.7365\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 6 Epoch: 30 loss: 0.5777 Train_acc: 0.71\n",
      "Val_acc: 0.7365\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "True\n",
      "Length of dataset: 923\n",
      "(923,)\n",
      "923\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 68%|██████▊   | 626/923 [00:02<00:01, 274.33it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "592\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 923/923 [00:03<00:00, 241.52it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5\n",
      "9\n",
      "7\n",
      "922\n",
      "torch.Size([56, 1])\n",
      "922\n",
      "torch.Size([79, 1])\n",
      "922\n",
      "torch.Size([8, 1])\n",
      "922\n",
      "Network(\n",
      "  (base_model): SCCNN(\n",
      "    (in_linear_0): Linear(in_features=1, out_features=16, bias=True)\n",
      "    (in_linear_1): Linear(in_features=1, out_features=16, bias=True)\n",
      "    (in_linear_2): Linear(in_features=1, out_features=16, bias=True)\n",
      "    (layers): ModuleList(\n",
      "      (0-1): 2 x SCCNNLayer_exp(\n",
      "        (diff_derivative_00): Time_derivative_diffusion()\n",
      "        (diff_derivative_10): Time_derivative_diffusion()\n",
      "        (diff_derivative_d1): Time_derivative_diffusion()\n",
      "        (diff_derivative_u1): Time_derivative_diffusion()\n",
      "        (diff_derivative_01): Time_derivative_diffusion()\n",
      "        (diff_derivative_21): Time_derivative_diffusion()\n",
      "        (diff_derivative_2): Time_derivative_diffusion()\n",
      "        (diff_derivative_d2): Time_derivative_diffusion()\n",
      "        (diff_derivative_u2): Time_derivative_diffusion()\n",
      "        (diff_derivative_12): Time_derivative_diffusion()\n",
      "      )\n",
      "    )\n",
      "  )\n",
      "  (out_linear_0): Linear(in_features=16, out_features=2, bias=True)\n",
      "  (out_linear_1): Linear(in_features=16, out_features=2, bias=True)\n",
      "  (out_linear_2): Linear(in_features=16, out_features=2, bias=True)\n",
      ")\n",
      "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]\n",
      "iter: 7 Epoch: 1 loss: 2.3171 Train_acc: 0.61\n",
      "Val_acc: 0.6216\n",
      "Test_acc-improved: 0.7243\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 7 Epoch: 2 loss: 0.9398 Train_acc: 0.63\n",
      "Val_acc: 0.6216\n",
      "Test_acc-still: 0.7243\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 7 Epoch: 3 loss: 0.7339 Train_acc: 0.65\n",
      "Val_acc: 0.6284\n",
      "Test_acc-improved: 0.7459\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 7 Epoch: 4 loss: 0.6606 Train_acc: 0.67\n",
      "Val_acc: 0.6081\n",
      "Test_acc-still: 0.7459\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 7 Epoch: 5 loss: 0.6409 Train_acc: 0.68\n",
      "Val_acc: 0.6622\n",
      "Test_acc-improved: 0.7730\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 7 Epoch: 6 loss: 0.6340 Train_acc: 0.68\n",
      "Val_acc: 0.6824\n",
      "Test_acc-improved: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 7 Epoch: 7 loss: 0.6294 Train_acc: 0.67\n",
      "Val_acc: 0.6892\n",
      "Test_acc-improved: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 7 Epoch: 8 loss: 0.6271 Train_acc: 0.68\n",
      "Val_acc: 0.7027\n",
      "Test_acc-improved: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 7 Epoch: 9 loss: 0.6227 Train_acc: 0.68\n",
      "Val_acc: 0.7095\n",
      "Test_acc-improved: 0.7730\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 7 Epoch: 10 loss: 0.6190 Train_acc: 0.68\n",
      "Val_acc: 0.7162\n",
      "Test_acc-improved: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 7 Epoch: 11 loss: 0.6150 Train_acc: 0.68\n",
      "Val_acc: 0.7297\n",
      "Test_acc-improved: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 7 Epoch: 12 loss: 0.6112 Train_acc: 0.68\n",
      "Val_acc: 0.7365\n",
      "Test_acc-improved: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 7 Epoch: 13 loss: 0.6065 Train_acc: 0.69\n",
      "Val_acc: 0.7500\n",
      "Test_acc-improved: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 7 Epoch: 14 loss: 0.6043 Train_acc: 0.69\n",
      "Val_acc: 0.7027\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 7 Epoch: 15 loss: 0.5957 Train_acc: 0.69\n",
      "Val_acc: 0.7500\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 7 Epoch: 16 loss: 0.5939 Train_acc: 0.69\n",
      "Val_acc: 0.7500\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 7 Epoch: 17 loss: 0.5915 Train_acc: 0.69\n",
      "Val_acc: 0.7500\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 7 Epoch: 18 loss: 0.5870 Train_acc: 0.70\n",
      "Val_acc: 0.7500\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 7 Epoch: 19 loss: 0.5849 Train_acc: 0.70\n",
      "Val_acc: 0.7432\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 7 Epoch: 20 loss: 0.5843 Train_acc: 0.70\n",
      "Val_acc: 0.7500\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 7 Epoch: 21 loss: 0.5806 Train_acc: 0.70\n",
      "Val_acc: 0.7500\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 7 Epoch: 22 loss: 0.5796 Train_acc: 0.70\n",
      "Val_acc: 0.7432\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 7 Epoch: 23 loss: 0.5786 Train_acc: 0.70\n",
      "Val_acc: 0.7432\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 7 Epoch: 24 loss: 0.5775 Train_acc: 0.70\n",
      "Val_acc: 0.7500\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 7 Epoch: 25 loss: 0.5784 Train_acc: 0.71\n",
      "Val_acc: 0.7432\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 7 Epoch: 26 loss: 0.5745 Train_acc: 0.71\n",
      "Val_acc: 0.7432\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 7 Epoch: 27 loss: 0.5745 Train_acc: 0.71\n",
      "Val_acc: 0.7432\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 7 Epoch: 28 loss: 0.5742 Train_acc: 0.72\n",
      "Val_acc: 0.7432\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 7 Epoch: 29 loss: 0.5731 Train_acc: 0.72\n",
      "Val_acc: 0.7432\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 7 Epoch: 30 loss: 0.5722 Train_acc: 0.72\n",
      "Val_acc: 0.7297\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "True\n",
      "Length of dataset: 923\n",
      "(923,)\n",
      "923\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 68%|██████▊   | 630/923 [00:02<00:01, 287.22it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "592\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 923/923 [00:03<00:00, 260.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5\n",
      "9\n",
      "7\n",
      "922\n",
      "torch.Size([56, 1])\n",
      "922\n",
      "torch.Size([79, 1])\n",
      "922\n",
      "torch.Size([8, 1])\n",
      "922\n",
      "Network(\n",
      "  (base_model): SCCNN(\n",
      "    (in_linear_0): Linear(in_features=1, out_features=16, bias=True)\n",
      "    (in_linear_1): Linear(in_features=1, out_features=16, bias=True)\n",
      "    (in_linear_2): Linear(in_features=1, out_features=16, bias=True)\n",
      "    (layers): ModuleList(\n",
      "      (0-1): 2 x SCCNNLayer_exp(\n",
      "        (diff_derivative_00): Time_derivative_diffusion()\n",
      "        (diff_derivative_10): Time_derivative_diffusion()\n",
      "        (diff_derivative_d1): Time_derivative_diffusion()\n",
      "        (diff_derivative_u1): Time_derivative_diffusion()\n",
      "        (diff_derivative_01): Time_derivative_diffusion()\n",
      "        (diff_derivative_21): Time_derivative_diffusion()\n",
      "        (diff_derivative_2): Time_derivative_diffusion()\n",
      "        (diff_derivative_d2): Time_derivative_diffusion()\n",
      "        (diff_derivative_u2): Time_derivative_diffusion()\n",
      "        (diff_derivative_12): Time_derivative_diffusion()\n",
      "      )\n",
      "    )\n",
      "  )\n",
      "  (out_linear_0): Linear(in_features=16, out_features=2, bias=True)\n",
      "  (out_linear_1): Linear(in_features=16, out_features=2, bias=True)\n",
      "  (out_linear_2): Linear(in_features=16, out_features=2, bias=True)\n",
      ")\n",
      "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]\n",
      "iter: 8 Epoch: 1 loss: 3.0451 Train_acc: 0.61\n",
      "Val_acc: 0.6486\n",
      "Test_acc-improved: 0.7676\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 8 Epoch: 2 loss: 1.0195 Train_acc: 0.63\n",
      "Val_acc: 0.6486\n",
      "Test_acc-still: 0.7676\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 8 Epoch: 3 loss: 0.7508 Train_acc: 0.63\n",
      "Val_acc: 0.6284\n",
      "Test_acc-still: 0.7676\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 8 Epoch: 4 loss: 0.6686 Train_acc: 0.65\n",
      "Val_acc: 0.5878\n",
      "Test_acc-still: 0.7676\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 8 Epoch: 5 loss: 0.6414 Train_acc: 0.66\n",
      "Val_acc: 0.6014\n",
      "Test_acc-still: 0.7676\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 8 Epoch: 6 loss: 0.6339 Train_acc: 0.68\n",
      "Val_acc: 0.6622\n",
      "Test_acc-improved: 0.8000\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 8 Epoch: 7 loss: 0.6323 Train_acc: 0.67\n",
      "Val_acc: 0.6757\n",
      "Test_acc-improved: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 8 Epoch: 8 loss: 0.6353 Train_acc: 0.66\n",
      "Val_acc: 0.6757\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 8 Epoch: 9 loss: 0.6317 Train_acc: 0.67\n",
      "Val_acc: 0.6892\n",
      "Test_acc-improved: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 8 Epoch: 10 loss: 0.6256 Train_acc: 0.66\n",
      "Val_acc: 0.6959\n",
      "Test_acc-improved: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 8 Epoch: 11 loss: 0.6193 Train_acc: 0.67\n",
      "Val_acc: 0.6959\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 8 Epoch: 12 loss: 0.6166 Train_acc: 0.68\n",
      "Val_acc: 0.7027\n",
      "Test_acc-improved: 0.8000\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 8 Epoch: 13 loss: 0.6123 Train_acc: 0.68\n",
      "Val_acc: 0.7095\n",
      "Test_acc-improved: 0.8000\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 8 Epoch: 14 loss: 0.6081 Train_acc: 0.68\n",
      "Val_acc: 0.6959\n",
      "Test_acc-still: 0.8000\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 8 Epoch: 15 loss: 0.6064 Train_acc: 0.68\n",
      "Val_acc: 0.6959\n",
      "Test_acc-still: 0.8000\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 8 Epoch: 16 loss: 0.6015 Train_acc: 0.69\n",
      "Val_acc: 0.7027\n",
      "Test_acc-still: 0.8000\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 8 Epoch: 17 loss: 0.5944 Train_acc: 0.69\n",
      "Val_acc: 0.7162\n",
      "Test_acc-improved: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 8 Epoch: 18 loss: 0.5906 Train_acc: 0.70\n",
      "Val_acc: 0.7095\n",
      "Test_acc-still: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 8 Epoch: 19 loss: 0.5954 Train_acc: 0.70\n",
      "Val_acc: 0.7230\n",
      "Test_acc-improved: 0.7892\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 8 Epoch: 20 loss: 0.5913 Train_acc: 0.70\n",
      "Val_acc: 0.7365\n",
      "Test_acc-improved: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 8 Epoch: 21 loss: 0.5849 Train_acc: 0.71\n",
      "Val_acc: 0.7297\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 8 Epoch: 22 loss: 0.5889 Train_acc: 0.71\n",
      "Val_acc: 0.7297\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 8 Epoch: 23 loss: 0.5863 Train_acc: 0.70\n",
      "Val_acc: 0.7297\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 8 Epoch: 24 loss: 0.5866 Train_acc: 0.71\n",
      "Val_acc: 0.7365\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 8 Epoch: 25 loss: 0.5824 Train_acc: 0.71\n",
      "Val_acc: 0.7230\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 8 Epoch: 26 loss: 0.5846 Train_acc: 0.71\n",
      "Val_acc: 0.7297\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 8 Epoch: 27 loss: 0.5812 Train_acc: 0.71\n",
      "Val_acc: 0.7297\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 8 Epoch: 28 loss: 0.5808 Train_acc: 0.71\n",
      "Val_acc: 0.7297\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 8 Epoch: 29 loss: 0.5804 Train_acc: 0.71\n",
      "Val_acc: 0.7297\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 8 Epoch: 30 loss: 0.5796 Train_acc: 0.71\n",
      "Val_acc: 0.7297\n",
      "Test_acc-still: 0.7838\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "True\n",
      "Length of dataset: 923\n",
      "(923,)\n",
      "923\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 69%|██████▉   | 639/923 [00:02<00:00, 296.97it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "592\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 923/923 [00:03<00:00, 265.64it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5\n",
      "9\n",
      "7\n",
      "922\n",
      "torch.Size([56, 1])\n",
      "922\n",
      "torch.Size([79, 1])\n",
      "922\n",
      "torch.Size([8, 1])\n",
      "922\n",
      "Network(\n",
      "  (base_model): SCCNN(\n",
      "    (in_linear_0): Linear(in_features=1, out_features=16, bias=True)\n",
      "    (in_linear_1): Linear(in_features=1, out_features=16, bias=True)\n",
      "    (in_linear_2): Linear(in_features=1, out_features=16, bias=True)\n",
      "    (layers): ModuleList(\n",
      "      (0-1): 2 x SCCNNLayer_exp(\n",
      "        (diff_derivative_00): Time_derivative_diffusion()\n",
      "        (diff_derivative_10): Time_derivative_diffusion()\n",
      "        (diff_derivative_d1): Time_derivative_diffusion()\n",
      "        (diff_derivative_u1): Time_derivative_diffusion()\n",
      "        (diff_derivative_01): Time_derivative_diffusion()\n",
      "        (diff_derivative_21): Time_derivative_diffusion()\n",
      "        (diff_derivative_2): Time_derivative_diffusion()\n",
      "        (diff_derivative_d2): Time_derivative_diffusion()\n",
      "        (diff_derivative_u2): Time_derivative_diffusion()\n",
      "        (diff_derivative_12): Time_derivative_diffusion()\n",
      "      )\n",
      "    )\n",
      "  )\n",
      "  (out_linear_0): Linear(in_features=16, out_features=2, bias=True)\n",
      "  (out_linear_1): Linear(in_features=16, out_features=2, bias=True)\n",
      "  (out_linear_2): Linear(in_features=16, out_features=2, bias=True)\n",
      ")\n",
      "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]\n",
      "iter: 9 Epoch: 1 loss: 3.7332 Train_acc: 0.61\n",
      "Val_acc: 0.5878\n",
      "Test_acc-improved: 0.6595\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 9 Epoch: 2 loss: 1.5265 Train_acc: 0.61\n",
      "Val_acc: 0.6419\n",
      "Test_acc-improved: 0.7405\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 9 Epoch: 3 loss: 0.9182 Train_acc: 0.62\n",
      "Val_acc: 0.6351\n",
      "Test_acc-still: 0.7405\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 9 Epoch: 4 loss: 0.7165 Train_acc: 0.64\n",
      "Val_acc: 0.6554\n",
      "Test_acc-improved: 0.7676\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 9 Epoch: 5 loss: 0.6573 Train_acc: 0.66\n",
      "Val_acc: 0.6351\n",
      "Test_acc-still: 0.7676\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 9 Epoch: 6 loss: 0.6398 Train_acc: 0.67\n",
      "Val_acc: 0.6554\n",
      "Test_acc-still: 0.7676\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 9 Epoch: 7 loss: 0.6311 Train_acc: 0.68\n",
      "Val_acc: 0.6959\n",
      "Test_acc-improved: 0.7784\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 9 Epoch: 8 loss: 0.6279 Train_acc: 0.68\n",
      "Val_acc: 0.7230\n",
      "Test_acc-improved: 0.7784\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 9 Epoch: 9 loss: 0.6293 Train_acc: 0.69\n",
      "Val_acc: 0.7162\n",
      "Test_acc-still: 0.7784\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 9 Epoch: 10 loss: 0.6331 Train_acc: 0.67\n",
      "Val_acc: 0.7095\n",
      "Test_acc-still: 0.7784\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 9 Epoch: 11 loss: 0.6341 Train_acc: 0.68\n",
      "Val_acc: 0.7230\n",
      "Test_acc-still: 0.7784\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 9 Epoch: 12 loss: 0.6200 Train_acc: 0.68\n",
      "Val_acc: 0.7162\n",
      "Test_acc-still: 0.7784\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 9 Epoch: 13 loss: 0.6221 Train_acc: 0.68\n",
      "Val_acc: 0.7162\n",
      "Test_acc-still: 0.7784\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 9 Epoch: 14 loss: 0.6138 Train_acc: 0.68\n",
      "Val_acc: 0.7027\n",
      "Test_acc-still: 0.7784\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 9 Epoch: 15 loss: 0.6181 Train_acc: 0.69\n",
      "Val_acc: 0.6824\n",
      "Test_acc-still: 0.7784\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 9 Epoch: 16 loss: 0.6049 Train_acc: 0.69\n",
      "Val_acc: 0.7162\n",
      "Test_acc-still: 0.7784\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 9 Epoch: 17 loss: 0.6135 Train_acc: 0.68\n",
      "Val_acc: 0.6892\n",
      "Test_acc-still: 0.7784\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 9 Epoch: 18 loss: 0.5941 Train_acc: 0.70\n",
      "Val_acc: 0.7297\n",
      "Test_acc-improved: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 9 Epoch: 19 loss: 0.5993 Train_acc: 0.69\n",
      "Val_acc: 0.7027\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 9 Epoch: 20 loss: 0.5936 Train_acc: 0.70\n",
      "Val_acc: 0.7297\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 9 Epoch: 21 loss: 0.5914 Train_acc: 0.70\n",
      "Val_acc: 0.7230\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 9 Epoch: 22 loss: 0.5902 Train_acc: 0.69\n",
      "Val_acc: 0.7297\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 9 Epoch: 23 loss: 0.5881 Train_acc: 0.70\n",
      "Val_acc: 0.7230\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 9 Epoch: 24 loss: 0.5883 Train_acc: 0.71\n",
      "Val_acc: 0.7230\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 9 Epoch: 25 loss: 0.5867 Train_acc: 0.70\n",
      "Val_acc: 0.7230\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 9 Epoch: 26 loss: 0.5842 Train_acc: 0.70\n",
      "Val_acc: 0.7230\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 9 Epoch: 27 loss: 0.5871 Train_acc: 0.70\n",
      "Val_acc: 0.7230\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 9 Epoch: 28 loss: 0.5811 Train_acc: 0.70\n",
      "Val_acc: 0.7162\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 9 Epoch: 29 loss: 0.5802 Train_acc: 0.70\n",
      "Val_acc: 0.7230\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "iter: 9 Epoch: 30 loss: 0.5803 Train_acc: 0.69\n",
      "Val_acc: 0.7297\n",
      "Test_acc-still: 0.7946\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n"
     ]
    }
   ],
   "source": [
    "acc_test = []\n",
    "for itter in range(10):\n",
    "    Test_accuracy = -1\n",
    "    import numpy as np\n",
    "    import random\n",
    "    import torch \n",
    "    # Define a fixed seed value\n",
    "    SEED = 2\n",
    "\n",
    "    # 1. Set the Python built-in random module's seed\n",
    "    random.seed(SEED)\n",
    "\n",
    "    # 2. Set the NumPy random seed\n",
    "    np.random.seed(SEED)\n",
    "\n",
    "    # 3. Set the PyTorch seed (for both CPU and GPU)\n",
    "    torch.manual_seed(itter)# Define a fixed seed value\n",
    "    print(torch.cuda.is_available())\n",
    "\n",
    "    import toponetx.datasets as datasets\n",
    "    from sklearn.model_selection import train_test_split\n",
    "\n",
    "    from Utils.sccnn_exp import SCCNN\n",
    "    from topomodelx.utils.sparse import from_sparse\n",
    "\n",
    "    # %load_ext autoreload\n",
    "    # %autoreload 2\n",
    "\n",
    "    import scipy\n",
    "    from scipy import sparse\n",
    "\n",
    "    def get_evals_evecs(L, k):\n",
    "        L_sparse = sparse.coo_matrix(L)\n",
    "\n",
    "        evals, evecs = scipy.sparse.linalg.eigs(L_sparse, k=k, ncv=4*k, return_eigenvectors=True)\n",
    "        # evals, evecs = scipy.linalg.eig(L)\n",
    "\n",
    "        evals=torch.tensor(evals.real)\n",
    "        evecs=torch.tensor(evecs.real)\n",
    "\n",
    "        return evals, evecs \n",
    "\n",
    "    import argparse\n",
    "\n",
    "    import numpy as np\n",
    "    from tqdm import tqdm\n",
    "    import gudhi\n",
    "    import torch\n",
    "    from Utils.preprocessing.simplicial_construction import get_boundary_matrices, get_boundary_matrices_from_processed_tree, process_simplex_tree, get_neighbors, get_weight_matrix_graph, get_weight_matrix_simplex,generate_triangles, augment_simplex_open_gc, _get_laplacians,_get_simplex_features_gc\n",
    "    from Utils.preprocessing.graph_construction import _get_graph\n",
    "    # from model.model import MPSN,SCNN,SAN\n",
    "    import torch.nn as nn\n",
    "    # from model.loss import l_rel, l_sub\n",
    "    import copy\n",
    "    import time\n",
    "    import numpy as np\n",
    "    from matplotlib.lines import Line2D\n",
    "    import matplotlib.pyplot as plt\n",
    "    import torch\n",
    "    import networkx as nx\n",
    "    from sklearn import metrics\n",
    "    from sklearn.metrics import classification_report,f1_score, accuracy_score\n",
    "    import sys\n",
    "\n",
    "    from sklearn.linear_model import LogisticRegression, RidgeClassifier\n",
    "    from sklearn.svm import SVC\n",
    "    from sklearn.model_selection import train_test_split\n",
    "\n",
    "    import gc\n",
    "    gc.enable()\n",
    "\n",
    "\n",
    "\n",
    "    parser = argparse.ArgumentParser(description='TopoSRL')\n",
    "\n",
    "    parser.add_argument('--dataname', type=str, default='proteins', help='Name of dataset.')\n",
    "    parser.add_argument('--gpu', type=int, default=0, help='GPU index.')\n",
    "    parser.add_argument('--epochs', type=int, default=1, help='Training epochs.')\n",
    "    parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate of TopoSRL encoder.')\n",
    "    parser.add_argument('--wd', type=float, default=0, help='Weight decay of TopoSRL encoder.')\n",
    "\n",
    "    parser.add_argument('--dim', type=int, default=4, help='Order of the simplicial complex.')\n",
    "    parser.add_argument('--alpha', type=float, default=0.5, help='alpha.')\n",
    "    parser.add_argument('--snn', type=str, default='MPSN', help='Type of SNN')\n",
    "    parser.add_argument('--delta', type=int, default=20, help='Number of samples to calculate L_rel')\n",
    "    parser.add_argument('--augmentation', type=str,  default='open', help='Type of agumentation')\n",
    "    parser.add_argument('--rho', type=float, default=0.1, help='Simplex removing and adding ratio.')\n",
    "\n",
    "    args = parser.parse_args(args=[])\n",
    "\n",
    "    if args.gpu != -1 and torch.cuda.is_available():\n",
    "        args.device = 'cuda'\n",
    "    else:\n",
    "        args.device = 'cpu'\n",
    "\n",
    "\n",
    "    data = args.dataname\n",
    "    alpha = args.alpha\n",
    "    delta = args.delta\n",
    "    epochs = args.epochs\n",
    "    labels = np.load('data/graph classification/'+data+'/label_sets_'+data+'.npy', allow_pickle=True)\n",
    "    simplicial = np.load('data/graph classification/'+data+'/simplicial_sets_'+data+'.npy', allow_pickle=True)\n",
    "    SCs = []\n",
    "    INDs = []\n",
    "    _labels = []\n",
    "    simplex_trees = []\n",
    "    node_attributes = []\n",
    "    netxG = []\n",
    "    for p in range(len(simplicial)):\n",
    "        for q in range(len(simplicial[p])):\n",
    "            simplex_tree = gudhi.SimplexTree()\n",
    "            sc = [[] for i in range(4)]\n",
    "            for i in range(4):\n",
    "                for j in simplicial[p][q][i]:\n",
    "                    sc[i].append(list(j))\n",
    "                    simplex_tree.insert(list(j))\n",
    "            for i in range(len(sc)):\n",
    "                sc[i] = np.array(sc[i])\n",
    "            if(len(sc[3])):\n",
    "                INDs.append(simplicial[p][q])\n",
    "                g = nx.from_edgelist(sc[1])\n",
    "                _labels.append(labels[p][q])\n",
    "                node_attributes.append(nx.adjacency_matrix(g).todense().sum(1))\n",
    "                netxG.append(g)\n",
    "                SCs.append(sc)    \n",
    "                simplex_trees.append(simplex_tree)\n",
    "    labels = np.array(_labels)\n",
    "    print(\"Length of dataset:\", len(SCs))\n",
    "    print(labels.shape)\n",
    "    print(len(SCs))\n",
    "\n",
    "\n",
    "    from scipy.sparse import coo_matrix\n",
    "\n",
    "\n",
    "    max_rank = 2  # the order of the SC is two\n",
    "    incidence_1_list = []\n",
    "    incidence_2_list = []\n",
    "\n",
    "\n",
    "    kk_0 = 2\n",
    "    kk_1 = 2\n",
    "    kk_2 = 2\n",
    "    evals_0_list = []\n",
    "    evecs_0_list = []\n",
    "    evals_d1_list = []\n",
    "    evecs_d1_list = []\n",
    "    evals_u1_list = []\n",
    "    evecs_u1_list = []\n",
    "    evals_2_list = []\n",
    "    evecs_2_list = []\n",
    "\n",
    "\n",
    "    _labels = []\n",
    "    skipped = []\n",
    "    x_0s = []\n",
    "    x_1s = []\n",
    "    x_2s = []\n",
    "    for i in tqdm(range(len(labels))):\n",
    "        _,_,bm, = get_boundary_matrices_from_processed_tree(simplex_trees[i], SCs[i], INDs[i], 3)\n",
    "        l1, l1_d, l1_u = _get_laplacians(bm)\n",
    "        _X = torch.FloatTensor(node_attributes[i]).to(args.device).view(len(node_attributes[i]),1)\n",
    "        X1 = _get_simplex_features_gc(SCs[i][1:3],_X)\n",
    "        if (X1[0].shape[0] != l1[0].shape[0]) or (X1[1].shape[0] != l1[1].shape[0]) or (X1[2].shape[0] != l1[2].shape[0]):\n",
    "            print(i)\n",
    "            continue\n",
    "        try:\n",
    "            x_0s.append(X1[0].to(args.device))\n",
    "            x_1s.append(X1[1].to(args.device))\n",
    "            x_2s.append(X1[2].to(args.device))\n",
    "            incidence_1 = bm[0].cpu().detach().numpy()\n",
    "            incidence_2 = bm[1].cpu().detach().numpy()\n",
    "            laplacian_0 = l1[0].cpu().detach().numpy()\n",
    "            laplacian_down_1 = l1_d[1].cpu().detach().numpy()\n",
    "            laplacian_up_1 = l1_u[1].cpu().detach().numpy()\n",
    "            laplacian_2 = l1[2].cpu().detach().numpy()\n",
    "\n",
    "            if X1[0].shape[0] != laplacian_0[0].shape[0]:\n",
    "                print(i)\n",
    "                print(X1[0].shape[0])\n",
    "                print(laplacian_0[0].shape[0])\n",
    "\n",
    "            incidence_1 = coo_matrix(incidence_1) # Convert NumPy array to COO sparse format\n",
    "            incidence_2 = coo_matrix(incidence_2)  # Convert NumPy array to COO sparse format\n",
    "            incidence_1 = from_sparse(incidence_1).to(args.device)\n",
    "            incidence_2 = from_sparse(incidence_2).to(args.device)\n",
    "            \n",
    "            evals_0, evecs_0 = get_evals_evecs(laplacian_0, kk_0)\n",
    "            evals_d1, evecs_d1 = get_evals_evecs(laplacian_down_1, kk_1)\n",
    "            evals_u1, evecs_u1 = get_evals_evecs(laplacian_up_1, kk_1)\n",
    "            evals_2, evecs_2 = get_evals_evecs(laplacian_2, kk_2)\n",
    "\n",
    "            incidence_1_list.append(incidence_1)\n",
    "            incidence_2_list.append(incidence_2)\n",
    "            evals_0_list.append(evals_0.to(args.device))\n",
    "            evecs_0_list.append(evecs_0.to(args.device))\n",
    "            evals_d1_list.append(evals_d1.to(args.device))\n",
    "            evecs_d1_list.append(evecs_d1.to(args.device))\n",
    "            evals_u1_list.append(evals_u1.to(args.device))\n",
    "            evecs_u1_list.append(evecs_u1.to(args.device))\n",
    "            evals_2_list.append(evals_2.to(args.device))\n",
    "            evecs_2_list.append(evecs_2.to(args.device))\n",
    "            _labels.append(labels[i])\n",
    "            # print(i)\n",
    "            # print(laplacian_0)\n",
    "        except RuntimeError:\n",
    "            skipped.append(i)\n",
    "    _labels = np.array(_labels)         \n",
    "\n",
    "\n",
    "    print(laplacian_0.shape[0])\n",
    "    print(laplacian_down_1.shape[0])\n",
    "    print(laplacian_2.shape[0])\n",
    "    print(len(x_0s))\n",
    "    print(x_0s[1].shape)\n",
    "    print(len(x_1s))\n",
    "    print(x_1s[1].shape)\n",
    "    print(len(x_2s))\n",
    "    print(x_2s[1].shape)\n",
    "    print(_labels.shape[0])\n",
    "\n",
    "    class Network(torch.nn.Module):\n",
    "        def __init__(\n",
    "            self,\n",
    "            in_channels_all,\n",
    "            hidden_channels_all,\n",
    "            out_channels,\n",
    "            conv_order,\n",
    "            max_rank,\n",
    "            n_layers=2,\n",
    "        ):\n",
    "            super().__init__()\n",
    "            self.base_model = SCCNN(\n",
    "                in_channels_all=in_channels_all,\n",
    "                hidden_channels_all=hidden_channels_all,\n",
    "                conv_order=conv_order,\n",
    "                sc_order=max_rank,\n",
    "                n_layers=n_layers,\n",
    "            )\n",
    "            out_channels_0, out_channels_1, out_channels_2 = hidden_channels_all\n",
    "            self.out_linear_0 = torch.nn.Linear(out_channels_0, out_channels)\n",
    "            self.out_linear_1 = torch.nn.Linear(out_channels_1, out_channels)\n",
    "            self.out_linear_2 = torch.nn.Linear(out_channels_2, out_channels)\n",
    "\n",
    "        def forward(self, x_all, eig_eiv_all, incidence_all):\n",
    "            x_all = self.base_model(x_all, eig_eiv_all, incidence_all)\n",
    "            x_0, x_1, x_2 = x_all\n",
    "\n",
    "            x_0 = self.out_linear_0(x_0)\n",
    "            x_1 = self.out_linear_1(x_1)\n",
    "            x_2 = self.out_linear_2(x_2)\n",
    "\n",
    "            # Take the average of the 2D, 1D, and 0D cell features. If they are NaN, convert them to 0.\n",
    "            two_dimensional_cells_mean = torch.nanmean(x_2, dim=0)\n",
    "            two_dimensional_cells_mean[torch.isnan(two_dimensional_cells_mean)] = 0\n",
    "            one_dimensional_cells_mean = torch.nanmean(x_1, dim=0)\n",
    "            one_dimensional_cells_mean[torch.isnan(one_dimensional_cells_mean)] = 0\n",
    "            zero_dimensional_cells_mean = torch.nanmean(x_0, dim=0)\n",
    "            zero_dimensional_cells_mean[torch.isnan(zero_dimensional_cells_mean)] = 0\n",
    "            # Return the sum of the averages\n",
    "            return (\n",
    "                two_dimensional_cells_mean\n",
    "                + one_dimensional_cells_mean\n",
    "                + zero_dimensional_cells_mean\n",
    "            )\n",
    "            \n",
    "            \n",
    "    conv_order = 2\n",
    "    intermediate_channels_all = (16, 16, 16)\n",
    "    num_layers = 2\n",
    "    out_channels = 2  # num classes\n",
    "\n",
    "    # in_channels_0 = x_0s[-1].shape[1]\n",
    "    # in_channels_1 = x_1s[-1].shape[1]\n",
    "    # in_channels_2 = x_2s[-1].shape[1]\n",
    "\n",
    "    in_channels_all = (1, 1, 1)\n",
    "    # print(in_channels_all)\n",
    "\n",
    "    model = Network(\n",
    "        in_channels_all=in_channels_all,\n",
    "        hidden_channels_all=intermediate_channels_all,\n",
    "        out_channels=out_channels,\n",
    "        conv_order=conv_order,\n",
    "        max_rank=max_rank,\n",
    "        n_layers=num_layers,\n",
    "    ).to(args.device)\n",
    "\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "    # loss_fn = torch.nn.MSELoss(size_average=True, reduction=\"mean\")\n",
    "    print(model)\n",
    "\n",
    "\n",
    "    test_size = 0.2\n",
    "    val_size = 0.2\n",
    "    x_0_train, x_0_test = train_test_split(x_0s, test_size=test_size, shuffle=True, random_state=SEED)\n",
    "    x_0_train, x_0_val = train_test_split(x_0_train, test_size=val_size, shuffle=True, random_state=SEED)\n",
    "    x_1_train, x_1_test = train_test_split(x_1s, test_size=test_size, shuffle=True, random_state=SEED)\n",
    "    x_1_train, x_1_val = train_test_split(x_1_train, test_size=val_size, shuffle=True, random_state=SEED)\n",
    "    x_2_train, x_2_test = train_test_split(x_2s, test_size=test_size, shuffle=True, random_state=SEED)\n",
    "    x_2_train, x_2_val = train_test_split(x_2_train, test_size=val_size, shuffle=True, random_state=SEED)\n",
    "\n",
    "    incidence_1_train, incidence_1_test = train_test_split(\n",
    "        incidence_1_list, test_size=test_size, shuffle=True, random_state=SEED\n",
    "    )\n",
    "    incidence_1_train, incidence_1_val = train_test_split(\n",
    "        incidence_1_train, test_size=val_size, shuffle=True, random_state=SEED\n",
    "    )\n",
    "\n",
    "    incidence_2_train, incidence_2_test = train_test_split(\n",
    "        incidence_2_list, test_size=test_size, shuffle=True, random_state=SEED\n",
    "    )\n",
    "    incidence_2_train, incidence_2_val = train_test_split(\n",
    "        incidence_2_train, test_size=val_size, shuffle=True, random_state=SEED\n",
    "    )\n",
    "\n",
    "    evals_0_train, evals_0_test = train_test_split(\n",
    "        evals_0_list, test_size=test_size, shuffle=True, random_state=SEED\n",
    "    )\n",
    "    evals_0_train, evals_0_val = train_test_split(\n",
    "        evals_0_train, test_size=val_size, shuffle=True, random_state=SEED\n",
    "    )\n",
    "\n",
    "    evecs_0_train, evecs_0_test = train_test_split(\n",
    "        evecs_0_list, test_size=test_size, shuffle=True, random_state=SEED\n",
    "    )\n",
    "    evecs_0_train, evecs_0_val = train_test_split(\n",
    "        evecs_0_train, test_size=val_size, shuffle=True, random_state=SEED\n",
    "    )\n",
    "\n",
    "    evals_d1_train, evals_d1_test = train_test_split(\n",
    "        evals_d1_list, test_size=test_size, shuffle=True, random_state=SEED\n",
    "    )\n",
    "    evals_d1_train, evals_d1_val = train_test_split(\n",
    "        evals_d1_train, test_size=val_size, shuffle=True, random_state=SEED\n",
    "    )\n",
    "\n",
    "    evecs_d1_train, evecs_d1_test = train_test_split(\n",
    "        evecs_d1_list, test_size=test_size, shuffle=True, random_state=SEED\n",
    "    )\n",
    "    evecs_d1_train, evecs_d1_val = train_test_split(\n",
    "        evecs_d1_train, test_size=val_size, shuffle=True, random_state=SEED\n",
    "    )\n",
    "\n",
    "    evals_u1_train, evals_u1_test = train_test_split(\n",
    "        evals_u1_list, test_size=test_size, shuffle=True, random_state=SEED\n",
    "    )\n",
    "    evals_u1_train, evals_u1_val = train_test_split(\n",
    "        evals_u1_train, test_size=val_size, shuffle=True, random_state=SEED\n",
    "    )\n",
    "\n",
    "    evecs_u1_train, evecs_u1_test = train_test_split(\n",
    "        evecs_u1_list, test_size=test_size, shuffle=True, random_state=SEED\n",
    "    )\n",
    "    evecs_u1_train, evecs_u1_val = train_test_split(\n",
    "        evecs_u1_train, test_size=val_size, shuffle=True, random_state=SEED\n",
    "    )\n",
    "\n",
    "    evals_2_train, evals_2_test = train_test_split(\n",
    "        evals_2_list, test_size=test_size, shuffle=True, random_state=SEED\n",
    "    )\n",
    "    evals_2_train, evals_2_val = train_test_split(\n",
    "        evals_2_train, test_size=val_size, shuffle=True, random_state=SEED\n",
    "    )\n",
    "\n",
    "    evecs_2_train, evecs_2_test = train_test_split(\n",
    "        evecs_2_list, test_size=test_size, shuffle=True, random_state=SEED\n",
    "    )\n",
    "    evecs_2_train, evecs_2_val = train_test_split(\n",
    "        evecs_2_train, test_size=val_size, shuffle=True, random_state=SEED\n",
    "    )\n",
    "\n",
    "    # y_train, y_test = train_test_split(ys, test_size=test_size, shuffle=True)\n",
    "    # y_train, y_val = train_test_split(y_train, test_size=val_size, shuffle=True)\n",
    "\n",
    "\n",
    "    import torch.nn.functional as F\n",
    "\n",
    "    y = np.array(_labels)\n",
    "    print(y)\n",
    "    num_classes = 2  # Define the number of classes\n",
    "    one_hot_labels = np.array(F.one_hot(torch.tensor(y), num_classes=num_classes))\n",
    "\n",
    "\n",
    "    # y_train = np.array(one_hot_labels[:320])\n",
    "    # y_test = np.array(one_hot_labels[320:])\n",
    "\n",
    "    y_train, y_test = train_test_split(one_hot_labels,test_size=test_size, shuffle=True, random_state=SEED)\n",
    "    y_train, y_val = train_test_split(y_train, test_size=val_size, shuffle=True, random_state=SEED)\n",
    "    # train_indices, test_indices = train_test_split(np.arange(len(one_hot_labels)), test_size=0.2, random_state=42)\n",
    "    # y_train = one_hot_labels[train_indices]\n",
    "    # y_test = one_hot_labels[test_indices]\n",
    "\n",
    "    y_train = torch.from_numpy(y_train).to(args.device)\n",
    "    y_test = torch.from_numpy(y_test).to(args.device)\n",
    "    y_val = torch.from_numpy(y_val).to(args.device)\n",
    "\n",
    "    test_interval = 1\n",
    "    num_epochs = 30\n",
    "    Val_loss_Best = float('inf')\n",
    "    criterion = torch.nn.CrossEntropyLoss()\n",
    "    Val_accuracy_Best = -1\n",
    "    for epoch_i in range(1, num_epochs + 1):\n",
    "        epoch_loss = []\n",
    "        y_train_pred = []\n",
    "        y_val_pred = []\n",
    "        y_test_pred = []\n",
    "        model.train()\n",
    "        for (\n",
    "            x_0,\n",
    "            x_1,\n",
    "            x_2,\n",
    "            incidence_1,\n",
    "            incidence_2,\n",
    "            evals_0, evecs_0,\n",
    "            evals_d1, evecs_d1,\n",
    "            evals_u1, evecs_u1,\n",
    "            evals_2, evecs_2,\n",
    "            y,\n",
    "        ) in zip(\n",
    "            x_0_train,\n",
    "            x_1_train,\n",
    "            x_2_train,\n",
    "            incidence_1_train,\n",
    "            incidence_2_train,\n",
    "            evals_0_train, evecs_0_train,\n",
    "            evals_d1_train, evecs_d1_train,\n",
    "            evals_u1_train, evecs_u1_train,\n",
    "            evals_2_train, evecs_2_train,\n",
    "            y_train,\n",
    "            strict=False,\n",
    "        ):\n",
    "            x_0 = torch.tensor(x_0)\n",
    "            x_1 = torch.tensor(x_1)\n",
    "            x_2 = torch.tensor(x_2)\n",
    "            y = torch.tensor(y, dtype=torch.float)\n",
    "            optimizer.zero_grad()\n",
    "            \n",
    "            x_all = (x_0.float(), x_1.float(), x_2.float())\n",
    "            eig_eiv_all = (evals_0, evecs_0, evals_d1, evecs_d1, evals_u1, evecs_u1, evals_2, evecs_2)\n",
    "            incidence_all = (incidence_1, incidence_2)\n",
    "\n",
    "            y_hat = model(x_all, eig_eiv_all, incidence_all)\n",
    "            y_train_pred.append(y_hat)\n",
    "            loss = criterion(y_hat, torch.argmax(y.float()))\n",
    "\n",
    "            # print(y_hat)\n",
    "            # loss = loss_fn(y_hat, y)\n",
    "\n",
    "            epoch_loss.append(loss.item())\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "        probs = torch.softmax(torch.stack(y_train_pred, 0), dim=1)\n",
    "        # Get predictions (index of the max probability)\n",
    "        y_pred = torch.argmax(probs, dim=1)\n",
    "        # accuracy = (y_pred[: len(y_train)] == torch.argmax(y_train.float(), dim=1)).float().mean().item()\n",
    "        correct = (y_pred == torch.argmax(y_train.float(), dim=1)).sum().item()\n",
    "        # correct = (y_pred[train_indices] == torch.argmax(y_train.float(), dim=1)).sum().item()\n",
    "        accuracy = correct / y_train.size(0)\n",
    "\n",
    "        print(\n",
    "            f\"iter: {itter} Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f} Train_acc: {accuracy:.2f}\",\n",
    "            flush=True,\n",
    "        )\n",
    "        with torch.no_grad():\n",
    "                for (\n",
    "                    x_0,\n",
    "                    x_1,\n",
    "                    x_2,\n",
    "                    incidence_1,\n",
    "                    incidence_2,\n",
    "                    evals_0, evecs_0,\n",
    "                    evals_d1, evecs_d1,\n",
    "                    evals_u1, evecs_u1,\n",
    "                    evals_2, evecs_2,\n",
    "                    y,\n",
    "                ) in zip(\n",
    "                    x_0_val,\n",
    "                    x_1_val,\n",
    "                    x_2_val,\n",
    "                    incidence_1_val,\n",
    "                    incidence_2_val,\n",
    "                    evals_0_val, evecs_0_val,\n",
    "                    evals_d1_val, evecs_d1_val,\n",
    "                    evals_u1_val, evecs_u1_val,\n",
    "                    evals_2_val, evecs_2_val,\n",
    "                    y_val,\n",
    "                    strict=False,\n",
    "                ):\n",
    "                    x_0 = torch.tensor(x_0)\n",
    "                    x_1 = torch.tensor(x_1)\n",
    "                    x_2 = torch.tensor(x_2)\n",
    "                    y = torch.tensor(y, dtype=torch.float)\n",
    "                    optimizer.zero_grad()\n",
    "                    x_all = (x_0.float(), x_1.float(), x_2.float())\n",
    "                    eig_eiv_all = (\n",
    "                        evals_0, evecs_0,\n",
    "                        evals_d1, evecs_d1,\n",
    "                        evals_u1, evecs_u1,\n",
    "                        evals_2, evecs_2,\n",
    "                    )\n",
    "                    incidence_all = (incidence_1, incidence_2)\n",
    "\n",
    "                    y_hat = model(x_all, eig_eiv_all, incidence_all)\n",
    "\n",
    "                    y_val_pred.append(y_hat)\n",
    "                    # Val_loss = loss_fn(y_hat, y)\n",
    "                    \n",
    "                probs = torch.softmax(torch.stack(y_val_pred,0), dim=1)\n",
    "                y_pred = torch.argmax(probs, dim=1)\n",
    "                correct = (y_pred == torch.argmax(y_val.float(), dim=1)).sum().item()\n",
    "                Val_accuracy = correct / y_val.size(0)\n",
    "\n",
    "                    \n",
    "                print(f\"Val_acc: {Val_accuracy:.4f}\", flush=True)\n",
    "                if Val_accuracy > Val_accuracy_Best:\n",
    "                    Val_accuracy_Best = Val_accuracy\n",
    "                    for (\n",
    "                        x_0,\n",
    "                        x_1,\n",
    "                        x_2,\n",
    "                        incidence_1,\n",
    "                        incidence_2,\n",
    "                        evals_0, evecs_0,\n",
    "                        evals_d1, evecs_d1,\n",
    "                        evals_u1, evecs_u1,\n",
    "                        evals_2, evecs_2,\n",
    "                        y,\n",
    "                    ) in zip(\n",
    "                        x_0_test,\n",
    "                        x_1_test,\n",
    "                        x_2_test,\n",
    "                        incidence_1_test,\n",
    "                        incidence_2_test,\n",
    "                        evals_0_test, evecs_0_test,\n",
    "                        evals_d1_test, evecs_d1_test,\n",
    "                        evals_u1_test, evecs_u1_test,\n",
    "                        evals_2_test, evecs_2_test,\n",
    "                        y_test,\n",
    "                        strict=False,\n",
    "                    ):\n",
    "                        x_0 = torch.tensor(x_0)\n",
    "                        x_1 = torch.tensor(x_1)\n",
    "                        x_2 = torch.tensor(x_2)\n",
    "                        y = torch.tensor(y, dtype=torch.float)\n",
    "                        optimizer.zero_grad()\n",
    "                        x_all = (x_0.float(), x_1.float(), x_2.float())\n",
    "                        eig_eiv_all = (\n",
    "                            evals_0, evecs_0,\n",
    "                            evals_d1, evecs_d1,\n",
    "                            evals_u1, evecs_u1,\n",
    "                            evals_2, evecs_2,\n",
    "                        )\n",
    "                        incidence_all = (incidence_1, incidence_2)\n",
    "\n",
    "                        y_hat = model(x_all, eig_eiv_all, incidence_all)\n",
    "                        y_test_pred.append(y_hat)\n",
    "\n",
    "                    probs = torch.softmax(torch.stack(y_test_pred,0), dim=1)\n",
    "                    y_pred = torch.argmax(probs, dim=1)\n",
    "                    correct = (y_pred == torch.argmax(y_test.float(), dim=1)).sum().item()\n",
    "                    Test_accuracy = correct / y_test.size(0)\n",
    "                    # Test_loss = loss_fn(y_hat, y)/(torch.norm(y,2)**2)\n",
    "                    print(f\"Test_acc-improved: {Test_accuracy:.4f}\", flush=True)\n",
    "                else:\n",
    "                    print(f\"Test_acc-still: {Test_accuracy:.4f}\", flush=True)\n",
    "                    \n",
    "                print(\">\"*100)\n",
    "    acc_test.append(Test_accuracy)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mean: 0.7897\n",
      "Standard Deviation: 0.0045\n"
     ]
    }
   ],
   "source": [
    "mean_value = np.mean(acc_test)\n",
    "std_value = np.std(acc_test)\n",
    "\n",
    "print(\"Mean:\", round(mean_value,4))\n",
    "print(\"Standard Deviation:\", round(std_value,4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.7945945945945946,\n",
       " 0.7837837837837838,\n",
       " 0.7945945945945946,\n",
       " 0.7837837837837838,\n",
       " 0.7891891891891892,\n",
       " 0.7945945945945946,\n",
       " 0.7891891891891892,\n",
       " 0.7891891891891892,\n",
       " 0.7837837837837838,\n",
       " 0.7945945945945946]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "acc_test"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tmx",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
