{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:Limited tf.compat.v2.summary API due to missing TensorBoard installation.\n",
      "WARNING:root:Limited tf.compat.v2.summary API due to missing TensorBoard installation.\n",
      "WARNING:root:Limited tf.compat.v2.summary API due to missing TensorBoard installation.\n",
      "WARNING:root:Limited tf.summary API due to missing TensorBoard installation.\n",
      "WARNING:root:Limited tf.compat.v2.summary API due to missing TensorBoard installation.\n",
      "WARNING:root:Limited tf.compat.v2.summary API due to missing TensorBoard installation.\n",
      "WARNING:root:Limited tf.compat.v2.summary API due to missing TensorBoard installation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[name: \"/device:CPU:0\"\n",
      "device_type: \"CPU\"\n",
      "memory_limit: 268435456\n",
      "locality {\n",
      "}\n",
      "incarnation: 9654898608392036199\n",
      ", name: \"/device:GPU:0\"\n",
      "device_type: \"GPU\"\n",
      "memory_limit: 4160159744\n",
      "locality {\n",
      "  bus_id: 1\n",
      "  links {\n",
      "  }\n",
      "}\n",
      "incarnation: 11372315678574588484\n",
      "physical_device_desc: \"device: 0, name: NVIDIA GeForce RTX 2060, pci bus id: 0000:01:00.0, compute capability: 7.5\"\n",
      "]\n"
     ]
    }
   ],
   "source": [
    "#imports \n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import tensorflow_addons as tfa\n",
    "import random\n",
    "from spektral.utils import tic, toc\n",
    "from models import *\n",
    "from utils import *\n",
    "#check if the gpu is available\n",
    "from tensorflow.python.client import device_lib\n",
    "print(device_lib.list_local_devices())\n",
    "#set seed so results are reproducible\n",
    "random.seed(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "#read the dataset\n",
    "def read_data(id: str): #DBLP3, DBLP5, Brain, Reddit, DBLPE\n",
    "    #pick which dataset to load\n",
    "    dataset_dict=dict()\n",
    "    dataset_dict[\"DBLP3\"]=\"Datasets/DBLP3.npz\"\n",
    "    dataset_dict[\"DBLP5\"]=\"Datasets/DBLP5.npz\"\n",
    "    dataset_dict[\"Brain\"]=\"Datasets/Brain.npz\"\n",
    "    dataset_dict[\"Reddit\"]=\"Datasets/reddit.npz\"\n",
    "    dataset_dict[\"DBLPE\"]=\"Datasets/DBLPE.npz\"\n",
    "\n",
    "    dataset = np.load(dataset_dict[id])\n",
    "\n",
    "    #get the adjacency matrix\n",
    "    adjs = dataset[\"adjs\"] #(time, node, node)\n",
    "\n",
    "    #Remove nodes with no connections at any timestep\n",
    "    #this shrinks the data considerably\n",
    "    temporal_sum = tf.math.reduce_sum(adjs, axis=0, keepdims=False, name=None)\n",
    "    row_sum = tf.math.reduce_sum(temporal_sum, axis=0, keepdims=False, name=None)\n",
    "    non_zero_indices = np.flatnonzero(row_sum)\n",
    "    adjs = adjs[:,non_zero_indices,:]\n",
    "    adjs = adjs[:,:,non_zero_indices]\n",
    "\n",
    "    #DBLPE is a dynamic featureless graph\n",
    "    if id==\"DBLPE\":\n",
    "        labels = dataset[\"labels\"] #(nodes, time, class)\n",
    "\n",
    "        # labels = np.argmax(labels,axis=2)\n",
    "        labels=labels[non_zero_indices]\n",
    "        feats=np.zeros([adjs.shape[1], adjs.shape[0], adjs.shape[2]])\n",
    "\n",
    "        #since there are no features just fill in the identity matrix\n",
    "        for i in range(feats.shape[1]):\n",
    "            feats[:,i,:]=np.eye(feats.shape[0])\n",
    "      \n",
    "    #All others are static feature-full graphs\n",
    "    else:\n",
    "        labels = dataset[\"labels\"] #(nodes, class)\n",
    "        feats = dataset[\"attmats\"] #(node, time, feat)\n",
    "\n",
    "        # labels = np.argmax(labels, axis=1)\n",
    "        labels = labels[non_zero_indices]\n",
    "        feats = feats[non_zero_indices]\n",
    "\n",
    "    #Other important variables\n",
    "    n_nodes = adjs.shape[1]\n",
    "    n_timesteps = adjs.shape[0]\n",
    "    n_class = int(labels.shape[1])\n",
    "    n_feat = feats.shape[2]\n",
    "\n",
    "    #Train Val Test split\n",
    "    nodes_id = list(range(n_nodes))\n",
    "    random.shuffle(nodes_id)\n",
    "    idx_train = nodes_id[:(7*n_nodes)//10]\n",
    "    idx_train = [True if i in idx_train else False for i in list(range(n_nodes))]\n",
    "    idx_val = nodes_id[(7*n_nodes)//10: (9*n_nodes)//10]\n",
    "    idx_val = [True if i in idx_val else False for i in list(range(n_nodes))]\n",
    "    idx_test = nodes_id[(9*n_nodes)//10: n_nodes]\n",
    "    idx_test = [True if i in idx_test else False for i in list(range(n_nodes))]\n",
    "\n",
    "    #custom data type that holds everything i might need\n",
    "    return STG_Dataset(tf.convert_to_tensor(adjs,dtype=tf.float32),\n",
    "                        tf.convert_to_tensor(adjs,dtype=tf.float32),\n",
    "                        tf.convert_to_tensor(feats,dtype=tf.float32), \n",
    "                        tf.convert_to_tensor(feats,dtype=tf.float32), \n",
    "                        tf.convert_to_tensor(labels,dtype=tf.float32), \n",
    "                        tf.convert_to_tensor(labels,dtype=tf.float32), \n",
    "                        n_nodes, n_timesteps, n_class, n_feat, \n",
    "                        np.array(idx_train),\n",
    "                        np.array(idx_val),\n",
    "                        np.array(idx_test))\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training step\n",
    "@tf.function\n",
    "def train(feats, adjs, labels, idx_train, idx_val, model, loss_fn, optimizer, acc):\n",
    "    #training\n",
    "    with tf.GradientTape() as tape:\n",
    "        predictions = model([feats, adjs], training=True)\n",
    "        loss_train = loss_fn(labels[idx_train], predictions[idx_train])\n",
    "    gradients = tape.gradient(loss_train, model.trainable_variables)\n",
    "    optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
    "\n",
    "    #evaluating\n",
    "    predictions = model([feats, adjs], training=False)\n",
    "    loss_val = loss_fn(labels[idx_val], predictions[idx_val])\n",
    "    acc.update_state(labels[idx_val], predictions[idx_val])\n",
    "    return loss_train\n",
    "\n",
    "\n",
    "@tf.function\n",
    "#testing step\n",
    "def test(feats, adjs, labels, idx_test, model, loss_fn, optimizer, acc, auc, f1):\n",
    "    predictions = model([feats, adjs], training=False)\n",
    "    loss_test = loss_fn(labels[idx_test], predictions[idx_test])\n",
    "\n",
    "    #updating metrics state\n",
    "    acc.update_state(labels[idx_test], predictions[idx_test])\n",
    "    auc.update_state(labels[idx_test], predictions[idx_test])\n",
    "    f1.update_state(labels[idx_test], predictions[idx_test])\n",
    "    return loss_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "#a single epoch of a train and test loop\n",
    "def timestep_train_test(epochs, model, data, loss_fn, optimizer, val_acc, acc, auc, f1):\n",
    "    best_val=0\n",
    "    tic()\n",
    "    #for each epoch\n",
    "    for epoch in range(1, epochs + 1):\n",
    "        #calculate the loss\n",
    "        loss_train = train(data.feats_timestep, data.adjs_timestep, data.labels_timestep, data.idx_train, data.idx_val, model, loss_fn, optimizer, val_acc)\n",
    "        #keep track of best val acc\n",
    "        if val_acc.result() > best_val:\n",
    "            best_val = val_acc.result()\n",
    "        val_acc.reset_state()\n",
    "    print(f\"Best Training Loss {loss_train}\")\n",
    "    print(f\"Best Val Acc: {best_val}\")\n",
    "\n",
    "    #after training test the data\n",
    "    loss_test = test(data.feats_timestep, data.adjs_timestep, data.labels_timestep, data.idx_test, model, loss_fn, optimizer, acc, auc, f1)\n",
    "    print(f\"Test Loss: {loss_test}, Test Acc: {acc.result()}, Test F1 score: {f1.result()}, Auc Test: {auc.result()}\")\n",
    "    # print(f\"lambda: {model.trainable_weights[0]}\")\n",
    "    toc(f\"{model.name} ({epochs} epochs)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "#wrapper over train and test loop. Takes in data name and model name and then trains and evaluates\n",
    "def overall_train_test(data_id, model_id):\n",
    "    #Constant parameters\n",
    "    epochs = 500\n",
    "    dropout_rate = 0.5\n",
    "    lr = 25e-4\n",
    "    weight_decay = 5e-4\n",
    "    ignores_temporal_data = [\"GAT\", \"GCN\", \"GraphSage\"]\n",
    "\n",
    "    #read the data\n",
    "    data = read_data(data_id)\n",
    "    #Each model takes in different parameters so this is where i decide which model to build\n",
    "    if model_id.__name__ not in [\"TRNNGCN\", \"EGCN\"]:\n",
    "        model = model_id(data.n_class, data.n_class, dropout_rate)\n",
    "    elif model_id.__name__ == \"EGCN\":\n",
    "        model = model_id(data.n_feat, data.n_class, data.n_class)\n",
    "    else:\n",
    "        model = model_id(data.n_nodes, data.n_class, data.n_class, dropout_rate)\n",
    "\n",
    "    #If the model ignores temporal data it only takes in 2 dimensions\n",
    "    if (model_id.__name__ in ignores_temporal_data):\n",
    "        model.build([(data.n_nodes, data.n_feat), (data.n_nodes, data.n_nodes)])\n",
    "    #else it needs to take in time\n",
    "    else:\n",
    "        model.build([(data.n_nodes, data.n_timestamps, data.n_feat), (data.n_timestamps, data.n_nodes, data.n_nodes)])\n",
    "    model.summary()\n",
    "    optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=weight_decay)\n",
    "    loss_fn = tf.keras.losses.CategoricalCrossentropy()\n",
    "    #Metrics can only be created once\n",
    "    val_acc = tf.keras.metrics.CategoricalAccuracy()\n",
    "    acc = tf.keras.metrics.CategoricalAccuracy()\n",
    "    auc = tf.keras.metrics.AUC(num_thresholds=data.adjs.shape[0], multi_label=False)\n",
    "    f1 = tfa.metrics.F1Score(data.labels.shape[1], average=\"weighted\")\n",
    "\n",
    "    #preprocessing\n",
    "    #for each timestep\n",
    "    for timestep in range(1, data.n_timestamps+1):\n",
    "        #keep track of the timesteps current position\n",
    "        data.adjs_timestep = tf.identity(data.adjs[:timestep,:,:])\n",
    "        data.feats_timestep = tf.identity(data.feats)\n",
    "        data.labels_timestep = tf.identity(data.labels)\n",
    "        if (data_id == \"DBLPE\"):\n",
    "            data.labels_timestep = data.labels_timestep[:,timestep-1]\n",
    "        \n",
    "        #If the model ignores temporal data, accumulate adj matrices\n",
    "        if (model_id.__name__ in ignores_temporal_data):\n",
    "            data.adjs_timestep = tf.math.reduce_sum(data.adjs_timestep, axis=0, keepdims=False, name=None)\n",
    "            data.feats_timestep = data.feats_timestep[:, -1, :]\n",
    "\n",
    "            #normalize the adj matrix\n",
    "            data.adjs_timestep += tf.eye(data.adjs_timestep.shape[0])\n",
    "            d = tf.reduce_sum(data.adjs_timestep, axis=1)\n",
    "            normalizing_matrix = np.zeros((data.adjs_timestep.shape[0], data.adjs_timestep.shape[0]))\n",
    "            normalizing_matrix[range(len(normalizing_matrix)), range(len(normalizing_matrix))] = d**(-0.5)\n",
    "            normalizing_matrix = tf.convert_to_tensor(normalizing_matrix, dtype=tf.float32)\n",
    "            data.adjs_timestep = tf.matmul(normalizing_matrix,data.adjs_timestep)\n",
    "            data.adjs_timestep=tf.matmul(tf.matmul(normalizing_matrix,data.adjs_timestep), normalizing_matrix)\n",
    "\n",
    "        timestep_train_test(epochs, model, data, loss_fn, optimizer, val_acc, acc, auc, f1)\n",
    "            \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"gat\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "dropout (Dropout)            multiple                  0         \n",
      "_________________________________________________________________\n",
      "gat_conv (GATConv)           multiple                  309       \n",
      "_________________________________________________________________\n",
      "gat_conv_1 (GATConv)         multiple                  18        \n",
      "=================================================================\n",
      "Total params: 327\n",
      "Trainable params: 327\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "Best Training Loss 0.6628769040107727\n",
      "Best Val Acc: 0.7402135133743286\n",
      "Test Loss: 0.7425931692123413, Test Acc: 0.7234042286872864, Test F1 score: 0.6073023676872253, Auc Test: 0.8087747097015381\n",
      "gat (500 epochs)\n",
      "Elapsed: 8.82s\n",
      "Best Training Loss 0.6197769045829773\n",
      "Best Val Acc: 0.7402135133743286\n",
      "Test Loss: 0.7074084281921387, Test Acc: 0.7234042286872864, Test F1 score: 0.6073023676872253, Auc Test: 0.8272753953933716\n",
      "gat (500 epochs)\n",
      "Elapsed: 6.96s\n",
      "Best Training Loss 0.6053073406219482\n",
      "Best Val Acc: 0.7402135133743286\n",
      "Test Loss: 0.6964904069900513, Test Acc: 0.7234042286872864, Test F1 score: 0.6073023676872253, Auc Test: 0.8346579670906067\n",
      "gat (500 epochs)\n",
      "Elapsed: 6.70s\n",
      "Best Training Loss 0.5971313118934631\n",
      "Best Val Acc: 0.7402135133743286\n",
      "Test Loss: 0.6909966468811035, Test Acc: 0.7234042286872864, Test F1 score: 0.6073023676872253, Auc Test: 0.8388945460319519\n",
      "gat (500 epochs)\n",
      "Elapsed: 6.66s\n",
      "Best Training Loss 0.5895549058914185\n",
      "Best Val Acc: 0.7402135133743286\n",
      "Test Loss: 0.6810287237167358, Test Acc: 0.7234042286872864, Test F1 score: 0.6073023676872253, Auc Test: 0.8416849970817566\n",
      "gat (500 epochs)\n",
      "Elapsed: 6.63s\n",
      "Best Training Loss 0.5826338529586792\n",
      "Best Val Acc: 0.7402135133743286\n",
      "Test Loss: 0.684391438961029, Test Acc: 0.7234042286872864, Test F1 score: 0.6073023676872253, Auc Test: 0.8432013988494873\n",
      "gat (500 epochs)\n",
      "Elapsed: 6.99s\n",
      "Best Training Loss 0.573632001876831\n",
      "Best Val Acc: 0.7402135133743286\n",
      "Test Loss: 0.6862150430679321, Test Acc: 0.7234042286872864, Test F1 score: 0.6073023676872253, Auc Test: 0.8443452715873718\n",
      "gat (500 epochs)\n",
      "Elapsed: 7.00s\n",
      "Best Training Loss 0.5777983069419861\n",
      "Best Val Acc: 0.7437722682952881\n",
      "Test Loss: 0.6787505745887756, Test Acc: 0.7234042286872864, Test F1 score: 0.6073023676872253, Auc Test: 0.8453857898712158\n",
      "gat (500 epochs)\n",
      "Elapsed: 7.37s\n",
      "Best Training Loss 0.5833646059036255\n",
      "Best Val Acc: 0.7437722682952881\n",
      "Test Loss: 0.6788987517356873, Test Acc: 0.7234042286872864, Test F1 score: 0.6073023676872253, Auc Test: 0.8464246392250061\n",
      "gat (500 epochs)\n",
      "Elapsed: 7.45s\n",
      "Best Training Loss 0.5690937042236328\n",
      "Best Val Acc: 0.7473309636116028\n",
      "Test Loss: 0.6796329021453857, Test Acc: 0.7234042286872864, Test F1 score: 0.6073023676872253, Auc Test: 0.8473614454269409\n",
      "gat (500 epochs)\n",
      "Elapsed: 6.99s\n"
     ]
    }
   ],
   "source": [
    "# tf.config.run_functions_eagerly(True)\n",
    "\n",
    "#input model and dataset and everything will run itself\n",
    "overall_train_test(\"DBLP3\", GAT)"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "246acc2e86da06095fd9b40717c2339d803f4778ff32430783cde5272a877e12"
  },
  "kernelspec": {
   "display_name": "Python 3.8.11 64-bit ('deep_learning': conda)",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
