{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train a COSIMO\n",
    "\n",
    "   \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Continuous Simplicial Neural Networks (COSIMO) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Node Classification: Runned on the cluster A100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import random\n",
    "import torch\n",
    "# Define a fixed seed value\n",
    "SEED = 1 # We selected the SEEDs in [1,10]\n",
    "test_size = 0.2\n",
    "val_size = 0.2 \n",
    "num_filters = 32\n",
    "num_layers = 4\n",
    "kk = 10\n",
    "kk2 = 10\n",
    "lr = 1e-2\n",
    "num_epochs = 100\n",
    "max_rank = 4\n",
    "\n",
    "# 1. Set the Python built-in random module's seed\n",
    "random.seed(SEED)\n",
    "\n",
    "# 2. Set the NumPy random seed\n",
    "np.random.seed(SEED)\n",
    "\n",
    "# 3. Set the PyTorch seed (for both CPU and GPU)\n",
    "torch.manual_seed(SEED)# Define a fixed seed value\n",
    "print(torch.cuda.is_available())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import toponetx.datasets as datasets\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "from sccnn_exp import COSIMO\n",
    "from topomodelx.utils.sparse import from_sparse\n",
    "\n",
    "# %load_ext autoreload\n",
    "# %autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import scipy\n",
    "from scipy import sparse\n",
    "\n",
    "def get_evals_evecs(L, k):\n",
    "    L_sparse = sparse.coo_matrix(L)\n",
    "\n",
    "    evals, evecs = scipy.sparse.linalg.eigs(L_sparse, k=k, ncv=4*k, return_eigenvectors=True)\n",
    "    # evals, evecs = scipy.linalg.eig(L)\n",
    "\n",
    "    evals=torch.tensor(evals.real)\n",
    "    evecs=torch.tensor(evecs.real)\n",
    "\n",
    "    return evals, evecs "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Pre-processing\n",
    "\n",
    "## Import dataset ##"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from preprocessing.simplicial_construction import get_boundary_matrices, get_neighbors, get_weight_matrix_graph, get_weight_matrix_simplex,generate_triangles,_get_laplacians,_get_simplex_features,augment_simplex,augment_simplex_open\n",
    "import argparse\n",
    "from preprocessing.graph_construction import _get_graph\n",
    "import torch\n",
    "import networkx as nx\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "parser = argparse.ArgumentParser(description='TopoSRL')\n",
    "\n",
    "parser.add_argument('--dataname', type=str, default='senate-bills', help='Name of dataset.')\n",
    "parser.add_argument('--gpu', type=int, default=0, help='GPU index.')\n",
    "parser.add_argument('--dim', type=int, default=4, help='Order of the simplicial complex.')\n",
    "\n",
    "args = parser.parse_args(args=[])\n",
    "\n",
    "if args.gpu != -1 and torch.cuda.is_available():\n",
    "    # args.device = 'cuda:{}'.format(args.gpu)\n",
    "    args.device = 'cuda'\n",
    "else:\n",
    "    args.device = 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Read the labels\n",
      "Reading the simplicies\n",
      "Creating tree\n",
      "Computing boundary matrix for dimension 1\n",
      "Computing boundary matrix for dimension 2\n",
      "Computing boundary matrix for dimension 3\n",
      "Computing boundary matrix for dimension 4\n",
      "Got boundaries\n",
      "Anchor nodes initialized\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ids/einizade/anaconda3/envs/tmx/lib/python3.11/site-packages/dgl/backend/pytorch/tensor.py:352: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n",
      "  assert input.numel() == input.storage().size(), \"Cannot convert view \" \\\n",
      "/tmp/ipykernel_2323687/1129562614.py:5: FutureWarning: adjacency_matrix will return a scipy.sparse array instead of a matrix in Networkx 3.0.\n",
      "  A = nx.adjacency_matrix(netxG).todense()\n"
     ]
    }
   ],
   "source": [
    "simplex_tree, sc, boundry_matrices, labels =  get_boundary_matrices(args.dataname, args.dim)\n",
    "print(\"Got boundaries\")\n",
    "g, netxG = _get_graph(sc[1])\n",
    "g = g.to(args.device)\n",
    "A = nx.adjacency_matrix(netxG).todense()\n",
    "sm = torch.nn.Softmax(dim=1)\n",
    "W0 = get_weight_matrix_graph(A)\n",
    "W0 = sm(torch.FloatTensor(W0).to(args.device))\n",
    "W0 = W0 * (W0!=W0.min(axis=1).values.unsqueeze(-1))\n",
    "laplacians, lower_laplacians, upper_laplacians = _get_laplacians(boundry_matrices)\n",
    "_X = _get_simplex_features(sc[1:4], g.ndata['features'])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define Neighborhood Strctures\n",
    "Get incidence matrices $\\mathbf{B}_1,\\mathbf{B}_2$ and Hodge Laplacians $\\mathbf{L}_0, \\mathbf{L}_1$ and $\\mathbf{L}_2$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The incidence matrix B1 has shape: (294, 12942).\n",
      "The incidence matrix B2 has shape: (12942, 24939).\n"
     ]
    }
   ],
   "source": [
    "incidence_1 = boundry_matrices[0].cpu().detach().numpy()\n",
    "incidence_2 = boundry_matrices[1].cpu().detach().numpy()\n",
    "\n",
    "print(f\"The incidence matrix B1 has shape: {incidence_1.shape}.\")\n",
    "print(f\"The incidence matrix B2 has shape: {incidence_2.shape}.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([294, 294])\n",
      "torch.Size([12942, 12942])\n",
      "torch.Size([12942, 12942])\n",
      "torch.Size([24939, 24939])\n",
      "torch.Size([24939, 24939])\n"
     ]
    }
   ],
   "source": [
    "laplacian_0 = laplacians[0]\n",
    "laplacian_down_1 = lower_laplacians[1]\n",
    "laplacian_up_1 = upper_laplacians[1]\n",
    "laplacian_down_2 = lower_laplacians[2]\n",
    "laplacian_up_2 = upper_laplacians[2]\n",
    "\n",
    "print(laplacian_0.shape)\n",
    "print(laplacian_down_1.shape)\n",
    "print(laplacian_up_1.shape)\n",
    "print(laplacian_down_2.shape)\n",
    "print(laplacian_up_2.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "evals_0, evecs_0 = get_evals_evecs(laplacian_0.cpu().detach(), kk)\n",
    "evals_d1, evecs_d1 = get_evals_evecs(laplacian_down_1.cpu().detach(), kk2)\n",
    "evals_u1, evecs_u1 = get_evals_evecs(laplacian_up_1.cpu().detach(), kk2)\n",
    "evals_d2, evecs_d2 = get_evals_evecs(laplacian_down_2.cpu().detach(), kk2)\n",
    "evals_u2, evecs_u2 = get_evals_evecs(laplacian_up_2.cpu().detach(), kk2)\n",
    "\n",
    "\n",
    "from scipy.sparse import coo_matrix\n",
    "incidence_1 = coo_matrix(incidence_1)  # Convert NumPy array to COO sparse format\n",
    "incidence_2 = coo_matrix(incidence_2)  # Convert NumPy array to COO sparse format\n",
    "\n",
    "incidence_1 = from_sparse(incidence_1)\n",
    "incidence_2 = from_sparse(incidence_2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import signal ##\n",
    "\n",
    "We retrieve an input signal on the nodes, edges and faces. The signal will have shape $n_\\text{simplicial} \\times$ in_channels, where in_channels is the dimension of each simplicial's feature. Here, we have in_channels = channels_nodes $ = 2$. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"A function to obtain features based on the input: rank\n",
    "\"\"\"\n",
    "\n",
    "\n",
    "def get_simplicial_features(dataset, rank):\n",
    "    if rank == 0:\n",
    "        which_feat = \"node_feat\"\n",
    "    elif rank == 1:\n",
    "        which_feat = \"edge_feat\"\n",
    "    elif rank == 2:\n",
    "        which_feat = \"face_feat\"\n",
    "    else:\n",
    "        raise ValueError(\n",
    "            \"input dimension must be 0, 1 or 2, because features are supported on nodes, edges and faces\"\n",
    "        )\n",
    "\n",
    "    x = list(dataset.get_simplex_attributes(which_feat).values())\n",
    "    return torch.tensor(np.stack(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "There are 294 nodes with features of dimension 3.\n",
      "There are 12942 edges with features of dimension 3.\n",
      "There are 24939 faces with features of dimension 3.\n"
     ]
    }
   ],
   "source": [
    "x_0 = _X[0]\n",
    "x_1 = _X[1]\n",
    "x_2 = _X[2]\n",
    "print(f\"There are {x_0.shape[0]} nodes with features of dimension {x_0.shape[1]}.\")\n",
    "print(f\"There are {x_1.shape[0]} edges with features of dimension {x_1.shape[1]}.\")\n",
    "print(f\"There are {x_2.shape[0]} faces with features of dimension {x_2.shape[1]}.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define binary labels\n",
    "We retrieve the labels associated to the nodes of each input simplex. In the KarateClub dataset, two social groups emerge. So we assign binary labels to the nodes indicating of which group they are a part.\n",
    "\n",
    "We convert the binary labels into one-hot encoder form, and keep the first four nodes' true labels for the purpose of testing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0 0 1 0 0 0 1 1 1 1 1 1 1 1 1 1 1 0 0 0 1 0 0 0 0 0 0 1 1 1 0 1 1 1 1 1 0\n",
      " 0 0 0 1 0 1 0 0 0 0 1 1 0 0 0 0 0 0 1 0 0 0 0 1 0 1 0 0 0 1 0 1 0 1 0 0 1\n",
      " 1 0 1 0 1 0 0 0 0 0 0 1 0 1 0 1 0 1 1 0 1 0 0 0 1 1 0 0 1 0 0 0 0 0 0 0 1\n",
      " 1 0 0 1 0 0 0 0 1 0 1 1 1 0 1 0 0 0 0 0 0 1 0 1 0 1 1 1 0 1 0 0 0 0 1 1 1\n",
      " 1 1 0 0 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 0 1 1 1 1 0 1 1 1 0 1 1 0 0 0 0 0 1\n",
      " 0 1 0 0 0 0 0 0 0 0 0 1 1 0 1 1 1 1 0 0 0 0 1 0 0 0 1 1 1 0 1 0 0 0 1 1 1\n",
      " 1 0 0 0 0 0 0 1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 0 0 0 1 0 1\n",
      " 0 0 1 0 1 1 0 0 1 1 0 1 0 0 0 0 0 0 0 1 0 1 0 0 0 1 1 1 1 1 1 1 1 1 0]\n"
     ]
    }
   ],
   "source": [
    "import torch.nn.functional as F\n",
    "\n",
    "y = np.array(labels-1)\n",
    "print(y)\n",
    "num_classes = 2  # Define the number of classes\n",
    "one_hot_labels = np.array(F.one_hot(torch.tensor(y), num_classes=num_classes))\n",
    "\n",
    "y_train, y_test = train_test_split(one_hot_labels,test_size=test_size, shuffle=False)\n",
    "y_train = torch.from_numpy(y_train).to(args.device)\n",
    "y_test = torch.from_numpy(y_test).to(args.device)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create and Train the Neural Network\n",
    "\n",
    "We specify the model with our pre-made neighborhood structures and specify an optimizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Network(torch.nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        in_channels_all,\n",
    "        hidden_channels_all,\n",
    "        out_channels,\n",
    "        conv_order,\n",
    "        max_rank,\n",
    "        update_func=None,\n",
    "        n_layers=2,\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.base_model = COSIMO(\n",
    "            in_channels_all=in_channels_all,\n",
    "            hidden_channels_all=hidden_channels_all,\n",
    "            conv_order=conv_order,\n",
    "            sc_order=max_rank,\n",
    "            update_func=update_func,\n",
    "            n_layers=n_layers,\n",
    "        )\n",
    "        out_channels_0, _, _ = hidden_channels_all\n",
    "        self.out_linear_0 = torch.nn.Linear(out_channels_0, out_channels)\n",
    "        self.out_linear_1 = torch.nn.Linear(out_channels, out_channels)\n",
    "\n",
    "    def forward(self, x_all, eig_eiv_all, incidence_all):\n",
    "        x_all_1 = self.base_model(x_all, eig_eiv_all, incidence_all)\n",
    "        x_0, _, _ = x_all_1\n",
    "\n",
    "        logits = self.out_linear_0(x_0)\n",
    "\n",
    "        return logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Obtain the initial features on all simplices\"\"\"\n",
    "x_all = (x_0.to(args.device), x_1.to(args.device), x_2.to(args.device))\n",
    "\n",
    "conv_order = 2\n",
    "in_channels_all = (x_0.shape[-1], x_1.shape[-1], x_2.shape[-1])\n",
    "intermediate_channels_all = (num_filters, num_filters, num_filters)\n",
    "out_channels = num_classes  # num classes\n",
    "\n",
    "eig_eiv_all = (\n",
    "    evals_0.to(args.device), evecs_0.to(args.device),\n",
    "    evals_d1.to(args.device), evecs_d1.to(args.device),\n",
    "    evals_u1.to(args.device), evecs_u1.to(args.device),\n",
    "    evals_d2.to(args.device), evecs_d2.to(args.device),\n",
    "    evals_u2.to(args.device), evecs_u2.to(args.device),\n",
    ")\n",
    "\n",
    "incidence_all = (incidence_1.to(args.device), incidence_2.to(args.device))\n",
    "\n",
    "model = Network(\n",
    "    in_channels_all=in_channels_all,\n",
    "    hidden_channels_all=intermediate_channels_all,\n",
    "    out_channels=out_channels,\n",
    "    conv_order=conv_order,\n",
    "    max_rank=max_rank,\n",
    "    update_func=None,\n",
    "    n_layers=num_layers,\n",
    ")\n",
    "\n",
    "model = model.to(args.device)\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=lr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1 loss: 5664.9961 Train_acc: 0.46\n",
      "Epoch: 2 loss: 57913.3945 Train_acc: 0.54\n",
      "Epoch: 3 loss: 5297.8374 Train_acc: 0.46\n",
      "Epoch: 4 loss: 4137.3789 Train_acc: 0.54\n",
      "Epoch: 5 loss: 215.1341 Train_acc: 0.71\n",
      "Epoch: 6 loss: 3311.7224 Train_acc: 0.46\n",
      "Epoch: 7 loss: 1101.0005 Train_acc: 0.54\n",
      "Epoch: 8 loss: 746.4346 Train_acc: 0.57\n",
      "Epoch: 9 loss: 2082.9409 Train_acc: 0.46\n",
      "Epoch: 10 loss: 742.3492 Train_acc: 0.49\n",
      "\n",
      "\n",
      "Test_acc: 0.37\n",
      "\n",
      "\n",
      "Epoch: 11 loss: 2307.6089 Train_acc: 0.54\n",
      "Epoch: 12 loss: 2742.4634 Train_acc: 0.54\n",
      "Epoch: 13 loss: 1139.3888 Train_acc: 0.57\n",
      "Epoch: 14 loss: 603.6898 Train_acc: 0.58\n",
      "Epoch: 15 loss: 1148.1345 Train_acc: 0.49\n",
      "Epoch: 16 loss: 412.4301 Train_acc: 0.65\n",
      "Epoch: 17 loss: 582.3062 Train_acc: 0.59\n",
      "Epoch: 18 loss: 852.7729 Train_acc: 0.58\n",
      "Epoch: 19 loss: 713.5306 Train_acc: 0.59\n",
      "Epoch: 20 loss: 325.5988 Train_acc: 0.67\n",
      "\n",
      "\n",
      "Test_acc: 0.73\n",
      "\n",
      "\n",
      "Epoch: 21 loss: 660.2907 Train_acc: 0.54\n",
      "Epoch: 22 loss: 580.5285 Train_acc: 0.54\n",
      "Epoch: 23 loss: 210.8781 Train_acc: 0.70\n",
      "Epoch: 24 loss: 339.7367 Train_acc: 0.61\n",
      "Epoch: 25 loss: 362.8009 Train_acc: 0.60\n",
      "Epoch: 26 loss: 237.1206 Train_acc: 0.62\n",
      "Epoch: 27 loss: 88.2708 Train_acc: 0.69\n",
      "Epoch: 28 loss: 212.7600 Train_acc: 0.56\n",
      "Epoch: 29 loss: 182.8307 Train_acc: 0.56\n",
      "Epoch: 30 loss: 80.3726 Train_acc: 0.64\n",
      "\n",
      "\n",
      "Test_acc: 0.66\n",
      "\n",
      "\n",
      "Epoch: 31 loss: 208.5088 Train_acc: 0.58\n",
      "Epoch: 32 loss: 160.1957 Train_acc: 0.61\n",
      "Epoch: 33 loss: 92.5760 Train_acc: 0.57\n",
      "Epoch: 34 loss: 134.6582 Train_acc: 0.54\n",
      "Epoch: 35 loss: 71.6410 Train_acc: 0.60\n",
      "Epoch: 36 loss: 44.6038 Train_acc: 0.74\n",
      "Epoch: 37 loss: 83.2694 Train_acc: 0.66\n",
      "Epoch: 38 loss: 76.9983 Train_acc: 0.70\n",
      "Epoch: 39 loss: 63.8286 Train_acc: 0.76\n",
      "Epoch: 40 loss: 65.1535 Train_acc: 0.77\n",
      "\n",
      "\n",
      "Test_acc: 0.76\n",
      "\n",
      "\n",
      "Epoch: 41 loss: 79.2751 Train_acc: 0.71\n",
      "Epoch: 42 loss: 76.8769 Train_acc: 0.70\n",
      "Epoch: 43 loss: 54.7937 Train_acc: 0.75\n",
      "Epoch: 44 loss: 41.1701 Train_acc: 0.78\n",
      "Epoch: 45 loss: 38.7855 Train_acc: 0.76\n",
      "Epoch: 46 loss: 42.3970 Train_acc: 0.74\n",
      "Epoch: 47 loss: 36.2408 Train_acc: 0.74\n",
      "Epoch: 48 loss: 35.6358 Train_acc: 0.70\n",
      "Epoch: 49 loss: 46.3429 Train_acc: 0.62\n",
      "Epoch: 50 loss: 47.8236 Train_acc: 0.60\n",
      "\n",
      "\n",
      "Test_acc: 0.66\n",
      "\n",
      "\n",
      "Epoch: 51 loss: 33.9454 Train_acc: 0.65\n",
      "Epoch: 52 loss: 27.4382 Train_acc: 0.73\n",
      "Epoch: 53 loss: 30.4646 Train_acc: 0.75\n",
      "Epoch: 54 loss: 28.5631 Train_acc: 0.76\n",
      "Epoch: 55 loss: 26.8513 Train_acc: 0.78\n",
      "Epoch: 56 loss: 28.9726 Train_acc: 0.77\n",
      "Epoch: 57 loss: 30.1843 Train_acc: 0.76\n",
      "Epoch: 58 loss: 27.9776 Train_acc: 0.75\n",
      "Epoch: 59 loss: 22.3965 Train_acc: 0.76\n",
      "Epoch: 60 loss: 18.6694 Train_acc: 0.80\n",
      "\n",
      "\n",
      "Test_acc: 0.73\n",
      "\n",
      "\n",
      "Epoch: 61 loss: 20.0843 Train_acc: 0.75\n",
      "Epoch: 62 loss: 21.0534 Train_acc: 0.72\n",
      "Epoch: 63 loss: 21.0247 Train_acc: 0.72\n",
      "Epoch: 64 loss: 20.0381 Train_acc: 0.71\n",
      "Epoch: 65 loss: 18.2268 Train_acc: 0.71\n",
      "Epoch: 66 loss: 15.9149 Train_acc: 0.72\n",
      "Epoch: 67 loss: 14.2818 Train_acc: 0.76\n",
      "Epoch: 68 loss: 14.7747 Train_acc: 0.80\n",
      "Epoch: 69 loss: 15.7956 Train_acc: 0.78\n",
      "Epoch: 70 loss: 15.6661 Train_acc: 0.77\n",
      "\n",
      "\n",
      "Test_acc: 0.76\n",
      "\n",
      "\n",
      "Epoch: 71 loss: 13.8969 Train_acc: 0.80\n",
      "Epoch: 72 loss: 12.3718 Train_acc: 0.80\n",
      "Epoch: 73 loss: 11.5621 Train_acc: 0.75\n",
      "Epoch: 74 loss: 11.3430 Train_acc: 0.74\n",
      "Epoch: 75 loss: 12.1752 Train_acc: 0.74\n",
      "Epoch: 76 loss: 11.8135 Train_acc: 0.73\n",
      "Epoch: 77 loss: 10.4415 Train_acc: 0.74\n",
      "Epoch: 78 loss: 9.7162 Train_acc: 0.75\n",
      "Epoch: 79 loss: 9.7447 Train_acc: 0.82\n",
      "Epoch: 80 loss: 10.2559 Train_acc: 0.81\n",
      "\n",
      "\n",
      "Test_acc: 0.71\n",
      "\n",
      "\n",
      "Epoch: 81 loss: 9.7561 Train_acc: 0.80\n",
      "Epoch: 82 loss: 9.2912 Train_acc: 0.80\n",
      "Epoch: 83 loss: 8.2080 Train_acc: 0.81\n",
      "Epoch: 84 loss: 8.0854 Train_acc: 0.79\n",
      "Epoch: 85 loss: 8.1663 Train_acc: 0.77\n",
      "Epoch: 86 loss: 8.0373 Train_acc: 0.74\n",
      "Epoch: 87 loss: 7.4073 Train_acc: 0.75\n",
      "Epoch: 88 loss: 7.1298 Train_acc: 0.81\n",
      "Epoch: 89 loss: 7.1360 Train_acc: 0.81\n",
      "Epoch: 90 loss: 7.1727 Train_acc: 0.80\n",
      "\n",
      "\n",
      "Test_acc: 0.73\n",
      "\n",
      "\n",
      "Epoch: 91 loss: 6.7692 Train_acc: 0.80\n",
      "Epoch: 92 loss: 6.4378 Train_acc: 0.81\n",
      "Epoch: 93 loss: 6.1237 Train_acc: 0.80\n",
      "Epoch: 94 loss: 5.9914 Train_acc: 0.77\n",
      "Epoch: 95 loss: 5.7624 Train_acc: 0.79\n",
      "Epoch: 96 loss: 5.7963 Train_acc: 0.81\n",
      "Epoch: 97 loss: 5.4186 Train_acc: 0.80\n",
      "Epoch: 98 loss: 5.1912 Train_acc: 0.80\n",
      "Epoch: 99 loss: 4.7805 Train_acc: 0.81\n",
      "Epoch: 100 loss: 5.0224 Train_acc: 0.81\n",
      "\n",
      "\n",
      "Test_acc: 0.73\n",
      "\n",
      "\n",
      "Test_acc: 0.73\n"
     ]
    }
   ],
   "source": [
    "test_interval = 10\n",
    "val_accuracy_best = -1\n",
    "test_accuracy = -1\n",
    "# Define cross-entropy loss function\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "for epoch_i in range(1, num_epochs + 1):\n",
    "    epoch_loss = []\n",
    "    model.train()\n",
    "    optimizer.zero_grad()\n",
    "    y_hat = model(x_all, eig_eiv_all, incidence_all)\n",
    "    \n",
    "    # Compute loss\n",
    "    loss = criterion(y_hat[: len(y_train)], torch.argmax(y_train.float(), dim=1))\n",
    "    \n",
    "    epoch_loss.append(loss.item())\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    probs = torch.softmax(y_hat, dim=1)\n",
    "    # Get predictions (index of the max probability)\n",
    "    y_pred = torch.argmax(probs, dim=1)\n",
    "    correct = (y_pred[: len(y_train)] == torch.argmax(y_train.float(), dim=1)).sum().item()\n",
    "    accuracy = correct / y_train.size(0)\n",
    "    print(\n",
    "        f\"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f} Train_acc: {accuracy:.2f}\",\n",
    "        flush=True,\n",
    "    )\n",
    "\n",
    "    if epoch_i % test_interval == 0:\n",
    "        with torch.no_grad():\n",
    "            y_hat_test = model(x_all, eig_eiv_all, incidence_all)            \n",
    "            probs = torch.softmax(y_hat_test, dim=1)\n",
    "            # Get predictions (index of the max probability)\n",
    "            y_pred_test = torch.argmax(probs, dim=1)\n",
    "            correct = (y_pred_test[-len(y_test) :] == torch.argmax(y_test.float(), dim=1)).sum().item()\n",
    "            test_accuracy = correct / y_test.size(0)\n",
    "            \n",
    "            print()\n",
    "            print()\n",
    "            print(f\"Test_acc: {test_accuracy:.2f}\", flush=True)\n",
    "            print()\n",
    "            print()\n",
    " \n",
    "print(f\"Test_acc: {test_accuracy:.2f}\", flush=True)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tmx",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
