{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a1b28fb5",
   "metadata": {},
   "source": [
    "# Module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2d30355",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Main script for training and evaluating temporal GNN models for node classification tasks.\n",
    "\"\"\"\n",
    "\n",
    "import os\n",
    "import sys\n",
    "import time\n",
    "import json\n",
    "import shutil\n",
    "import logging\n",
    "import warnings\n",
    "from collections import defaultdict\n",
    "import pickle\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "\n",
    "# Suppress unnecessary warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "logging.getLogger('matplotlib').setLevel(logging.WARNING)\n",
    "\n",
    "from models.TGAT import TGAT\n",
    "from models.MemoryModel import MemoryModel, compute_src_dst_node_time_shifts\n",
    "from models.CAWN import CAWN\n",
    "from models.TCL import TCL\n",
    "from models.GraphMixer import GraphMixer\n",
    "from models.DyGFormer import DyGFormer\n",
    "from models.modules import MergeLayer, MLPClassifier\n",
    "\n",
    "from utils.utils import (\n",
    "    set_random_seed, convert_to_gpu, get_parameter_sizes, create_optimizer,\n",
    "    get_neighbor_sampler\n",
    ")\n",
    "from utils.DataLoader import get_idx_data_loader, get_node_classification_data\n",
    "from utils.metrics import get_node_classification_metrics\n",
    "from utils.EarlyStopping import EarlyStopping\n",
    "from utils.load_configs import get_node_classification_args\n",
    "from evaluate_models_utils import evaluate_model_node_classification\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f2fb90c",
   "metadata": {},
   "source": [
    "# Settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c7353d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Args:\n",
    "    \"\"\"\n",
    "    Configuration class for model and training setup.\n",
    "    \"\"\"\n",
    "\n",
    "    #### Dataset selection ####\n",
    "\n",
    "    # Real-world datasets\n",
    "    #dataset_name = \"brain\"\n",
    "    dataset_name = \"school\"\n",
    "    #dataset_name = \"stock\"\n",
    "\n",
    "    # Synthetic dataset\n",
    "    #dataset_name = \"synthetic_exp2.1\"\n",
    "    #dataset_name = \"synthetic_exp2.2\"\n",
    "    #### Model selection ####\n",
    "\n",
    "    # Choose one model to activate\n",
    "    #model_name = 'JODIE'\n",
    "    model_name = 'DyRep'\n",
    "    #model_name = 'TGN'\n",
    "    #model_name = \"TGAT\"\n",
    "    #model_name = 'DyGFormer'\n",
    "\n",
    "    #### General training configuration ####\n",
    "\n",
    "    batch_size = 128  # Number of samples per training batch\n",
    "\n",
    "    gpu = 0  # GPU index if using CUDA, otherwise set -1 for CPU\n",
    "\n",
    "    #### Neighbor sampling ####\n",
    "\n",
    "    num_neighbors = 5  # Number of temporal neighbors to sample\n",
    "    sample_neighbor_strategy = 'uniform'  # Strategy for sampling neighbors\n",
    "    time_scaling_factor = 10.0  # Scaling factor for time-based features\n",
    "\n",
    "    #### Model-specific hyperparameters ####\n",
    "\n",
    "    num_walk_heads = 4  # For models with random walk or multi-head attention\n",
    "    num_heads = 4  # Number of attention heads\n",
    "    num_layers = 2  # Number of layers (e.g., GNN or transformer layers)\n",
    "    walk_length = 5  # Length of random walks (if applicable)\n",
    "    time_gap = 1.0  # Time gap resolution\n",
    "    time_feat_dim = 32  # Dimension of time features\n",
    "    position_feat_dim = 32  # Dimension of position features\n",
    "\n",
    "    patch_size = 16  # Patch size for temporal sequences    \n",
    "    channel_embedding_dim = 12  # Embedding size per channel\n",
    "    \n",
    "    ### brain,stock\n",
    "    max_input_sequence_length = 10  # Max length of input temporal sequence\n",
    "    ### Else\n",
    "    #max_input_sequence_length = 6  # Max length of input temporal sequence\n",
    "    #### Training settings ####\n",
    "\n",
    "    learning_rate = 0.001  # Optimizer learning rate\n",
    "    dropout = 0.1  # Dropout probability\n",
    "    num_epochs = 50  # Total number of training epochs\n",
    "    optimizer = 'Adam'  # Optimizer choice\n",
    "    weight_decay = 0.0005  # L2 regularization weight\n",
    "    patience = 5  # Early stopping patience (epochs without improvement)\n",
    "\n",
    "    #### Dataset splits ####\n",
    "\n",
    "    # There are no valdiation or test set becuase we simply want to see the node embeddings\n",
    "    val_ratio = 0.0  # Validation set ratio\n",
    "    test_ratio = 0.0  # Test set ratio\n",
    "\n",
    "    #### Experiment settings ####\n",
    "\n",
    "    num_runs = 1  # Number of training runs for averaging\n",
    "    negative_sample_strategy = 'random'  # Strategy for negative sampling\n",
    "    load_best_configs = False  # Whether to load best hyperparameters\n",
    "\n",
    "    #### Device setup ####\n",
    "\n",
    "    device = 'cuda:0' if torch.cuda.is_available() and gpu >= 0 else 'cpu'\n",
    "\n",
    "# Instantiate the args object\n",
    "args = Args()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3409679c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "cfb9ba86",
   "metadata": {},
   "source": [
    "# Load Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60d4a35a",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "node_raw_features, edge_raw_features, full_data, train_data, val_data, test_data = \\\n",
    "    get_node_classification_data(dataset_name=args.dataset_name, val_ratio=args.val_ratio, test_ratio=args.test_ratio)\n",
    "\n",
    "data = full_data\n",
    "\n",
    "sorted_indices = np.argsort(data.node_interact_times)\n",
    "data.src_node_ids = data.src_node_ids[sorted_indices]\n",
    "data.dst_node_ids = data.dst_node_ids[sorted_indices]\n",
    "data.node_interact_times = data.node_interact_times[sorted_indices]\n",
    "data.edge_ids = data.edge_ids[sorted_indices]\n",
    "\n",
    "full_neighbor_sampler = get_neighbor_sampler(\n",
    "    data=full_data,\n",
    "    sample_neighbor_strategy=args.sample_neighbor_strategy,\n",
    "    time_scaling_factor=args.time_scaling_factor,\n",
    "    seed=1\n",
    ")\n",
    "\n",
    "train_idx_data_loader = get_idx_data_loader(list(range(len(train_data.src_node_ids))), batch_size=args.batch_size, shuffle=False)\n",
    "val_idx_data_loader = get_idx_data_loader(list(range(len(val_data.src_node_ids))), batch_size=args.batch_size, shuffle=False)\n",
    "test_idx_data_loader = get_idx_data_loader(list(range(len(test_data.src_node_ids))), batch_size=args.batch_size, shuffle=False)\n",
    "\n",
    "val_metric_all_runs, test_metric_all_runs = [], []\n",
    "\n",
    "for run in range(args.num_runs):\n",
    "    set_random_seed(seed=run)\n",
    "\n",
    "    args.seed = run\n",
    "    args.load_model_name = f'{args.model_name}_seed{args.seed}'\n",
    "    args.save_model_name = f'node_classification_{args.model_name}_seed{args.seed}'\n",
    "\n",
    "    logging.basicConfig(level=logging.INFO)\n",
    "    logger = logging.getLogger()\n",
    "    logger.setLevel(logging.DEBUG)\n",
    "    os.makedirs(f\"./logs/{args.model_name}/{args.dataset_name}/{args.save_model_name}/\", exist_ok=True)\n",
    "    fh = logging.FileHandler(f\"./logs/{args.model_name}/{args.dataset_name}/{args.save_model_name}/{str(time.time())}.log\")\n",
    "    fh.setLevel(logging.DEBUG)\n",
    "    ch = logging.StreamHandler()\n",
    "    ch.setLevel(logging.WARNING)\n",
    "    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')\n",
    "    fh.setFormatter(formatter)\n",
    "    ch.setFormatter(formatter)\n",
    "    logger.addHandler(fh)\n",
    "    logger.addHandler(ch)\n",
    "\n",
    "    run_start_time = time.time()\n",
    "    logger.info(f\"********** Run {run + 1} starts. **********\")\n",
    "    logger.info(f'configuration is {args}')\n",
    "\n",
    "    if args.model_name == 'TGAT':\n",
    "        dynamic_backbone = TGAT(node_raw_features, edge_raw_features, full_neighbor_sampler, args.time_feat_dim,\n",
    "                                    args.num_layers, args.num_heads, args.dropout, args.device)\n",
    "    elif args.model_name in ['JODIE', 'DyRep', 'TGN']:\n",
    "        src_node_mean_time_shift, src_node_std_time_shift, dst_node_mean_time_shift_dst, dst_node_std_time_shift = \\\n",
    "                compute_src_dst_node_time_shifts(train_data.src_node_ids, train_data.dst_node_ids, train_data.node_interact_times)\n",
    "        dynamic_backbone = MemoryModel(node_raw_features, edge_raw_features, full_neighbor_sampler, args.time_feat_dim,\n",
    "                                           args.model_name, args.num_layers, args.num_heads, args.dropout,\n",
    "                                           src_node_mean_time_shift, src_node_std_time_shift,\n",
    "                                           dst_node_mean_time_shift_dst, dst_node_std_time_shift, args.device)\n",
    "    elif args.model_name == 'CAWN':\n",
    "        dynamic_backbone = CAWN(node_raw_features, edge_raw_features, full_neighbor_sampler, args.time_feat_dim,\n",
    "                                    args.position_feat_dim, args.walk_length, args.num_walk_heads, args.dropout, args.device)\n",
    "    elif args.model_name == 'TCL':\n",
    "        dynamic_backbone = TCL(node_raw_features, edge_raw_features, full_neighbor_sampler, args.time_feat_dim,\n",
    "                                   args.num_layers, args.num_heads, args.num_neighbors + 1, args.dropout, args.device)\n",
    "    elif args.model_name == 'GraphMixer':\n",
    "        dynamic_backbone = GraphMixer(node_raw_features, edge_raw_features, full_neighbor_sampler, args.time_feat_dim,\n",
    "                                          args.num_neighbors, args.num_layers, args.dropout, args.device)\n",
    "    elif args.model_name == 'DyGFormer':\n",
    "        dynamic_backbone = DyGFormer(node_raw_features, edge_raw_features, full_neighbor_sampler, args.time_feat_dim,\n",
    "                                         args.channel_embedding_dim, args.patch_size, args.num_layers, args.num_heads,\n",
    "                                         args.dropout, args.max_input_sequence_length, args.device)\n",
    "    else:\n",
    "        raise ValueError(f\"Wrong value for model_name {args.model_name}!\")\n",
    "\n",
    "    link_predictor = MergeLayer(node_raw_features.shape[1], node_raw_features.shape[1],\n",
    "                                    node_raw_features.shape[1], 1)\n",
    "    model = nn.Sequential(dynamic_backbone, link_predictor)\n",
    "\n",
    "    load_model_folder = f\"./saved_models/{args.model_name}/{args.dataset_name}/{args.load_model_name}\"\n",
    "    early_stopping = EarlyStopping(patience=0, save_model_folder=load_model_folder,\n",
    "                                       save_model_name=args.load_model_name, logger=logger, model_name=args.model_name)\n",
    "\n",
    "\n",
    "    # === Load full model state_dict ===\n",
    "    checkpoint_path = os.path.join(load_model_folder, f\"{args.load_model_name}.pkl\")\n",
    "    model.load_state_dict(torch.load(checkpoint_path, map_location=args.device))\n",
    "\n",
    "    # === Move to correct device ===\n",
    "    model = convert_to_gpu(model, device=args.device)\n",
    "\n",
    "    # === Extract dynamic backbone (MemoryModel) ===\n",
    "    model = model[0]  # type: MemoryModel\n",
    "\n",
    "    # === Reset memory (crucial before inference) ===\n",
    "    if args.model_name in ['JODIE', 'DyRep', 'TGN']:\n",
    "        model.memory_bank.__init_memory_bank__()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "261791c7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "6c0bfec6",
   "metadata": {},
   "source": [
    "# Calculate node embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bde5988",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "# === Group interactions by integer time (once) ===\n",
    "interactions_by_time = defaultdict(list)\n",
    "for i in range(len(data.src_node_ids)):\n",
    "    t = int(data.node_interact_times[i])\n",
    "    interactions_by_time[t].append(i)\n",
    "\n",
    "# === Extract embeddings per time step ===\n",
    "all_embeddings = []\n",
    "unique_times = sorted(interactions_by_time.keys())\n",
    "\n",
    "\n",
    "for t in tqdm(unique_times, desc=\"Time points\"):\n",
    "    node_to_embs = defaultdict(list)\n",
    "\n",
    "    for idx in interactions_by_time[t]:\n",
    "        src = data.src_node_ids[idx]\n",
    "        dst = data.dst_node_ids[idx]\n",
    "        time = data.node_interact_times[idx]\n",
    "        edge_id = data.edge_ids[idx]\n",
    "\n",
    "        src_np = np.array([src])\n",
    "        dst_np = np.array([dst])\n",
    "        t_np = np.array([time])\n",
    "        e_np = np.array([edge_id])\n",
    "\n",
    "        with torch.no_grad():\n",
    "            if args.model_name in ['JODIE', 'DyRep', 'TGN']:\n",
    "                src_emb, dst_emb = model.compute_src_dst_node_temporal_embeddings(\n",
    "                    src_node_ids=src_np,\n",
    "                    dst_node_ids=dst_np,\n",
    "                    node_interact_times=t_np,\n",
    "                    edge_ids=e_np,\n",
    "                    edges_are_positive=True,\n",
    "                    num_neighbors=args.num_neighbors\n",
    "                )\n",
    "            elif args.model_name in ['TGAT', 'CAWN', 'TCL', 'GraphMixer']:\n",
    "                src_emb, dst_emb = model.compute_src_dst_node_temporal_embeddings(\n",
    "                    src_node_ids=src_np,\n",
    "                    dst_node_ids=dst_np,\n",
    "                    node_interact_times=t_np,\n",
    "                    num_neighbors=args.num_neighbors\n",
    "                )\n",
    "            elif args.model_name == 'DyGFormer':\n",
    "                src_emb, dst_emb = model.compute_src_dst_node_temporal_embeddings(\n",
    "                    src_node_ids=src_np,\n",
    "                    dst_node_ids=dst_np,\n",
    "                    node_interact_times=t_np\n",
    "                )\n",
    "            else:\n",
    "                raise ValueError(f\"Unsupported model {args.model_name} in embedding extraction.\")\n",
    "\n",
    "        node_to_embs[src].append(src_emb.squeeze(0).cpu())\n",
    "        node_to_embs[dst].append(dst_emb.squeeze(0).cpu())\n",
    "\n",
    "        if args.model_name in ['JODIE', 'DyRep', 'TGN']:\n",
    "            model.memory_bank.detach_memory_bank()\n",
    "\n",
    "    # Average per node\n",
    "    for node_id, embs in node_to_embs.items():\n",
    "        avg_emb = torch.stack(embs).mean(dim=0).numpy()\n",
    "        all_embeddings.append({\n",
    "            'time': int(t),\n",
    "            'node_id': int(node_id),\n",
    "            'embedding': avg_emb\n",
    "        })\n",
    "\n",
    "# === Save ===\n",
    "output_path = f\"node_embeddings_discrete_avg_{args.model_name}_{args.dataset_name}.pkl\"\n",
    "with open(output_path, \"wb\") as f:\n",
    "    pickle.dump(all_embeddings, f)\n",
    "\n",
    "    \n",
    "    \n",
    "#### ADDED ####\n",
    "# Sort embeddings first by time, then by node_id\n",
    "sorted_embeddings = sorted(all_embeddings, key=lambda e: (e['time'], e['node_id']))\n",
    "text_output_path = f\"node_embeddings_discrete_avg_{args.model_name}_{args.dataset_name}.txt\"\n",
    "\n",
    "with open(text_output_path, \"w\", encoding=\"utf-8\") as f_txt:\n",
    "    num_nodes = len(set((e['node_id'], e['time']) for e in all_embeddings))\n",
    "    num_timesteps = len(set(e['time'] for e in all_embeddings))\n",
    "    f_txt.write(f\"{num_nodes} {num_timesteps}\\n\")\n",
    "    for entry in sorted_embeddings:\n",
    "        emb_str = \" \".join([f\"{x:.6f}\" for x in entry['embedding']])\n",
    "        f_txt.write(f\"{entry['node_id']} {emb_str}\\n\")\n",
    "\n",
    "print(f\"\\n? Saved {len(all_embeddings)} node-time embeddings to: {output_path}\")\n",
    "print(f\"\\n? Saved {len(all_embeddings)} node-time embeddings text to: {text_output_path}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d54a67c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
