{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a88e729f-32ac-4e47-9d9d-a50fddbef48f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from graphert.processing_data  import *\n",
    "from graphert.create_random_walks  import *\n",
    "from graphert.train_model  import *\n",
    "from graphert.train_tokenizer  import *\n",
    "from graphert.temporal_embeddings import *\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import scipy as sp\n",
    "import transformers\n",
    "from sklearn.manifold import TSNE\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "efb755d1-067e-44fc-8aa5-78f65aa0ff34",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import make_scorer, precision_score, recall_score, f1_score, accuracy_score\n",
    "from sklearn.model_selection import cross_val_score, StratifiedKFold\n",
    "import numpy as np\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import classification_report, confusion_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "aab59e6b-f4ab-48ac-bbd3-3ff11c314dfc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib.colors import ListedColormap\n",
    "import plotly.graph_objs as go"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "751beb7d-eafc-4910-8067-5384df82660f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "299ae33d-3f08-4b88-941c-925c86c11ab9",
   "metadata": {},
   "outputs": [],
   "source": [
    "tsne = TSNE(n_components=2, perplexity=40, n_iter=1000, random_state=4)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "30cc39fc-a7fb-4fc1-8380-069bb7276efe",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name = \"hippocampus_rat\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "f47614e9-060a-4675-a4c1-bc3a92a69d94",
   "metadata": {},
   "outputs": [],
   "source": [
    "graph_path = 'data/facebook/facebook-wall.txt'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "fae95555-2041-421a-8d45-e4a4045e4234",
   "metadata": {},
   "outputs": [],
   "source": [
    "graph_df = pd.read_table(graph_path, sep='\\t', header=None)\n",
    "graph_df.columns = ['source', 'target', 'time']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "5b108873-9c2c-426d-932e-9184586852bd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>source</th>\n",
       "      <th>target</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>28</td>\n",
       "      <td>28</td>\n",
       "      <td>1095135831</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1015</td>\n",
       "      <td>1017</td>\n",
       "      <td>1097725406</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>959</td>\n",
       "      <td>959</td>\n",
       "      <td>1098387569</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>991</td>\n",
       "      <td>991</td>\n",
       "      <td>1098425204</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1015</td>\n",
       "      <td>1017</td>\n",
       "      <td>1098489762</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>876988</th>\n",
       "      <td>1715</td>\n",
       "      <td>17995</td>\n",
       "      <td>1232597482</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>876989</th>\n",
       "      <td>18616</td>\n",
       "      <td>18616</td>\n",
       "      <td>1232598051</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>876990</th>\n",
       "      <td>28549</td>\n",
       "      <td>31056</td>\n",
       "      <td>1232598370</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>876991</th>\n",
       "      <td>24830</td>\n",
       "      <td>59912</td>\n",
       "      <td>1232598672</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>876992</th>\n",
       "      <td>24830</td>\n",
       "      <td>59912</td>\n",
       "      <td>1232598691</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>876993 rows × 3 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "        source  target        time\n",
       "0           28      28  1095135831\n",
       "1         1015    1017  1097725406\n",
       "2          959     959  1098387569\n",
       "3          991     991  1098425204\n",
       "4         1015    1017  1098489762\n",
       "...        ...     ...         ...\n",
       "876988    1715   17995  1232597482\n",
       "876989   18616   18616  1232598051\n",
       "876990   28549   31056  1232598370\n",
       "876991   24830   59912  1232598672\n",
       "876992   24830   59912  1232598691\n",
       "\n",
       "[876993 rows x 3 columns]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "graph_df"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "219bd75c-0530-4bc2-8115-428bb124a602",
   "metadata": {},
   "source": [
    "## Constructing temporal graphs from provided adjacency matrices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "5b76fac7-6d79-4cde-86a7-ce0a7b1c379b",
   "metadata": {},
   "outputs": [],
   "source": [
    "adj_time_list = pd.read_pickle(f'../../Dataset/adj_time_list_hippocampus_rat.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "fdec6aa0-6d8d-4078-be3e-49f6c6f97af9",
   "metadata": {},
   "outputs": [],
   "source": [
    "threshold = 0 # choose accordingly \n",
    "\n",
    "adj_matrices = [np.abs(np.where(np.abs(connectivity_matrix.toarray()) < threshold, 0, connectivity_matrix.toarray())) for connectivity_matrix in adj_time_list]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "3c087d10-1bb7-454e-acef-6b93e73a4e92",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = []\n",
    "\n",
    "# Iterate over each time step and adjacency matrix\n",
    "for time_step, adj_matrix in enumerate(adj_matrices):\n",
    "    # Find the indices of the non-zero elements in the adjacency matrix\n",
    "    source, target = np.where(adj_matrix != 0)\n",
    "    # Extract the corresponding weights\n",
    "    weights = adj_matrix[source, target]\n",
    "    \n",
    "    # Append the data to the list\n",
    "    for s, t, w in zip(source, target, weights):\n",
    "        data.append({\n",
    "            \"source\": s,\n",
    "            \"target\": t,\n",
    "            \"year\": int(time_step+2000),  # You can adjust this based on the actual time step\n",
    "            \"weight\": w\n",
    "        })\n",
    "\n",
    "# Create the DataFrame\n",
    "df = pd.DataFrame(data)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "a52e4759-11bc-4bb2-b8ab-c76d7f451f37",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "        source  target  year    weight\n",
      "0            0       0  2000  1.000000\n",
      "1            1       1  2000  1.000000\n",
      "2            1       2  2000  0.042241\n",
      "3            1       3  2000  0.033968\n",
      "4            1       5  2000  0.027462\n",
      "...        ...     ...   ...       ...\n",
      "191559     115     115  2084  1.000000\n",
      "191560     116     116  2084  1.000000\n",
      "191561     117     117  2084  1.000000\n",
      "191562     118     118  2084  1.000000\n",
      "191563     119     119  2084  1.000000\n",
      "\n",
      "[191564 rows x 4 columns]\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# Display the first few rows of the DataFrame\n",
    "print(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "6f7bfca8-68a1-4484-aa71-e8963315748e",
   "metadata": {},
   "outputs": [],
   "source": [
    "graph_nx, temporal_graph = load_dataset(df, dataset_name, time_granularity='years')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "6d30d03d-7a26-4fec-900b-9ffc83076842",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<networkx.classes.multidigraph.MultiDiGraph at 0x7b03b796eb10>"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "graph_nx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "01bbbbf8-d899-478f-a3e0-6f2894a5e1a3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{datetime.datetime(2000, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03b601d950>, datetime.datetime(2001, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03b606e6d0>, datetime.datetime(2002, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03b6c69550>, datetime.datetime(2003, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03b609e490>, datetime.datetime(2004, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03aa389dd0>, datetime.datetime(2005, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03b60a2990>, datetime.datetime(2006, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03aa1ee890>, datetime.datetime(2007, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03aa11b090>, datetime.datetime(2008, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03aa1ee350>, datetime.datetime(2009, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03aa1ede90>, datetime.datetime(2010, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03aa0d3310>, datetime.datetime(2011, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a9ec04d0>, datetime.datetime(2012, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03aa11b750>, datetime.datetime(2013, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a9ec0510>, datetime.datetime(2014, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a9fdc590>, datetime.datetime(2015, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a9ec0210>, datetime.datetime(2016, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a9cbfd10>, datetime.datetime(2017, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03aa079c10>, datetime.datetime(2018, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a9f4fc10>, datetime.datetime(2019, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a9e2d910>, datetime.datetime(2020, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a9c4b0d0>, datetime.datetime(2021, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a9b48050>, datetime.datetime(2022, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a9ceb050>, datetime.datetime(2023, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a9bdad90>, datetime.datetime(2024, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a99c2650>, datetime.datetime(2025, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a9aae390>, datetime.datetime(2026, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a9a8ce50>, datetime.datetime(2027, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03b601d7d0>, datetime.datetime(2028, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a9b48110>, datetime.datetime(2029, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a98d5910>, datetime.datetime(2030, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03b68f1f50>, datetime.datetime(2031, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a9883510>, datetime.datetime(2032, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a98d5190>, datetime.datetime(2033, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a9c4b150>, datetime.datetime(2034, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03b6db7090>, datetime.datetime(2035, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a9988510>, datetime.datetime(2036, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a9959690>, datetime.datetime(2037, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a9522210>, datetime.datetime(2038, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03aa1edf90>, datetime.datetime(2039, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03b6fb8dd0>, datetime.datetime(2040, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03aa497750>, datetime.datetime(2041, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a9381010>, datetime.datetime(2042, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a93f5890>, datetime.datetime(2043, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a92af810>, datetime.datetime(2044, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a9381450>, datetime.datetime(2045, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a8fd3090>, datetime.datetime(2046, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03aa1a3d50>, datetime.datetime(2047, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a8fd2bd0>, datetime.datetime(2048, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a8f1e1d0>, datetime.datetime(2049, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a929acd0>, datetime.datetime(2050, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a8e0e590>, datetime.datetime(2051, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a972e210>, datetime.datetime(2052, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a909c890>, datetime.datetime(2053, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03b6035750>, datetime.datetime(2054, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a8e0e250>, datetime.datetime(2055, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03aa11b050>, datetime.datetime(2056, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03b722c210>, datetime.datetime(2057, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a8d6d650>, datetime.datetime(2058, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a8b38690>, datetime.datetime(2059, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a8c43c90>, datetime.datetime(2060, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a896bf50>, datetime.datetime(2061, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03b7039d90>, datetime.datetime(2062, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a89f5dd0>, datetime.datetime(2063, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a88356d0>, datetime.datetime(2064, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a8a34d50>, datetime.datetime(2065, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a8725210>, datetime.datetime(2066, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03b6e16010>, datetime.datetime(2067, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a877eed0>, datetime.datetime(2068, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a877f390>, datetime.datetime(2069, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a877ed10>, datetime.datetime(2070, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a83e1ad0>, datetime.datetime(2071, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a8426f10>, datetime.datetime(2072, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a8877b50>, datetime.datetime(2073, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a80e15d0>, datetime.datetime(2074, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a8047f50>, datetime.datetime(2075, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a9479ad0>, datetime.datetime(2076, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a8047d90>, datetime.datetime(2077, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03b6c31ad0>, datetime.datetime(2078, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03b6c31a90>, datetime.datetime(2079, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a7ce6990>, datetime.datetime(2080, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a7e7ec90>, datetime.datetime(2081, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a7c0a250>, datetime.datetime(2082, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a7f12110>, datetime.datetime(2083, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a780fd90>, datetime.datetime(2084, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03a7b871d0>}\n"
     ]
    }
   ],
   "source": [
    "\n",
    "graphs = temporal_graph.get_temporal_graphs(min_degree=5)\n",
    "print(graphs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "0f4e0b17-a392-4386-afc1-497604e9d041",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of nodes: 120\n",
      "Number of edges: 191564\n"
     ]
    }
   ],
   "source": [
    "# Print the number of nodes and edges\n",
    "print(f\"Number of nodes: {graph_nx.number_of_nodes()}\")\n",
    "print(f\"Number of edges: {graph_nx.number_of_edges()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "29ba7e3a-49af-4474-97b5-5cc2be1b3b2c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TemporalGraph attributes:\n",
      "Data:\n",
      "   source  target  year    weight time_index       time\n",
      "0       0       0  2000  1.000000 2000-01-01 2000-01-01\n",
      "1       1       1  2000  1.000000 2000-01-01 2000-01-01\n",
      "2       1       2  2000  0.042241 2000-01-01 2000-01-01\n",
      "3       1       3  2000  0.033968 2000-01-01 2000-01-01\n",
      "4       1       5  2000  0.027462 2000-01-01 2000-01-01\n",
      "Time Granularity: years\n",
      "Time Columns: ['year']\n",
      "Step: relativedelta(years=+1)\n"
     ]
    }
   ],
   "source": [
    "# Print the attributes of the temporal_graph object\n",
    "print(\"TemporalGraph attributes:\")\n",
    "print(\"Data:\")\n",
    "print(temporal_graph.data.head())  # Display the first few rows of the data\n",
    "print(\"Time Granularity:\", temporal_graph.time_granularity)\n",
    "print(\"Time Columns:\", temporal_graph.time_columns)\n",
    "print(\"Step:\", temporal_graph.step)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d3ffc36-961b-4423-9794-4b4936f7486d",
   "metadata": {},
   "source": [
    "## GraphERT: \n",
    "Transformers-based Temporal Dynamic Graph Embedding\n",
    "\n",
    "![Alt text](GraphERT.png)\n",
    "\n",
    "Moran Beladev, Gilad Katz, Lior Rokach, Uriel Singer, Kira Radinsky.\n",
    "CIKM’23 – October 2023, Birmingham, United Kingdom."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "c6db0965-6270-4dc8-983b-14146d6157c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "cc_nodes = sorted(nx.connected_components(graph_nx.to_undirected()), key=len, reverse=True)[0] # biggest cc\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f35b35f-7f97-4393-b8a6-6b9946674d74",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "walk_len=32, num_walks=10\n",
      "0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 25/25 [03:39<00:00,  8.77s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 25/25 [00:20<00:00,  1.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 25/25 [00:06<00:00,  3.74it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 25/25 [00:53<00:00,  2.14s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 25/25 [01:23<00:00,  3.33s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 25/25 [00:35<00:00,  1.43s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 25/25 [03:01<00:00,  7.28s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 25/25 [03:16<00:00,  7.84s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 25/25 [01:43<00:00,  4.13s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 25/25 [02:07<00:00,  5.09s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 25/25 [03:42<00:00,  8.90s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "11\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/25 [00:00<?, ?it/s]"
     ]
    }
   ],
   "source": [
    "graphs = {i: v for i, (k, v) in enumerate(graphs.items())}\n",
    "qs = [0.25, 0.5, 1, 2, 4]\n",
    "ps = [0.25, 0.5, 1, 2, 4]\n",
    "walk_lengths = [32]\n",
    "num_walks_list = [10]\n",
    "create_random_walks(graphs, ps, qs, walk_lengths, num_walks_list, dataset_name, cc_nodes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "909a9abf-fc3f-4d30-892a-4ca9738af4fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "walk_len = walk_lengths[0]\n",
    "num_walks = num_walks_list[0]\n",
    "\n",
    "random_walk_path = f'datasets_res/{dataset_name}/paths_walk_len_{walk_len}_num_walks_{num_walks}.csv'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90cea38f-5084-4d11-ae76-c196f11b19d0",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "#train a node-level tokenizer\n",
    "train_graph_tokenizer(random_walk_path, dataset_name, walk_len)\n",
    "# train_only_temporal_model(random_walk_path, dataset_name, walk_len)\n",
    "train_mlm_temporal_model(random_walk_path, dataset_name, walk_len)\n",
    "# train_2_steps_model(random_walk_path, dataset_name, walk_len)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8479ec1f-8117-4e36-9382-12277a007c52",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = f'datasets_res/{dataset_name}/models/mlm_and_temporal_model'\n",
    "    # get temporal embeddings by the last layer\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4700b324-16ca-4f16-aa71-da0aa0c41851",
   "metadata": {},
   "outputs": [],
   "source": [
    "cls_temporal_embeddings = get_temporal_embeddings(model_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f7e5cd8-e743-4988-a75b-4c91205a9eef",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9836e35c-84ca-4b03-8105-6c912cc3d5cb",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    " # get temporal embeddings by averaging the paths embeddings per time\n",
    "data_df = pd.read_csv(random_walk_path, index_col=None)\n",
    "t_cls_emb_mean, t_cls_emb_weighted_mean, t_prob, t_nodes_emb_mean = get_embeddings_by_paths_average(data_df, model_path, dataset_name,\n",
    "                                                                          walk_len)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "3927fd57-44b4-4e30-aaf2-d8736848ee22",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Dictionary to store the average embeddings for each time point\n",
    "average_embeddings = {}\n",
    "\n",
    "for time_point, inner_dict in t_nodes_emb_mean.items():\n",
    "    # Collect all embeddings for the current time point\n",
    "    embeddings = np.array(list(inner_dict.values()))\n",
    "    \n",
    "    # Calculate the average embedding\n",
    "    average_embedding = np.mean(embeddings, axis=0)\n",
    "    \n",
    "    # Store the average embedding in the dictionary\n",
    "    average_embeddings[time_point] = average_embedding\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "id": "4c92697c-d753-406c-988a-6f403a2be0cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "cls_emb_mean_list = list(t_cls_emb_mean.values())\n",
    "cls_emb_weighted_mean_list = list(t_cls_emb_weighted_mean.values())\n",
    "average_embeddings_list = list(average_embeddings.values())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
