{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train a COSIMO\n",
    "\n",
    "### We train the model to perform:\n",
    "    Complex Regression using the shrec16 benchmark dataset.\n",
    "   "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Continuous Simplicial Neural Networks [COSIMO]</a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Complex Regression on the Shrec-16 Dataset: Full Version"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import random\n",
    "import torch\n",
    "# One can define a fixed seed value (we spanned [1,10])\n",
    "SEED = 5\n",
    "\n",
    "# 1. Set the Python built-in random module's seed\n",
    "random.seed(SEED)\n",
    "\n",
    "# 2. Set the NumPy random seed\n",
    "np.random.seed(SEED)\n",
    "\n",
    "# 3. Set the PyTorch seed (for both CPU and GPU)\n",
    "torch.manual_seed(SEED)# Define a fixed seed value\n",
    "print(torch.cuda.is_available())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import toponetx.datasets as datasets\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "from COSIMO import COSIMO\n",
    "from topomodelx.utils.sparse import from_sparse\n",
    "\n",
    "# %load_ext autoreload\n",
    "# %autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import scipy\n",
    "from scipy import sparse\n",
    "\n",
    "def get_evals_evecs(L, k):\n",
    "    L_sparse = sparse.coo_matrix(L)\n",
    "\n",
    "    evals, evecs = scipy.sparse.linalg.eigs(L_sparse, k=k, ncv=4*k, return_eigenvectors=True)\n",
    "    # evals, evecs = scipy.linalg.eig(L)\n",
    "\n",
    "    evals=torch.tensor(evals.real)\n",
    "    evecs=torch.tensor(evecs.real)\n",
    "\n",
    "    return evals, evecs "
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pre-processing\n",
    "\n",
    "### Import shrec dataset ##\n",
    "\n",
    "We must first lift our graph dataset into the simplicial complex domain."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading shrec 16 full dataset...\n",
      "\n",
      "done!\n"
     ]
    }
   ],
   "source": [
    "shrec, _ = datasets.mesh.shrec_16(size=\"full\")\n",
    "shrec = {key: np.array(value) for key, value in shrec.items()}\n",
    "x_0s = shrec[\"node_feat\"]\n",
    "x_1s = shrec[\"edge_feat\"]\n",
    "x_2s = shrec[\"face_feat\"]\n",
    "\n",
    "ys = shrec[\"label\"]\n",
    "simplexes = shrec[\"complexes\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(6, 10, 7)\n"
     ]
    }
   ],
   "source": [
    "in_channels_0 = x_0s[-1].shape[1]\n",
    "in_channels_1 = x_1s[-1].shape[1]\n",
    "in_channels_2 = x_2s[-1].shape[1]\n",
    "\n",
    "in_channels_all = (in_channels_0, in_channels_1, in_channels_2)\n",
    "print(in_channels_all)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define Neighborhood Strctures\n",
    "Get incidence matrices $\\mathbf{B}_1,\\mathbf{B}_2$ and Hodge Laplacians $\\mathbf{L}_0, \\mathbf{L}_1$ and $\\mathbf{L}_2$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_rank = 2  # the order of the SC is two\n",
    "incidence_1_list = []\n",
    "incidence_2_list = []\n",
    "\n",
    "\n",
    "kk_0 = 10\n",
    "kk_1 = 10\n",
    "kk_2 = 10\n",
    "evals_0_list = []\n",
    "evecs_0_list = []\n",
    "evals_d1_list = []\n",
    "evecs_d1_list = []\n",
    "evals_u1_list = []\n",
    "evecs_u1_list = []\n",
    "evals_2_list = []\n",
    "evecs_2_list = []\n",
    "\n",
    "for simplex in simplexes:\n",
    "    incidence_1 = simplex.incidence_matrix(rank=1)\n",
    "    incidence_2 = simplex.incidence_matrix(rank=2)\n",
    "    laplacian_0 = simplex.hodge_laplacian_matrix(rank=0)\n",
    "    laplacian_down_1 = simplex.down_laplacian_matrix(rank=1)\n",
    "    laplacian_up_1 = simplex.up_laplacian_matrix(rank=1)\n",
    "    laplacian_2 = simplex.hodge_laplacian_matrix(rank=2)\n",
    "\n",
    "    incidence_1 = from_sparse(incidence_1)\n",
    "    incidence_2 = from_sparse(incidence_2)\n",
    "    evals_0, evecs_0 = get_evals_evecs(laplacian_0, kk_0)\n",
    "    evals_d1, evecs_d1 = get_evals_evecs(laplacian_down_1, kk_1)\n",
    "    evals_u1, evecs_u1 = get_evals_evecs(laplacian_up_1, kk_1)\n",
    "    evals_2, evecs_2 = get_evals_evecs(laplacian_2, kk_2)\n",
    "\n",
    "    incidence_1_list.append(incidence_1)\n",
    "    incidence_2_list.append(incidence_2)\n",
    "    evals_0_list.append(evals_0)\n",
    "    evecs_0_list.append(evecs_0)\n",
    "    evals_d1_list.append(evals_d1)\n",
    "    evecs_d1_list.append(evecs_d1)\n",
    "    evals_u1_list.append(evals_u1)\n",
    "    evecs_u1_list.append(evecs_u1)\n",
    "    evals_2_list.append(evals_2)\n",
    "    evecs_2_list.append(evecs_2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "252\n",
      "750\n",
      "500\n"
     ]
    }
   ],
   "source": [
    "print(laplacian_0.shape[0])\n",
    "print(laplacian_down_1.shape[0])\n",
    "print(laplacian_2.shape[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create and Train the Neural Network\n",
    "\n",
    "We specify the model with our pre-made neighborhood structures and specify an optimizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Network(torch.nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        in_channels_all,\n",
    "        hidden_channels_all,\n",
    "        out_channels,\n",
    "        conv_order,\n",
    "        max_rank,\n",
    "        n_layers=2,\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.base_model = COSIMO(\n",
    "            in_channels_all=in_channels_all,\n",
    "            hidden_channels_all=hidden_channels_all,\n",
    "            conv_order=conv_order,\n",
    "            sc_order=max_rank,\n",
    "            n_layers=n_layers,\n",
    "        )\n",
    "        out_channels_0, out_channels_1, out_channels_2 = hidden_channels_all\n",
    "        self.out_linear_0 = torch.nn.Linear(out_channels_0, out_channels)\n",
    "        self.out_linear_1 = torch.nn.Linear(out_channels_1, out_channels)\n",
    "        self.out_linear_2 = torch.nn.Linear(out_channels_2, out_channels)\n",
    "\n",
    "    def forward(self, x_all, eig_eiv_all, incidence_all):\n",
    "        x_all = self.base_model(x_all, eig_eiv_all, incidence_all)\n",
    "        x_0, x_1, x_2 = x_all\n",
    "\n",
    "        x_0 = self.out_linear_0(x_0)\n",
    "        x_1 = self.out_linear_1(x_1)\n",
    "        x_2 = self.out_linear_2(x_2)\n",
    "\n",
    "        # Take the average of the 2D, 1D, and 0D cell features. If they are NaN, convert them to 0.\n",
    "        two_dimensional_cells_mean = torch.nanmean(x_2, dim=0)\n",
    "        two_dimensional_cells_mean[torch.isnan(two_dimensional_cells_mean)] = 0\n",
    "        one_dimensional_cells_mean = torch.nanmean(x_1, dim=0)\n",
    "        one_dimensional_cells_mean[torch.isnan(one_dimensional_cells_mean)] = 0\n",
    "        zero_dimensional_cells_mean = torch.nanmean(x_0, dim=0)\n",
    "        zero_dimensional_cells_mean[torch.isnan(zero_dimensional_cells_mean)] = 0\n",
    "        # Return the sum of the averages\n",
    "        return (\n",
    "            two_dimensional_cells_mean\n",
    "            + one_dimensional_cells_mean\n",
    "            + zero_dimensional_cells_mean\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Network(\n",
      "  (base_model): COSIMO(\n",
      "    (in_linear_0): Linear(in_features=6, out_features=16, bias=True)\n",
      "    (in_linear_1): Linear(in_features=10, out_features=16, bias=True)\n",
      "    (in_linear_2): Linear(in_features=7, out_features=16, bias=True)\n",
      "    (layers): ModuleList(\n",
      "      (0-1): 2 x COSIMO_Layer_exp(\n",
      "        (diff_derivative_00): Time_derivative_diffusion()\n",
      "        (diff_derivative_10): Time_derivative_diffusion()\n",
      "        (diff_derivative_d1): Time_derivative_diffusion()\n",
      "        (diff_derivative_u1): Time_derivative_diffusion()\n",
      "        (diff_derivative_01): Time_derivative_diffusion()\n",
      "        (diff_derivative_21): Time_derivative_diffusion()\n",
      "        (diff_derivative_2): Time_derivative_diffusion()\n",
      "        (diff_derivative_d2): Time_derivative_diffusion()\n",
      "        (diff_derivative_u2): Time_derivative_diffusion()\n",
      "        (diff_derivative_12): Time_derivative_diffusion()\n",
      "      )\n",
      "    )\n",
      "  )\n",
      "  (out_linear_0): Linear(in_features=16, out_features=1, bias=True)\n",
      "  (out_linear_1): Linear(in_features=16, out_features=1, bias=True)\n",
      "  (out_linear_2): Linear(in_features=16, out_features=1, bias=True)\n",
      ")\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/aref/anaconda3/envs/tmx/lib/python3.11/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='mean' instead.\n",
      "  warnings.warn(warning.format(ret))\n"
     ]
    }
   ],
   "source": [
    "conv_order = 2\n",
    "intermediate_channels_all = (16, 16, 16)\n",
    "num_layers = 2\n",
    "out_channels = 1  # num classes\n",
    "\n",
    "model = Network(\n",
    "    in_channels_all=in_channels_all,\n",
    "    hidden_channels_all=intermediate_channels_all,\n",
    "    out_channels=out_channels,\n",
    "    conv_order=conv_order,\n",
    "    max_rank=max_rank,\n",
    "    n_layers=num_layers,\n",
    ")\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n",
    "loss_fn = torch.nn.MSELoss(size_average=True, reduction=\"mean\")\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_size = 0.2\n",
    "val_size = 0.2\n",
    "x_0_train, x_0_test = train_test_split(x_0s, test_size=test_size, shuffle=False)\n",
    "x_0_train, x_0_val = train_test_split(x_0_train, test_size=val_size, shuffle=False)\n",
    "x_1_train, x_1_test = train_test_split(x_1s, test_size=test_size, shuffle=False)\n",
    "x_1_train, x_1_val = train_test_split(x_1_train, test_size=val_size, shuffle=False)\n",
    "x_2_train, x_2_test = train_test_split(x_2s, test_size=test_size, shuffle=False)\n",
    "x_2_train, x_2_val = train_test_split(x_2_train, test_size=val_size, shuffle=False)\n",
    "\n",
    "incidence_1_train, incidence_1_test = train_test_split(\n",
    "    incidence_1_list, test_size=test_size, shuffle=False\n",
    ")\n",
    "incidence_1_train, incidence_1_val = train_test_split(\n",
    "    incidence_1_train, test_size=val_size, shuffle=False\n",
    ")\n",
    "\n",
    "incidence_2_train, incidence_2_test = train_test_split(\n",
    "    incidence_2_list, test_size=test_size, shuffle=False\n",
    ")\n",
    "incidence_2_train, incidence_2_val = train_test_split(\n",
    "    incidence_2_train, test_size=val_size, shuffle=False\n",
    ")\n",
    "\n",
    "evals_0_train, evals_0_test = train_test_split(\n",
    "    evals_0_list, test_size=test_size, shuffle=False\n",
    ")\n",
    "evals_0_train, evals_0_val = train_test_split(\n",
    "    evals_0_train, test_size=val_size, shuffle=False\n",
    ")\n",
    "\n",
    "evecs_0_train, evecs_0_test = train_test_split(\n",
    "    evecs_0_list, test_size=test_size, shuffle=False\n",
    ")\n",
    "evecs_0_train, evecs_0_val = train_test_split(\n",
    "    evecs_0_train, test_size=val_size, shuffle=False\n",
    ")\n",
    "\n",
    "evals_d1_train, evals_d1_test = train_test_split(\n",
    "    evals_d1_list, test_size=test_size, shuffle=False\n",
    ")\n",
    "evals_d1_train, evals_d1_val = train_test_split(\n",
    "    evals_d1_train, test_size=val_size, shuffle=False\n",
    ")\n",
    "\n",
    "evecs_d1_train, evecs_d1_test = train_test_split(\n",
    "    evecs_d1_list, test_size=test_size, shuffle=False\n",
    ")\n",
    "evecs_d1_train, evecs_d1_val = train_test_split(\n",
    "    evecs_d1_train, test_size=val_size, shuffle=False\n",
    ")\n",
    "\n",
    "evals_u1_train, evals_u1_test = train_test_split(\n",
    "    evals_u1_list, test_size=test_size, shuffle=False\n",
    ")\n",
    "evals_u1_train, evals_u1_val = train_test_split(\n",
    "    evals_u1_train, test_size=val_size, shuffle=False\n",
    ")\n",
    "\n",
    "evecs_u1_train, evecs_u1_test = train_test_split(\n",
    "    evecs_u1_list, test_size=test_size, shuffle=False\n",
    ")\n",
    "evecs_u1_train, evecs_u1_val = train_test_split(\n",
    "    evecs_u1_train, test_size=val_size, shuffle=False\n",
    ")\n",
    "\n",
    "evals_2_train, evals_2_test = train_test_split(\n",
    "    evals_2_list, test_size=test_size, shuffle=False\n",
    ")\n",
    "evals_2_train, evals_2_val = train_test_split(\n",
    "    evals_2_train, test_size=val_size, shuffle=False\n",
    ")\n",
    "\n",
    "evecs_2_train, evecs_2_test = train_test_split(\n",
    "    evecs_2_list, test_size=test_size, shuffle=False\n",
    ")\n",
    "evecs_2_train, evecs_2_val = train_test_split(\n",
    "    evecs_2_train, test_size=val_size, shuffle=False\n",
    ")\n",
    "\n",
    "y_train, y_test = train_test_split(ys, test_size=test_size, shuffle=False)\n",
    "y_train, y_val = train_test_split(y_train, test_size=val_size, shuffle=False)\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We train the COSIMO:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/aref/anaconda3/envs/tmx/lib/python3.11/site-packages/torch/nn/modules/loss.py:535: UserWarning: Using a target size (torch.Size([])) that is different to the input size (torch.Size([1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n",
      "  return F.mse_loss(input, target, reduction=self.reduction)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1 loss: 0.6461\n",
      "Val_loss: 0.2544\n",
      "Test_loss-improved: 0.0936\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 2 loss: 4.0011\n",
      "Val_loss: 0.4924\n",
      "Test_loss-still: 0.0936\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 3 loss: 3.2776\n",
      "Val_loss: 1.9659\n",
      "Test_loss-still: 0.0936\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 4 loss: 2.4954\n",
      "Val_loss: 2.1877\n",
      "Test_loss-improved: 0.0756\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 5 loss: 2.7939\n",
      "Val_loss: 1.2523\n",
      "Test_loss-still: 0.0756\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 6 loss: 2.6561\n",
      "Val_loss: 2.0896\n",
      "Test_loss-still: 0.0756\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 7 loss: 2.4616\n",
      "Val_loss: 8.3687\n",
      "Test_loss-improved: 0.0685\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 8 loss: 2.9309\n",
      "Val_loss: 4.7350\n",
      "Test_loss-still: 0.0685\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 9 loss: 3.0539\n",
      "Val_loss: 2.4632\n",
      "Test_loss-still: 0.0685\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 10 loss: 3.1491\n",
      "Val_loss: 3.7103\n",
      "Test_loss-improved: 0.0397\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 11 loss: 3.0098\n",
      "Val_loss: 3.4495\n",
      "Test_loss-still: 0.0397\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 12 loss: 2.6350\n",
      "Val_loss: 3.4933\n",
      "Test_loss-still: 0.0397\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 13 loss: 2.9500\n",
      "Val_loss: 6.8979\n",
      "Test_loss-still: 0.0397\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 14 loss: 3.0485\n",
      "Val_loss: 3.3416\n",
      "Test_loss-still: 0.0397\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 15 loss: 3.0432\n",
      "Val_loss: 4.7525\n",
      "Test_loss-still: 0.0397\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 16 loss: 2.4240\n",
      "Val_loss: 5.5288\n",
      "Test_loss-still: 0.0397\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 17 loss: 2.4512\n",
      "Val_loss: 6.2478\n",
      "Test_loss-still: 0.0397\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 18 loss: 2.6431\n",
      "Val_loss: 6.8130\n",
      "Test_loss-still: 0.0397\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 19 loss: 3.3818\n",
      "Val_loss: 6.1110\n",
      "Test_loss-improved: 0.0341\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 20 loss: 3.7247\n",
      "Val_loss: 6.9157\n",
      "Test_loss-improved: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 21 loss: 3.1384\n",
      "Val_loss: 3.9827\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 22 loss: 3.2525\n",
      "Val_loss: 5.6460\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 23 loss: 2.7550\n",
      "Val_loss: 2.8103\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 24 loss: 2.4461\n",
      "Val_loss: 4.9623\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 25 loss: 2.7608\n",
      "Val_loss: 9.6845\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 26 loss: 3.2348\n",
      "Val_loss: 8.9687\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 27 loss: 3.2933\n",
      "Val_loss: 8.1289\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 28 loss: 3.3622\n",
      "Val_loss: 5.0633\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 29 loss: 3.0146\n",
      "Val_loss: 7.0989\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 30 loss: 2.9178\n",
      "Val_loss: 4.0633\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 31 loss: 3.1370\n",
      "Val_loss: 9.8032\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 32 loss: 3.1946\n",
      "Val_loss: 4.2757\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 33 loss: 3.0266\n",
      "Val_loss: 7.1174\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 34 loss: 3.2315\n",
      "Val_loss: 4.4233\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 35 loss: 3.2978\n",
      "Val_loss: 7.3669\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 36 loss: 3.5156\n",
      "Val_loss: 6.8707\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 37 loss: 3.0461\n",
      "Val_loss: 8.4753\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 38 loss: 3.1088\n",
      "Val_loss: 9.5842\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 39 loss: 3.3194\n",
      "Val_loss: 7.3592\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 40 loss: 3.2130\n",
      "Val_loss: 5.8845\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 41 loss: 3.0744\n",
      "Val_loss: 7.0225\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 42 loss: 3.1776\n",
      "Val_loss: 7.9910\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 43 loss: 3.3954\n",
      "Val_loss: 6.8615\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 44 loss: 3.5696\n",
      "Val_loss: 3.3913\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 45 loss: 3.1750\n",
      "Val_loss: 9.8468\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 46 loss: 2.8179\n",
      "Val_loss: 6.1538\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 47 loss: 3.4671\n",
      "Val_loss: 8.5905\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 48 loss: 3.1088\n",
      "Val_loss: 6.4883\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 49 loss: 2.7647\n",
      "Val_loss: 9.5225\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 50 loss: 3.0053\n",
      "Val_loss: 8.0610\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 51 loss: 3.4745\n",
      "Val_loss: 3.4761\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 52 loss: 3.4282\n",
      "Val_loss: 8.1520\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 53 loss: 3.4680\n",
      "Val_loss: 5.6668\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 54 loss: 3.3746\n",
      "Val_loss: 8.8392\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 55 loss: 3.2692\n",
      "Val_loss: 7.8017\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 56 loss: 3.1790\n",
      "Val_loss: 7.4090\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 57 loss: 3.1758\n",
      "Val_loss: 6.9531\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 58 loss: 3.2768\n",
      "Val_loss: 8.6026\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 59 loss: 3.1534\n",
      "Val_loss: 5.2551\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 60 loss: 2.9224\n",
      "Val_loss: 6.1335\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 61 loss: 2.9524\n",
      "Val_loss: 3.5539\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 62 loss: 3.6621\n",
      "Val_loss: 5.3522\n",
      "Test_loss-still: 0.0343\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 63 loss: 3.5418\n",
      "Val_loss: 3.0890\n",
      "Test_loss-improved: 0.0238\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 64 loss: 3.1683\n",
      "Val_loss: 5.1600\n",
      "Test_loss-still: 0.0238\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 65 loss: 3.0678\n",
      "Val_loss: 5.2393\n",
      "Test_loss-still: 0.0238\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 66 loss: 3.3334\n",
      "Val_loss: 8.9621\n",
      "Test_loss-still: 0.0238\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 67 loss: 3.1180\n",
      "Val_loss: 5.2394\n",
      "Test_loss-still: 0.0238\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 68 loss: 3.0192\n",
      "Val_loss: 2.5331\n",
      "Test_loss-still: 0.0238\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 69 loss: 2.9688\n",
      "Val_loss: 8.4871\n",
      "Test_loss-still: 0.0238\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 70 loss: 3.3310\n",
      "Val_loss: 6.9176\n",
      "Test_loss-still: 0.0238\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 71 loss: 2.9922\n",
      "Val_loss: 7.8359\n",
      "Test_loss-still: 0.0238\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 72 loss: 3.1533\n",
      "Val_loss: 6.4589\n",
      "Test_loss-still: 0.0238\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 73 loss: 2.9852\n",
      "Val_loss: 7.5691\n",
      "Test_loss-still: 0.0238\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 74 loss: 3.1303\n",
      "Val_loss: 10.6083\n",
      "Test_loss-still: 0.0238\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 75 loss: 3.3934\n",
      "Val_loss: 5.0665\n",
      "Test_loss-still: 0.0238\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 76 loss: 3.2035\n",
      "Val_loss: 7.5824\n",
      "Test_loss-still: 0.0238\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 77 loss: 3.2779\n",
      "Val_loss: 9.4619\n",
      "Test_loss-still: 0.0238\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 78 loss: 3.3205\n",
      "Val_loss: 5.6420\n",
      "Test_loss-still: 0.0238\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 79 loss: 3.4431\n",
      "Val_loss: 7.3482\n",
      "Test_loss-still: 0.0238\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 80 loss: 3.4739\n",
      "Val_loss: 5.9536\n",
      "Test_loss-still: 0.0238\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 81 loss: 3.2747\n",
      "Val_loss: 7.9884\n",
      "Test_loss-still: 0.0238\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 82 loss: 3.1383\n",
      "Val_loss: 8.9659\n",
      "Test_loss-still: 0.0238\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 83 loss: 2.7256\n",
      "Val_loss: 8.4524\n",
      "Test_loss-still: 0.0238\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 84 loss: 2.9663\n",
      "Val_loss: 6.4258\n",
      "Test_loss-still: 0.0238\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 85 loss: 3.1572\n",
      "Val_loss: 10.3378\n",
      "Test_loss-still: 0.0238\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 86 loss: 3.3724\n",
      "Val_loss: 12.2056\n",
      "Test_loss-still: 0.0238\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 87 loss: 3.0338\n",
      "Val_loss: 5.7402\n",
      "Test_loss-still: 0.0238\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 88 loss: 3.1651\n",
      "Val_loss: 7.1348\n",
      "Test_loss-still: 0.0238\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 89 loss: 3.6704\n",
      "Val_loss: 4.5395\n",
      "Test_loss-still: 0.0238\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 90 loss: 3.6040\n",
      "Val_loss: 8.8042\n",
      "Test_loss-still: 0.0238\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 91 loss: 2.9558\n",
      "Val_loss: 9.1890\n",
      "Test_loss-still: 0.0238\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 92 loss: 3.3881\n",
      "Val_loss: 9.4795\n",
      "Test_loss-still: 0.0238\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 93 loss: 3.0557\n",
      "Val_loss: 7.9044\n",
      "Test_loss-still: 0.0238\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 94 loss: 3.1414\n",
      "Val_loss: 4.9878\n",
      "Test_loss-still: 0.0238\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 95 loss: 3.2951\n",
      "Val_loss: 9.0626\n",
      "Test_loss-improved: 0.0260\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 96 loss: 3.3261\n",
      "Val_loss: 8.9716\n",
      "Test_loss-still: 0.0260\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 97 loss: 3.2751\n",
      "Val_loss: 7.0408\n",
      "Test_loss-still: 0.0260\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 98 loss: 3.5601\n",
      "Val_loss: 3.0570\n",
      "Test_loss-still: 0.0260\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 99 loss: 3.3377\n",
      "Val_loss: 5.6261\n",
      "Test_loss-still: 0.0260\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "Epoch: 100 loss: 3.5533\n",
      "Val_loss: 7.3089\n",
      "Test_loss-still: 0.0260\n",
      ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n"
     ]
    }
   ],
   "source": [
    "test_interval = 1\n",
    "num_epochs = 100\n",
    "Val_loss_Best = float('inf')\n",
    "\n",
    "for epoch_i in range(1, num_epochs + 1):\n",
    "    epoch_loss = []\n",
    "    model.train()\n",
    "    for (\n",
    "        x_0,\n",
    "        x_1,\n",
    "        x_2,\n",
    "        incidence_1,\n",
    "        incidence_2,\n",
    "        evals_0, evecs_0,\n",
    "        evals_d1, evecs_d1,\n",
    "        evals_u1, evecs_u1,\n",
    "        evals_2, evecs_2,\n",
    "        y,\n",
    "    ) in zip(\n",
    "        x_0_train,\n",
    "        x_1_train,\n",
    "        x_2_train,\n",
    "        incidence_1_train,\n",
    "        incidence_2_train,\n",
    "        evals_0_train, evecs_0_train,\n",
    "        evals_d1_train, evecs_d1_train,\n",
    "        evals_u1_train, evecs_u1_train,\n",
    "        evals_2_train, evecs_2_train,\n",
    "        y_train,\n",
    "        strict=False,\n",
    "    ):\n",
    "        x_0 = torch.tensor(x_0)\n",
    "        x_1 = torch.tensor(x_1)\n",
    "        x_2 = torch.tensor(x_2)\n",
    "        y = torch.tensor(y, dtype=torch.float)\n",
    "        optimizer.zero_grad()\n",
    "        x_all = (x_0.float(), x_1.float(), x_2.float())\n",
    "        eig_eiv_all = (evals_0, evecs_0, evals_d1, evecs_d1, evals_u1, evecs_u1, evals_2, evecs_2)\n",
    "        incidence_all = (incidence_1, incidence_2)\n",
    "\n",
    "        y_hat = model(x_all, eig_eiv_all, incidence_all)\n",
    "\n",
    "        # print(y_hat)\n",
    "        loss = loss_fn(y_hat, y)\n",
    "\n",
    "        epoch_loss.append(loss.item())\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "    print(\n",
    "        f\"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f}\",\n",
    "        flush=True,\n",
    "    )\n",
    "    with torch.no_grad():\n",
    "            for (\n",
    "                x_0,\n",
    "                x_1,\n",
    "                x_2,\n",
    "                incidence_1,\n",
    "                incidence_2,\n",
    "                evals_0, evecs_0,\n",
    "                evals_d1, evecs_d1,\n",
    "                evals_u1, evecs_u1,\n",
    "                evals_2, evecs_2,\n",
    "                y,\n",
    "            ) in zip(\n",
    "                x_0_val,\n",
    "                x_1_val,\n",
    "                x_2_val,\n",
    "                incidence_1_val,\n",
    "                incidence_2_val,\n",
    "                evals_0_val, evecs_0_val,\n",
    "                evals_d1_val, evecs_d1_val,\n",
    "                evals_u1_val, evecs_u1_val,\n",
    "                evals_2_val, evecs_2_val,\n",
    "                y_val,\n",
    "                strict=False,\n",
    "            ):\n",
    "                x_0 = torch.tensor(x_0)\n",
    "                x_1 = torch.tensor(x_1)\n",
    "                x_2 = torch.tensor(x_2)\n",
    "                y = torch.tensor(y, dtype=torch.float)\n",
    "                optimizer.zero_grad()\n",
    "                x_all = (x_0.float(), x_1.float(), x_2.float())\n",
    "                eig_eiv_all = (\n",
    "                    evals_0, evecs_0,\n",
    "                    evals_d1, evecs_d1,\n",
    "                    evals_u1, evecs_u1,\n",
    "                    evals_2, evecs_2,\n",
    "                )\n",
    "                incidence_all = (incidence_1, incidence_2)\n",
    "\n",
    "                y_hat = model(x_all, eig_eiv_all, incidence_all)\n",
    "\n",
    "                Val_loss = loss_fn(y_hat, y)\n",
    "            print(f\"Val_loss: {loss:.4f}\", flush=True)\n",
    "            if Val_loss < Val_loss_Best:\n",
    "                for (\n",
    "                    x_0,\n",
    "                    x_1,\n",
    "                    x_2,\n",
    "                    incidence_1,\n",
    "                    incidence_2,\n",
    "                    evals_0, evecs_0,\n",
    "                    evals_d1, evecs_d1,\n",
    "                    evals_u1, evecs_u1,\n",
    "                    evals_2, evecs_2,\n",
    "                    y,\n",
    "                ) in zip(\n",
    "                    x_0_test,\n",
    "                    x_1_test,\n",
    "                    x_2_test,\n",
    "                    incidence_1_test,\n",
    "                    incidence_2_test,\n",
    "                    evals_0_test, evecs_0_test,\n",
    "                    evals_d1_test, evecs_d1_test,\n",
    "                    evals_u1_test, evecs_u1_test,\n",
    "                    evals_2_test, evecs_2_test,\n",
    "                    y_test,\n",
    "                    strict=False,\n",
    "                ):\n",
    "                    x_0 = torch.tensor(x_0)\n",
    "                    x_1 = torch.tensor(x_1)\n",
    "                    x_2 = torch.tensor(x_2)\n",
    "                    y = torch.tensor(y, dtype=torch.float)\n",
    "                    optimizer.zero_grad()\n",
    "                    x_all = (x_0.float(), x_1.float(), x_2.float())\n",
    "                    eig_eiv_all = (\n",
    "                        evals_0, evecs_0,\n",
    "                        evals_d1, evecs_d1,\n",
    "                        evals_u1, evecs_u1,\n",
    "                        evals_2, evecs_2,\n",
    "                    )\n",
    "                    incidence_all = (incidence_1, incidence_2)\n",
    "\n",
    "                    y_hat = model(x_all, eig_eiv_all, incidence_all)\n",
    "\n",
    "                    Test_loss = loss_fn(y_hat, y)/(torch.norm(y,2)**2)\n",
    "                print(f\"Test_loss-improved: {Test_loss:.4f}\", flush=True)\n",
    "                Val_loss_Best = Val_loss\n",
    "            else:\n",
    "                print(f\"Test_loss-still: {Test_loss:.4f}\", flush=True)\n",
    "                \n",
    "            print(\">\"*100)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tmx",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
