{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train a COSIMO\n",
    "\n",
    "### We train the model to perform:\n",
    "    Complex Regression using the shrec16 benchmark dataset.\n",
    "   "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Continuous Simplicial Neural Networks [COSIMO]</a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Complex Regression on the Shrec-16 Dataset: Small Version"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import random\n",
    "import torch\n",
    "# One can define a fixed seed value (we spanned [1,10])\n",
    "SEED = 4\n",
    "\n",
    "# 1. Set the Python built-in random module's seed\n",
    "random.seed(SEED)\n",
    "\n",
    "# 2. Set the NumPy random seed\n",
    "np.random.seed(SEED)\n",
    "\n",
    "# 3. Set the PyTorch seed (for both CPU and GPU)\n",
    "torch.manual_seed(SEED)# Define a fixed seed value\n",
    "print(torch.cuda.is_available())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import toponetx.datasets as datasets\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "from COSIMO import COSIMO\n",
    "from topomodelx.utils.sparse import from_sparse\n",
    "\n",
    "# %load_ext autoreload\n",
    "# %autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import scipy\n",
    "from scipy import sparse\n",
    "\n",
    "def get_evals_evecs(L, k):\n",
    "    L_sparse = sparse.coo_matrix(L)\n",
    "\n",
    "    evals, evecs = scipy.sparse.linalg.eigs(L_sparse, k=k, ncv=4*k, return_eigenvectors=True)\n",
    "    # evals, evecs = scipy.linalg.eig(L)\n",
    "\n",
    "    evals=torch.tensor(evals.real)\n",
    "    evecs=torch.tensor(evecs.real)\n",
    "\n",
    "    return evals, evecs "
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pre-processing\n",
    "\n",
    "### Import shrec dataset ##\n",
    "\n",
    "We must first lift our graph dataset into the simplicial complex domain."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading shrec 16 small dataset...\n",
      "\n",
      "done!\n"
     ]
    }
   ],
   "source": [
    "shrec, _ = datasets.mesh.shrec_16(size=\"small\")\n",
    "shrec = {key: np.array(value) for key, value in shrec.items()}\n",
    "x_0s = shrec[\"node_feat\"]\n",
    "x_1s = shrec[\"edge_feat\"]\n",
    "x_2s = shrec[\"face_feat\"]\n",
    "\n",
    "ys = shrec[\"label\"]\n",
    "simplexes = shrec[\"complexes\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(6, 10, 7)\n"
     ]
    }
   ],
   "source": [
    "in_channels_0 = x_0s[-1].shape[1]\n",
    "in_channels_1 = x_1s[-1].shape[1]\n",
    "in_channels_2 = x_2s[-1].shape[1]\n",
    "\n",
    "in_channels_all = (in_channels_0, in_channels_1, in_channels_2)\n",
    "print(in_channels_all)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define Neighborhood Strctures\n",
    "Get incidence matrices $\\mathbf{B}_1,\\mathbf{B}_2$ and Hodge Laplacians $\\mathbf{L}_0, \\mathbf{L}_1$ and $\\mathbf{L}_2$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_rank = 2  # the order of the SC is two\n",
    "incidence_1_list = []\n",
    "incidence_2_list = []\n",
    "\n",
    "\n",
    "kk_0 = 10\n",
    "kk_1 = 10\n",
    "kk_2 = 10\n",
    "evals_0_list = []\n",
    "evecs_0_list = []\n",
    "evals_d1_list = []\n",
    "evecs_d1_list = []\n",
    "evals_u1_list = []\n",
    "evecs_u1_list = []\n",
    "evals_2_list = []\n",
    "evecs_2_list = []\n",
    "\n",
    "for simplex in simplexes:\n",
    "    incidence_1 = simplex.incidence_matrix(rank=1)\n",
    "    incidence_2 = simplex.incidence_matrix(rank=2)\n",
    "    laplacian_0 = simplex.hodge_laplacian_matrix(rank=0)\n",
    "    laplacian_down_1 = simplex.down_laplacian_matrix(rank=1)\n",
    "    laplacian_up_1 = simplex.up_laplacian_matrix(rank=1)\n",
    "    laplacian_2 = simplex.hodge_laplacian_matrix(rank=2)\n",
    "\n",
    "    incidence_1 = from_sparse(incidence_1)\n",
    "    incidence_2 = from_sparse(incidence_2)\n",
    "    evals_0, evecs_0 = get_evals_evecs(laplacian_0, kk_0)\n",
    "    evals_d1, evecs_d1 = get_evals_evecs(laplacian_down_1, kk_1)\n",
    "    evals_u1, evecs_u1 = get_evals_evecs(laplacian_up_1, kk_1)\n",
    "    evals_2, evecs_2 = get_evals_evecs(laplacian_2, kk_2)\n",
    "\n",
    "    incidence_1_list.append(incidence_1)\n",
    "    incidence_2_list.append(incidence_2)\n",
    "    evals_0_list.append(evals_0)\n",
    "    evecs_0_list.append(evecs_0)\n",
    "    evals_d1_list.append(evals_d1)\n",
    "    evecs_d1_list.append(evecs_d1)\n",
    "    evals_u1_list.append(evals_u1)\n",
    "    evecs_u1_list.append(evecs_u1)\n",
    "    evals_2_list.append(evals_2)\n",
    "    evecs_2_list.append(evecs_2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "252\n",
      "750\n",
      "500\n"
     ]
    }
   ],
   "source": [
    "print(laplacian_0.shape[0])\n",
    "print(laplacian_down_1.shape[0])\n",
    "print(laplacian_2.shape[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create and Train the Neural Network\n",
    "\n",
    "We specify the model with our pre-made neighborhood structures and specify an optimizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Network(torch.nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        in_channels_all,\n",
    "        hidden_channels_all,\n",
    "        out_channels,\n",
    "        conv_order,\n",
    "        max_rank,\n",
    "        n_layers=2,\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.base_model = COSIMO(\n",
    "            in_channels_all=in_channels_all,\n",
    "            hidden_channels_all=hidden_channels_all,\n",
    "            conv_order=conv_order,\n",
    "            sc_order=max_rank,\n",
    "            n_layers=n_layers,\n",
    "        )\n",
    "        out_channels_0, out_channels_1, out_channels_2 = hidden_channels_all\n",
    "        self.out_linear_0 = torch.nn.Linear(out_channels_0, out_channels)\n",
    "        self.out_linear_1 = torch.nn.Linear(out_channels_1, out_channels)\n",
    "        self.out_linear_2 = torch.nn.Linear(out_channels_2, out_channels)\n",
    "\n",
    "    def forward(self, x_all, eig_eiv_all, incidence_all):\n",
    "        x_all = self.base_model(x_all, eig_eiv_all, incidence_all)\n",
    "        x_0, x_1, x_2 = x_all\n",
    "\n",
    "        x_0 = self.out_linear_0(x_0)\n",
    "        x_1 = self.out_linear_1(x_1)\n",
    "        x_2 = self.out_linear_2(x_2)\n",
    "\n",
    "        # Take the average of the 2D, 1D, and 0D cell features. If they are NaN, convert them to 0.\n",
    "        two_dimensional_cells_mean = torch.nanmean(x_2, dim=0)\n",
    "        two_dimensional_cells_mean[torch.isnan(two_dimensional_cells_mean)] = 0\n",
    "        one_dimensional_cells_mean = torch.nanmean(x_1, dim=0)\n",
    "        one_dimensional_cells_mean[torch.isnan(one_dimensional_cells_mean)] = 0\n",
    "        zero_dimensional_cells_mean = torch.nanmean(x_0, dim=0)\n",
    "        zero_dimensional_cells_mean[torch.isnan(zero_dimensional_cells_mean)] = 0\n",
    "        # Return the sum of the averages\n",
    "        return (\n",
    "            two_dimensional_cells_mean\n",
    "            + one_dimensional_cells_mean\n",
    "            + zero_dimensional_cells_mean\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Network(\n",
      "  (base_model): COSIMO(\n",
      "    (in_linear_0): Linear(in_features=6, out_features=16, bias=True)\n",
      "    (in_linear_1): Linear(in_features=10, out_features=16, bias=True)\n",
      "    (in_linear_2): Linear(in_features=7, out_features=16, bias=True)\n",
      "    (layers): ModuleList(\n",
      "      (0-1): 2 x COSIMO_Layer_exp(\n",
      "        (diff_derivative_00): Time_derivative_diffusion()\n",
      "        (diff_derivative_10): Time_derivative_diffusion()\n",
      "        (diff_derivative_d1): Time_derivative_diffusion()\n",
      "        (diff_derivative_u1): Time_derivative_diffusion()\n",
      "        (diff_derivative_01): Time_derivative_diffusion()\n",
      "        (diff_derivative_21): Time_derivative_diffusion()\n",
      "        (diff_derivative_2): Time_derivative_diffusion()\n",
      "        (diff_derivative_d2): Time_derivative_diffusion()\n",
      "        (diff_derivative_u2): Time_derivative_diffusion()\n",
      "        (diff_derivative_12): Time_derivative_diffusion()\n",
      "      )\n",
      "    )\n",
      "  )\n",
      "  (out_linear_0): Linear(in_features=16, out_features=1, bias=True)\n",
      "  (out_linear_1): Linear(in_features=16, out_features=1, bias=True)\n",
      "  (out_linear_2): Linear(in_features=16, out_features=1, bias=True)\n",
      ")\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/aref/anaconda3/envs/tmx/lib/python3.11/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='mean' instead.\n",
      "  warnings.warn(warning.format(ret))\n"
     ]
    }
   ],
   "source": [
    "conv_order = 2\n",
    "intermediate_channels_all = (16, 16, 16)\n",
    "num_layers = 2\n",
    "out_channels = 1  # num classes\n",
    "\n",
    "model = Network(\n",
    "    in_channels_all=in_channels_all,\n",
    "    hidden_channels_all=intermediate_channels_all,\n",
    "    out_channels=out_channels,\n",
    "    conv_order=conv_order,\n",
    "    max_rank=max_rank,\n",
    "    n_layers=num_layers,\n",
    ")\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n",
    "loss_fn = torch.nn.MSELoss(size_average=True, reduction=\"mean\")\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_size = 0.2\n",
    "val_size = 0.2\n",
    "x_0_train, x_0_test = train_test_split(x_0s, test_size=test_size, shuffle=False)\n",
    "x_0_train, x_0_val = train_test_split(x_0_train, test_size=val_size, shuffle=False)\n",
    "x_1_train, x_1_test = train_test_split(x_1s, test_size=test_size, shuffle=False)\n",
    "x_1_train, x_1_val = train_test_split(x_1_train, test_size=val_size, shuffle=False)\n",
    "x_2_train, x_2_test = train_test_split(x_2s, test_size=test_size, shuffle=False)\n",
    "x_2_train, x_2_val = train_test_split(x_2_train, test_size=val_size, shuffle=False)\n",
    "\n",
    "incidence_1_train, incidence_1_test = train_test_split(\n",
    "    incidence_1_list, test_size=test_size, shuffle=False\n",
    ")\n",
    "incidence_1_train, incidence_1_val = train_test_split(\n",
    "    incidence_1_train, test_size=val_size, shuffle=False\n",
    ")\n",
    "\n",
    "incidence_2_train, incidence_2_test = train_test_split(\n",
    "    incidence_2_list, test_size=test_size, shuffle=False\n",
    ")\n",
    "incidence_2_train, incidence_2_val = train_test_split(\n",
    "    incidence_2_train, test_size=val_size, shuffle=False\n",
    ")\n",
    "\n",
    "evals_0_train, evals_0_test = train_test_split(\n",
    "    evals_0_list, test_size=test_size, shuffle=False\n",
    ")\n",
    "evals_0_train, evals_0_val = train_test_split(\n",
    "    evals_0_train, test_size=val_size, shuffle=False\n",
    ")\n",
    "\n",
    "evecs_0_train, evecs_0_test = train_test_split(\n",
    "    evecs_0_list, test_size=test_size, shuffle=False\n",
    ")\n",
    "evecs_0_train, evecs_0_val = train_test_split(\n",
    "    evecs_0_train, test_size=val_size, shuffle=False\n",
    ")\n",
    "\n",
    "evals_d1_train, evals_d1_test = train_test_split(\n",
    "    evals_d1_list, test_size=test_size, shuffle=False\n",
    ")\n",
    "evals_d1_train, evals_d1_val = train_test_split(\n",
    "    evals_d1_train, test_size=val_size, shuffle=False\n",
    ")\n",
    "\n",
    "evecs_d1_train, evecs_d1_test = train_test_split(\n",
    "    evecs_d1_list, test_size=test_size, shuffle=False\n",
    ")\n",
    "evecs_d1_train, evecs_d1_val = train_test_split(\n",
    "    evecs_d1_train, test_size=val_size, shuffle=False\n",
    ")\n",
    "\n",
    "evals_u1_train, evals_u1_test = train_test_split(\n",
    "    evals_u1_list, test_size=test_size, shuffle=False\n",
    ")\n",
    "evals_u1_train, evals_u1_val = train_test_split(\n",
    "    evals_u1_train, test_size=val_size, shuffle=False\n",
    ")\n",
    "\n",
    "evecs_u1_train, evecs_u1_test = train_test_split(\n",
    "    evecs_u1_list, test_size=test_size, shuffle=False\n",
    ")\n",
    "evecs_u1_train, evecs_u1_val = train_test_split(\n",
    "    evecs_u1_train, test_size=val_size, shuffle=False\n",
    ")\n",
    "\n",
    "evals_2_train, evals_2_test = train_test_split(\n",
    "    evals_2_list, test_size=test_size, shuffle=False\n",
    ")\n",
    "evals_2_train, evals_2_val = train_test_split(\n",
    "    evals_2_train, test_size=val_size, shuffle=False\n",
    ")\n",
    "\n",
    "evecs_2_train, evecs_2_test = train_test_split(\n",
    "    evecs_2_list, test_size=test_size, shuffle=False\n",
    ")\n",
    "evecs_2_train, evecs_2_val = train_test_split(\n",
    "    evecs_2_train, test_size=val_size, shuffle=False\n",
    ")\n",
    "\n",
    "y_train, y_test = train_test_split(ys, test_size=test_size, shuffle=False)\n",
    "y_train, y_val = train_test_split(y_train, test_size=val_size, shuffle=False)\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We train the COSIMO:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/aref/anaconda3/envs/tmx/lib/python3.11/site-packages/torch/nn/modules/loss.py:535: UserWarning: Using a target size (torch.Size([])) that is different to the input size (torch.Size([1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n",
      "  return F.mse_loss(input, target, reduction=self.reduction)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1 loss: 112.6733\n",
      "Val_loss: 1.7076\n",
      "Test_loss-improved: 0.0292\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 2 loss: 99.6715\n",
      "Val_loss: 0.0085\n",
      "Test_loss-improved: 0.0145\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 3 loss: 98.0700\n",
      "Val_loss: 0.0008\n",
      "Test_loss-still: 0.0145\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 4 loss: 93.9155\n",
      "Val_loss: 0.0053\n",
      "Test_loss-still: 0.0145\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 5 loss: 89.9676\n",
      "Val_loss: 0.5926\n",
      "Test_loss-still: 0.0145\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 6 loss: 84.4574\n",
      "Val_loss: 0.1364\n",
      "Test_loss-still: 0.0145\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 7 loss: 82.1039\n",
      "Val_loss: 0.8782\n",
      "Test_loss-still: 0.0145\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 8 loss: 78.3349\n",
      "Val_loss: 0.2830\n",
      "Test_loss-still: 0.0145\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 9 loss: 75.5327\n",
      "Val_loss: 0.0026\n",
      "Test_loss-still: 0.0145\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 10 loss: 71.6482\n",
      "Val_loss: 1.8854\n",
      "Test_loss-still: 0.0145\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 11 loss: 68.1196\n",
      "Val_loss: 11.8188\n",
      "Test_loss-still: 0.0145\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 12 loss: 64.6845\n",
      "Val_loss: 10.3781\n",
      "Test_loss-still: 0.0145\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 13 loss: 62.3612\n",
      "Val_loss: 5.5939\n",
      "Test_loss-still: 0.0145\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 14 loss: 60.6999\n",
      "Val_loss: 4.6781\n",
      "Test_loss-improved: 0.0147\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 15 loss: 59.6399\n",
      "Val_loss: 3.2743\n",
      "Test_loss-improved: 0.0152\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 16 loss: 58.4341\n",
      "Val_loss: 2.3262\n",
      "Test_loss-improved: 0.0153\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 17 loss: 57.4533\n",
      "Val_loss: 0.5617\n",
      "Test_loss-still: 0.0153\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 18 loss: 56.4842\n",
      "Val_loss: 1.6705\n",
      "Test_loss-improved: 0.0158\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 19 loss: 56.0502\n",
      "Val_loss: 0.3739\n",
      "Test_loss-still: 0.0158\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 20 loss: 54.6196\n",
      "Val_loss: 0.1784\n",
      "Test_loss-still: 0.0158\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 21 loss: 53.4491\n",
      "Val_loss: 0.1156\n",
      "Test_loss-improved: 0.0236\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 22 loss: 52.7288\n",
      "Val_loss: 0.3263\n",
      "Test_loss-improved: 0.0214\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 23 loss: 52.2723\n",
      "Val_loss: 0.0008\n",
      "Test_loss-improved: 0.0234\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 24 loss: 51.5909\n",
      "Val_loss: 0.5341\n",
      "Test_loss-still: 0.0234\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 25 loss: 49.6798\n",
      "Val_loss: 0.1868\n",
      "Test_loss-still: 0.0234\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 26 loss: 49.4936\n",
      "Val_loss: 0.0328\n",
      "Test_loss-still: 0.0234\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 27 loss: 49.2586\n",
      "Val_loss: 0.0073\n",
      "Test_loss-improved: 0.0322\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 28 loss: 48.3507\n",
      "Val_loss: 0.5825\n",
      "Test_loss-still: 0.0322\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 29 loss: 47.0523\n",
      "Val_loss: 0.6791\n",
      "Test_loss-still: 0.0322\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 30 loss: 46.5197\n",
      "Val_loss: 0.6428\n",
      "Test_loss-still: 0.0322\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 31 loss: 45.5422\n",
      "Val_loss: 0.6834\n",
      "Test_loss-still: 0.0322\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 32 loss: 44.5762\n",
      "Val_loss: 0.6425\n",
      "Test_loss-still: 0.0322\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 33 loss: 43.6584\n",
      "Val_loss: 0.1921\n",
      "Test_loss-still: 0.0322\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 34 loss: 43.1811\n",
      "Val_loss: 0.0570\n",
      "Test_loss-still: 0.0322\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 35 loss: 42.5115\n",
      "Val_loss: 0.0176\n",
      "Test_loss-still: 0.0322\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 36 loss: 41.7838\n",
      "Val_loss: 0.0059\n",
      "Test_loss-still: 0.0322\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 37 loss: 41.0705\n",
      "Val_loss: 0.0148\n",
      "Test_loss-still: 0.0322\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 38 loss: 40.8265\n",
      "Val_loss: 0.0171\n",
      "Test_loss-still: 0.0322\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 39 loss: 39.9840\n",
      "Val_loss: 0.0246\n",
      "Test_loss-still: 0.0322\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 40 loss: 39.6874\n",
      "Val_loss: 0.0002\n",
      "Test_loss-still: 0.0322\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 41 loss: 39.4049\n",
      "Val_loss: 0.0014\n",
      "Test_loss-still: 0.0322\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 42 loss: 38.7489\n",
      "Val_loss: 0.0599\n",
      "Test_loss-improved: 0.0583\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 43 loss: 38.6883\n",
      "Val_loss: 0.0001\n",
      "Test_loss-improved: 0.0555\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 44 loss: 38.5343\n",
      "Val_loss: 0.0111\n",
      "Test_loss-still: 0.0555\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 45 loss: 37.9144\n",
      "Val_loss: 0.0009\n",
      "Test_loss-improved: 0.0523\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 46 loss: 37.4784\n",
      "Val_loss: 0.0023\n",
      "Test_loss-improved: 0.0448\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 47 loss: 37.7562\n",
      "Val_loss: 0.0187\n",
      "Test_loss-still: 0.0448\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 48 loss: 37.5315\n",
      "Val_loss: 0.0155\n",
      "Test_loss-improved: 0.0414\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 49 loss: 37.0574\n",
      "Val_loss: 0.0269\n",
      "Test_loss-improved: 0.0368\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 50 loss: 36.7746\n",
      "Val_loss: 0.0261\n",
      "Test_loss-improved: 0.0323\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 51 loss: 36.9117\n",
      "Val_loss: 0.0070\n",
      "Test_loss-improved: 0.0290\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 52 loss: 37.0798\n",
      "Val_loss: 0.0438\n",
      "Test_loss-still: 0.0290\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 53 loss: 36.4991\n",
      "Val_loss: 0.0092\n",
      "Test_loss-improved: 0.0243\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 54 loss: 36.6137\n",
      "Val_loss: 0.0837\n",
      "Test_loss-improved: 0.0169\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 55 loss: 37.2442\n",
      "Val_loss: 0.4390\n",
      "Test_loss-still: 0.0169\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 56 loss: 36.1256\n",
      "Val_loss: 0.0852\n",
      "Test_loss-still: 0.0169\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 57 loss: 35.6306\n",
      "Val_loss: 0.0434\n",
      "Test_loss-improved: 0.0112\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 58 loss: 37.0138\n",
      "Val_loss: 0.0010\n",
      "Test_loss-still: 0.0112\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 59 loss: 37.9350\n",
      "Val_loss: 0.0001\n",
      "Test_loss-still: 0.0112\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 60 loss: 37.4904\n",
      "Val_loss: 0.4498\n",
      "Test_loss-still: 0.0112\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 61 loss: 38.8802\n",
      "Val_loss: 7.8845\n",
      "Test_loss-improved: 0.0048\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 62 loss: 39.2482\n",
      "Val_loss: 8.0984\n",
      "Test_loss-improved: 0.0007\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 63 loss: 36.1061\n",
      "Val_loss: 12.3503\n",
      "Test_loss-improved: 0.0011\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 64 loss: 33.7334\n",
      "Val_loss: 9.4226\n",
      "Test_loss-still: 0.0011\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 65 loss: 35.6357\n",
      "Val_loss: 1.8940\n",
      "Test_loss-still: 0.0011\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 66 loss: 39.1576\n",
      "Val_loss: 9.9880\n",
      "Test_loss-still: 0.0011\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 67 loss: 39.2723\n",
      "Val_loss: 4.8045\n",
      "Test_loss-still: 0.0011\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 68 loss: 34.5192\n",
      "Val_loss: 13.5795\n",
      "Test_loss-still: 0.0011\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 69 loss: 36.1382\n",
      "Val_loss: 6.9526\n",
      "Test_loss-still: 0.0011\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 70 loss: 37.0325\n",
      "Val_loss: 6.9939\n",
      "Test_loss-still: 0.0011\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 71 loss: 36.6362\n",
      "Val_loss: 11.2716\n",
      "Test_loss-still: 0.0011\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 72 loss: 35.1571\n",
      "Val_loss: 4.7615\n",
      "Test_loss-still: 0.0011\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 73 loss: 33.3455\n",
      "Val_loss: 1.5646\n",
      "Test_loss-still: 0.0011\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 74 loss: 32.7997\n",
      "Val_loss: 1.6336\n",
      "Test_loss-still: 0.0011\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 75 loss: 32.8301\n",
      "Val_loss: 1.3132\n",
      "Test_loss-still: 0.0011\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 76 loss: 32.7976\n",
      "Val_loss: 2.1932\n",
      "Test_loss-still: 0.0011\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 77 loss: 32.3725\n",
      "Val_loss: 2.4608\n",
      "Test_loss-still: 0.0011\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 78 loss: 31.8872\n",
      "Val_loss: 2.2615\n",
      "Test_loss-still: 0.0011\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 79 loss: 31.4376\n",
      "Val_loss: 1.7302\n",
      "Test_loss-still: 0.0011\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 80 loss: 31.3991\n",
      "Val_loss: 1.6555\n",
      "Test_loss-still: 0.0011\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 81 loss: 31.3966\n",
      "Val_loss: 1.8504\n",
      "Test_loss-still: 0.0011\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 82 loss: 31.3073\n",
      "Val_loss: 1.7223\n",
      "Test_loss-still: 0.0011\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 83 loss: 32.1288\n",
      "Val_loss: 0.5455\n",
      "Test_loss-still: 0.0011\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 84 loss: 32.4564\n",
      "Val_loss: 0.6449\n",
      "Test_loss-still: 0.0011\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 85 loss: 33.9302\n",
      "Val_loss: 2.0138\n",
      "Test_loss-improved: 0.0005\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 86 loss: 33.9861\n",
      "Val_loss: 4.4512\n",
      "Test_loss-improved: 0.0001\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 87 loss: 32.4516\n",
      "Val_loss: 6.4022\n",
      "Test_loss-improved: 0.0004\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 88 loss: 32.2553\n",
      "Val_loss: 2.9958\n",
      "Test_loss-improved: 0.0006\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 89 loss: 30.8015\n",
      "Val_loss: 0.1781\n",
      "Test_loss-improved: 0.0034\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 90 loss: 31.1289\n",
      "Val_loss: 0.2638\n",
      "Test_loss-still: 0.0034\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 91 loss: 32.8725\n",
      "Val_loss: 3.8212\n",
      "Test_loss-still: 0.0034\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 92 loss: 29.3485\n",
      "Val_loss: 2.8180\n",
      "Test_loss-still: 0.0034\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 93 loss: 30.3465\n",
      "Val_loss: 3.7297\n",
      "Test_loss-still: 0.0034\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 94 loss: 32.9074\n",
      "Val_loss: 1.9033\n",
      "Test_loss-still: 0.0034\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 95 loss: 30.1751\n",
      "Val_loss: 3.4258\n",
      "Test_loss-still: 0.0034\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 96 loss: 30.9596\n",
      "Val_loss: 3.8200\n",
      "Test_loss-still: 0.0034\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 97 loss: 30.4347\n",
      "Val_loss: 4.1521\n",
      "Test_loss-still: 0.0034\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 98 loss: 29.5979\n",
      "Val_loss: 3.8765\n",
      "Test_loss-improved: 0.0198\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 99 loss: 32.4256\n",
      "Val_loss: 0.7627\n",
      "Test_loss-improved: 0.0154\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 100 loss: 30.8665\n",
      "Val_loss: 2.6988\n",
      "Test_loss-still: 0.0154\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n"
     ]
    }
   ],
   "source": [
    "test_interval = 1\n",
    "num_epochs = 100\n",
    "Val_loss_Best = float('inf')\n",
    "\n",
    "for epoch_i in range(1, num_epochs + 1):\n",
    "    epoch_loss = []\n",
    "    model.train()\n",
    "    for (\n",
    "        x_0,\n",
    "        x_1,\n",
    "        x_2,\n",
    "        incidence_1,\n",
    "        incidence_2,\n",
    "        evals_0, evecs_0,\n",
    "        evals_d1, evecs_d1,\n",
    "        evals_u1, evecs_u1,\n",
    "        evals_2, evecs_2,\n",
    "        y,\n",
    "    ) in zip(\n",
    "        x_0_train,\n",
    "        x_1_train,\n",
    "        x_2_train,\n",
    "        incidence_1_train,\n",
    "        incidence_2_train,\n",
    "        evals_0_train, evecs_0_train,\n",
    "        evals_d1_train, evecs_d1_train,\n",
    "        evals_u1_train, evecs_u1_train,\n",
    "        evals_2_train, evecs_2_train,\n",
    "        y_train,\n",
    "        strict=False,\n",
    "    ):\n",
    "        x_0 = torch.tensor(x_0)\n",
    "        x_1 = torch.tensor(x_1)\n",
    "        x_2 = torch.tensor(x_2)\n",
    "        y = torch.tensor(y, dtype=torch.float)\n",
    "        optimizer.zero_grad()\n",
    "        x_all = (x_0.float(), x_1.float(), x_2.float())\n",
    "        eig_eiv_all = (evals_0, evecs_0, evals_d1, evecs_d1, evals_u1, evecs_u1, evals_2, evecs_2)\n",
    "        incidence_all = (incidence_1, incidence_2)\n",
    "\n",
    "        y_hat = model(x_all, eig_eiv_all, incidence_all)\n",
    "\n",
    "        # print(y_hat)\n",
    "        loss = loss_fn(y_hat, y)\n",
    "\n",
    "        epoch_loss.append(loss.item())\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "    print(\n",
    "        f\"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f}\",\n",
    "        flush=True,\n",
    "    )\n",
    "    with torch.no_grad():\n",
    "            for (\n",
    "                x_0,\n",
    "                x_1,\n",
    "                x_2,\n",
    "                incidence_1,\n",
    "                incidence_2,\n",
    "                evals_0, evecs_0,\n",
    "                evals_d1, evecs_d1,\n",
    "                evals_u1, evecs_u1,\n",
    "                evals_2, evecs_2,\n",
    "                y,\n",
    "            ) in zip(\n",
    "                x_0_val,\n",
    "                x_1_val,\n",
    "                x_2_val,\n",
    "                incidence_1_val,\n",
    "                incidence_2_val,\n",
    "                evals_0_val, evecs_0_val,\n",
    "                evals_d1_val, evecs_d1_val,\n",
    "                evals_u1_val, evecs_u1_val,\n",
    "                evals_2_val, evecs_2_val,\n",
    "                y_val,\n",
    "                strict=False,\n",
    "            ):\n",
    "                x_0 = torch.tensor(x_0)\n",
    "                x_1 = torch.tensor(x_1)\n",
    "                x_2 = torch.tensor(x_2)\n",
    "                y = torch.tensor(y, dtype=torch.float)\n",
    "                optimizer.zero_grad()\n",
    "                x_all = (x_0.float(), x_1.float(), x_2.float())\n",
    "                eig_eiv_all = (\n",
    "                    evals_0, evecs_0,\n",
    "                    evals_d1, evecs_d1,\n",
    "                    evals_u1, evecs_u1,\n",
    "                    evals_2, evecs_2,\n",
    "                )\n",
    "                incidence_all = (incidence_1, incidence_2)\n",
    "\n",
    "                y_hat = model(x_all, eig_eiv_all, incidence_all)\n",
    "\n",
    "                Val_loss = loss_fn(y_hat, y)\n",
    "            print(f\"Val_loss: {loss:.4f}\", flush=True)\n",
    "            if Val_loss < Val_loss_Best:\n",
    "                for (\n",
    "                    x_0,\n",
    "                    x_1,\n",
    "                    x_2,\n",
    "                    incidence_1,\n",
    "                    incidence_2,\n",
    "                    evals_0, evecs_0,\n",
    "                    evals_d1, evecs_d1,\n",
    "                    evals_u1, evecs_u1,\n",
    "                    evals_2, evecs_2,\n",
    "                    y,\n",
    "                ) in zip(\n",
    "                    x_0_test,\n",
    "                    x_1_test,\n",
    "                    x_2_test,\n",
    "                    incidence_1_test,\n",
    "                    incidence_2_test,\n",
    "                    evals_0_test, evecs_0_test,\n",
    "                    evals_d1_test, evecs_d1_test,\n",
    "                    evals_u1_test, evecs_u1_test,\n",
    "                    evals_2_test, evecs_2_test,\n",
    "                    y_test,\n",
    "                    strict=False,\n",
    "                ):\n",
    "                    x_0 = torch.tensor(x_0)\n",
    "                    x_1 = torch.tensor(x_1)\n",
    "                    x_2 = torch.tensor(x_2)\n",
    "                    y = torch.tensor(y, dtype=torch.float)\n",
    "                    optimizer.zero_grad()\n",
    "                    x_all = (x_0.float(), x_1.float(), x_2.float())\n",
    "                    eig_eiv_all = (\n",
    "                        evals_0, evecs_0,\n",
    "                        evals_d1, evecs_d1,\n",
    "                        evals_u1, evecs_u1,\n",
    "                        evals_2, evecs_2,\n",
    "                    )\n",
    "                    incidence_all = (incidence_1, incidence_2)\n",
    "\n",
    "                    y_hat = model(x_all, eig_eiv_all, incidence_all)\n",
    "\n",
    "                    Test_loss = loss_fn(y_hat, y)/(torch.norm(y,2)**2)\n",
    "                print(f\"Test_loss-improved: {Test_loss:.4f}\", flush=True)\n",
    "                Val_loss_Best = Val_loss\n",
    "            else:\n",
    "                print(f\"Test_loss-still: {Test_loss:.4f}\", flush=True)\n",
    "                \n",
    "            print(\">\"*100)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tmx",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
