{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "97e3ac1e",
   "metadata": {},
   "source": [
    "# Module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6aa27f8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Main script for training temporal GNN models.\n",
    "\"\"\"\n",
    "\n",
    "import os\n",
    "import sys\n",
    "import time\n",
    "import json\n",
    "import shutil\n",
    "import logging\n",
    "import warnings\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "from models.TGAT import TGAT\n",
    "from models.CAWN import CAWN\n",
    "from models.TCL import TCL\n",
    "from models.GraphMixer import GraphMixer\n",
    "from models.DyGFormer import DyGFormer\n",
    "from models.MemoryModel import MemoryModel, compute_src_dst_node_time_shifts\n",
    "from models.modules import MergeLayer\n",
    "\n",
    "from utils.utils import (\n",
    "    set_random_seed,\n",
    "    convert_to_gpu,\n",
    "    get_parameter_sizes,\n",
    "    create_optimizer,\n",
    "    get_neighbor_sampler,\n",
    "    NegativeEdgeSampler\n",
    ")\n",
    "from utils.metrics import get_link_prediction_metrics\n",
    "from utils.DataLoader import get_idx_data_loader, get_link_prediction_data\n",
    "from utils.EarlyStopping import EarlyStopping\n",
    "from utils.load_configs import get_link_prediction_args\n",
    "from evaluate_models_utils import evaluate_model_link_prediction\n",
    "\n",
    "\n",
    "def ensure_undirected(train_data):\n",
    "    \"\"\"\n",
    "    Ensures that for every (u, v, t, e) in train_data, (v, u, t, e') also exists.\n",
    "    If not, it adds it.\n",
    "    \"\"\"\n",
    "    # Get original edges\n",
    "    edges = set(zip(\n",
    "        train_data.src_node_ids,\n",
    "        train_data.dst_node_ids,\n",
    "        train_data.node_interact_times\n",
    "    ))\n",
    "\n",
    "    # Collect new reversed edges\n",
    "    new_src, new_dst, new_times, new_edge_ids = [], [], [], []\n",
    "    max_edge_id = max(train_data.edge_ids) + 1\n",
    "\n",
    "    for i in range(len(train_data.src_node_ids)):\n",
    "        u = train_data.src_node_ids[i]\n",
    "        v = train_data.dst_node_ids[i]\n",
    "        t = train_data.node_interact_times[i]\n",
    "        e = train_data.edge_ids[i]\n",
    "\n",
    "        if (v, u, t) not in edges:\n",
    "            new_src.append(v)\n",
    "            new_dst.append(u)\n",
    "            new_times.append(t)\n",
    "            new_edge_ids.append(max_edge_id)\n",
    "            max_edge_id += 1\n",
    "\n",
    "    # Append new reversed edges\n",
    "    if new_src:\n",
    "        train_data.src_node_ids = np.concatenate([train_data.src_node_ids, np.array(new_src)])\n",
    "        train_data.dst_node_ids = np.concatenate([train_data.dst_node_ids, np.array(new_dst)])\n",
    "        train_data.node_interact_times = np.concatenate([train_data.node_interact_times, np.array(new_times)])\n",
    "        train_data.edge_ids = np.concatenate([train_data.edge_ids, np.array(new_edge_ids)])\n",
    "\n",
    "    return train_data\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf25106c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "43069536",
   "metadata": {},
   "source": [
    "# Settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29fb6489",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Args:\n",
    "    \"\"\"\n",
    "    Configuration class for model and training setup.\n",
    "    \"\"\"\n",
    "\n",
    "    #### Dataset selection ####\n",
    "\n",
    "    # Real-world datasets\n",
    "    #dataset_name = \"brain\"\n",
    "    #dataset_name = \"school\"\n",
    "    #dataset_name = \"stock\"\n",
    "\n",
    "    # Synthetic dataset\n",
    "    dataset_name = \"synthetic_exp2.1\"\n",
    "    #dataset_name = \"synthetic_exp2.2\"\n",
    "    #### Model selection ####\n",
    "\n",
    "    # Choose one model to activate\n",
    "    #model_name = 'JODIE'\n",
    "    #model_name = 'DyRep'\n",
    "    #model_name = 'TGN'\n",
    "    model_name = \"TGAT\"\n",
    "    #model_name = 'DyGFormer'\n",
    "\n",
    "    #### General training configuration ####\n",
    "\n",
    "    batch_size = 32  # Number of samples per training batch\n",
    "\n",
    "    gpu = 0  # GPU index if using CUDA, otherwise set -1 for CPU\n",
    "\n",
    "    #### Neighbor sampling ####\n",
    "\n",
    "    num_neighbors = 5  # Number of temporal neighbors to sample\n",
    "    sample_neighbor_strategy = 'uniform'  # Strategy for sampling neighbors\n",
    "    time_scaling_factor = 10.0  # Scaling factor for time-based features\n",
    "\n",
    "    #### Model-specific hyperparameters ####\n",
    "\n",
    "    num_walk_heads = 4  # For models with random walk or multi-head attention\n",
    "    num_heads = 4  # Number of attention heads\n",
    "    num_layers = 2  # Number of layers (e.g., GNN or transformer layers)\n",
    "    walk_length = 5  # Length of random walks (if applicable)\n",
    "    time_gap = 1.0  # Time gap resolution\n",
    "    time_feat_dim = 32  # Dimension of time features\n",
    "    position_feat_dim = 32  # Dimension of position features\n",
    "\n",
    "    patch_size = 16  # Patch size for temporal sequences    \n",
    "    channel_embedding_dim = 12  # Embedding size per channel\n",
    "    \n",
    "    ### brain,stock\n",
    "    max_input_sequence_length = 10  # Max length of input temporal sequence\n",
    "    ### Else\n",
    "    #max_input_sequence_length = 6  # Max length of input temporal sequence\n",
    "    #### Training settings ####\n",
    "\n",
    "    learning_rate = 0.001  # Optimizer learning rate\n",
    "    dropout = 0.1  # Dropout probability\n",
    "    num_epochs = 50  # Total number of training epochs\n",
    "    optimizer = 'Adam'  # Optimizer choice\n",
    "    weight_decay = 0.0005  # L2 regularization weight\n",
    "    patience = 5  # Early stopping patience (epochs without improvement)\n",
    "\n",
    "    #### Dataset splits ####\n",
    "\n",
    "    # There are no valdiation or test set becuase we simply want to see the node embeddings\n",
    "    val_ratio = 0.0  # Validation set ratio\n",
    "    test_ratio = 0.0  # Test set ratio\n",
    "\n",
    "    #### Experiment settings ####\n",
    "\n",
    "    num_runs = 1  # Number of training runs for averaging\n",
    "    negative_sample_strategy = 'random'  # Strategy for negative sampling\n",
    "    load_best_configs = False  # Whether to load best hyperparameters\n",
    "\n",
    "    #### Device setup ####\n",
    "\n",
    "    device = 'cuda:0' if torch.cuda.is_available() and gpu >= 0 else 'cpu'\n",
    "\n",
    "# Instantiate the args object\n",
    "args = Args()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a008294a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "9e424401",
   "metadata": {},
   "source": [
    "# Run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7114cad1",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "warnings.filterwarnings('ignore')\n",
    "    \n",
    "# get data for training, validation and testing\n",
    "node_raw_features, edge_raw_features, full_data, train_data, val_data, test_data, new_node_val_data, new_node_test_data = \\\n",
    "    get_link_prediction_data(dataset_name=args.dataset_name, val_ratio=args.val_ratio, test_ratio=args.test_ratio)\n",
    "\n",
    "sorted_indices = np.argsort(train_data.node_interact_times)\n",
    "train_data.src_node_ids = train_data.src_node_ids[sorted_indices]\n",
    "train_data.dst_node_ids = train_data.dst_node_ids[sorted_indices]\n",
    "train_data.node_interact_times = train_data.node_interact_times[sorted_indices]\n",
    "train_data.edge_ids = train_data.edge_ids[sorted_indices] \n",
    "\n",
    "#train_data = ensure_undirected(train_data)\n",
    "    \n",
    "# initialize training neighbor sampler to retrieve temporal graph\n",
    "train_neighbor_sampler = get_neighbor_sampler(data=train_data, sample_neighbor_strategy=args.sample_neighbor_strategy,\n",
    "                                                  time_scaling_factor=args.time_scaling_factor, seed=0)\n",
    "\n",
    "# initialize validation and test neighbor sampler to retrieve temporal graph\n",
    "full_neighbor_sampler = get_neighbor_sampler(data=full_data, sample_neighbor_strategy=args.sample_neighbor_strategy,\n",
    "                                                 time_scaling_factor=args.time_scaling_factor, seed=1)\n",
    "\n",
    "# initialize negative samplers, set seeds for validation and testing so negatives are the same across different runs\n",
    "# in the inductive setting, negatives are sampled only amongst other new nodes\n",
    "# train negative edge sampler does not need to specify the seed, but evaluation samplers need to do so\n",
    "train_neg_edge_sampler = NegativeEdgeSampler(src_node_ids=train_data.src_node_ids, dst_node_ids=train_data.dst_node_ids)\n",
    "val_neg_edge_sampler = NegativeEdgeSampler(src_node_ids=full_data.src_node_ids, dst_node_ids=full_data.dst_node_ids, seed=0)\n",
    "new_node_val_neg_edge_sampler = NegativeEdgeSampler(src_node_ids=new_node_val_data.src_node_ids, dst_node_ids=new_node_val_data.dst_node_ids, seed=1)\n",
    "test_neg_edge_sampler = NegativeEdgeSampler(src_node_ids=full_data.src_node_ids, dst_node_ids=full_data.dst_node_ids, seed=2)\n",
    "new_node_test_neg_edge_sampler = NegativeEdgeSampler(src_node_ids=new_node_test_data.src_node_ids, dst_node_ids=new_node_test_data.dst_node_ids, seed=3)\n",
    "\n",
    "# get data loaders\n",
    "train_idx_data_loader = get_idx_data_loader(indices_list=list(range(len(train_data.src_node_ids))), batch_size=args.batch_size, shuffle=False)\n",
    "val_idx_data_loader = get_idx_data_loader(indices_list=list(range(len(val_data.src_node_ids))), batch_size=args.batch_size, shuffle=False)\n",
    "new_node_val_idx_data_loader = get_idx_data_loader(indices_list=list(range(len(new_node_val_data.src_node_ids))), batch_size=args.batch_size, shuffle=False)\n",
    "test_idx_data_loader = get_idx_data_loader(indices_list=list(range(len(test_data.src_node_ids))), batch_size=args.batch_size, shuffle=False)\n",
    "new_node_test_idx_data_loader = get_idx_data_loader(indices_list=list(range(len(new_node_test_data.src_node_ids))), batch_size=args.batch_size, shuffle=False)\n",
    "\n",
    "val_metric_all_runs, new_node_val_metric_all_runs, test_metric_all_runs, new_node_test_metric_all_runs = [], [], [], []\n",
    "\n",
    "for run in range(args.num_runs):\n",
    "\n",
    "    set_random_seed(seed=run)\n",
    "\n",
    "    args.seed = run\n",
    "    args.save_model_name = f'{args.model_name}_seed{args.seed}'\n",
    "\n",
    "    # set up logger\n",
    "    logging.basicConfig(level=logging.INFO)\n",
    "    logger = logging.getLogger()\n",
    "    logger.setLevel(logging.DEBUG)\n",
    "    os.makedirs(f\"./logs/{args.model_name}/{args.dataset_name}/{args.save_model_name}/\", exist_ok=True)\n",
    "    \n",
    "    # create file handler that logs debug and higher level messages\n",
    "    fh = logging.FileHandler(f\"./logs/{args.model_name}/{args.dataset_name}/{args.save_model_name}/{str(time.time())}.log\")\n",
    "    fh.setLevel(logging.DEBUG)\n",
    "    # create console handler with a higher log level\n",
    "    ch = logging.StreamHandler()\n",
    "    ch.setLevel(logging.WARNING)\n",
    "    # create formatter and add it to the handlers\n",
    "    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')\n",
    "    fh.setFormatter(formatter)\n",
    "    ch.setFormatter(formatter)\n",
    "    # add the handlers to logger\n",
    "    logger.addHandler(fh)\n",
    "    logger.addHandler(ch)\n",
    "\n",
    "    run_start_time = time.time()\n",
    "    logger.info(f\"********** Run {run + 1} starts. **********\")\n",
    "\n",
    "    logger.info(f'configuration is {args}')\n",
    "\n",
    "    # create model\n",
    "    if args.model_name == 'TGAT':\n",
    "        dynamic_backbone = TGAT(node_raw_features=node_raw_features, edge_raw_features=edge_raw_features, neighbor_sampler=train_neighbor_sampler,\n",
    "                                    time_feat_dim=args.time_feat_dim, num_layers=args.num_layers, num_heads=args.num_heads, dropout=args.dropout, device=args.device)\n",
    "    elif args.model_name in ['JODIE', 'DyRep', 'TGN']:\n",
    "        # four floats that represent the mean and standard deviation of source and destination node time shifts in the training data, which is used for JODIE\n",
    "        src_node_mean_time_shift, src_node_std_time_shift, dst_node_mean_time_shift_dst, dst_node_std_time_shift = \\\n",
    "                compute_src_dst_node_time_shifts(train_data.src_node_ids, train_data.dst_node_ids, train_data.node_interact_times)\n",
    "        dynamic_backbone = MemoryModel(node_raw_features=node_raw_features, edge_raw_features=edge_raw_features, neighbor_sampler=train_neighbor_sampler,\n",
    "                                           time_feat_dim=args.time_feat_dim, model_name=args.model_name, num_layers=args.num_layers, num_heads=args.num_heads,\n",
    "                                           dropout=args.dropout, src_node_mean_time_shift=src_node_mean_time_shift, src_node_std_time_shift=src_node_std_time_shift,\n",
    "                                           dst_node_mean_time_shift_dst=dst_node_mean_time_shift_dst, dst_node_std_time_shift=dst_node_std_time_shift, device=args.device)\n",
    "    elif args.model_name == 'CAWN':\n",
    "        dynamic_backbone = CAWN(node_raw_features=node_raw_features, edge_raw_features=edge_raw_features, neighbor_sampler=train_neighbor_sampler,\n",
    "                                    time_feat_dim=args.time_feat_dim, position_feat_dim=args.position_feat_dim, walk_length=args.walk_length,\n",
    "                                    num_walk_heads=args.num_walk_heads, dropout=args.dropout, device=args.device)\n",
    "    elif args.model_name == 'TCL':\n",
    "        dynamic_backbone = TCL(node_raw_features=node_raw_features, edge_raw_features=edge_raw_features, neighbor_sampler=train_neighbor_sampler,\n",
    "                                   time_feat_dim=args.time_feat_dim, num_layers=args.num_layers, num_heads=args.num_heads,\n",
    "                                   num_depths=args.num_neighbors + 1, dropout=args.dropout, device=args.device)\n",
    "    elif args.model_name == 'GraphMixer':\n",
    "        dynamic_backbone = GraphMixer(node_raw_features=node_raw_features, edge_raw_features=edge_raw_features, neighbor_sampler=train_neighbor_sampler,\n",
    "                                          time_feat_dim=args.time_feat_dim, num_tokens=args.num_neighbors, num_layers=args.num_layers, dropout=args.dropout, device=args.device)\n",
    "    elif args.model_name == 'DyGFormer':\n",
    "        dynamic_backbone = DyGFormer(node_raw_features=node_raw_features, edge_raw_features=edge_raw_features, neighbor_sampler=train_neighbor_sampler,\n",
    "                                         time_feat_dim=args.time_feat_dim, channel_embedding_dim=args.channel_embedding_dim, patch_size=args.patch_size,\n",
    "                                         num_layers=args.num_layers, num_heads=args.num_heads, dropout=args.dropout,\n",
    "                                         max_input_sequence_length=args.max_input_sequence_length, device=args.device)\n",
    "    else:\n",
    "        raise ValueError(f\"Wrong value for model_name {args.model_name}!\")\n",
    "    link_predictor = MergeLayer(input_dim1=node_raw_features.shape[1], input_dim2=node_raw_features.shape[1],\n",
    "                                    hidden_dim=node_raw_features.shape[1], output_dim=1)\n",
    "    model = nn.Sequential(dynamic_backbone, link_predictor)\n",
    "    logger.info(f'model -> {model}')\n",
    "    logger.info(f'model name: {args.model_name}, #parameters: {get_parameter_sizes(model) * 4} B, '\n",
    "                    f'{get_parameter_sizes(model) * 4 / 1024} KB, {get_parameter_sizes(model) * 4 / 1024 / 1024} MB.')\n",
    "\n",
    "    optimizer = create_optimizer(model=model, optimizer_name=args.optimizer, learning_rate=args.learning_rate, weight_decay=args.weight_decay)\n",
    "\n",
    "    model = convert_to_gpu(model, device=args.device)\n",
    "\n",
    "    save_model_folder = f\"./saved_models/{args.model_name}/{args.dataset_name}/{args.save_model_name}/\"\n",
    "    shutil.rmtree(save_model_folder, ignore_errors=True)\n",
    "    os.makedirs(save_model_folder, exist_ok=True)\n",
    "\n",
    "    early_stopping = EarlyStopping(patience=args.patience, save_model_folder=save_model_folder,\n",
    "                                       save_model_name=args.save_model_name, logger=logger, model_name=args.model_name)\n",
    "\n",
    "    loss_func = nn.BCELoss()\n",
    "\n",
    "    for epoch in range(args.num_epochs):\n",
    "\n",
    "        model.train()\n",
    "        if args.model_name in ['DyRep', 'TGAT', 'TGN', 'CAWN', 'TCL', 'GraphMixer', 'DyGFormer']:\n",
    "            # training, only use training graph\n",
    "            model[0].set_neighbor_sampler(train_neighbor_sampler)\n",
    "        if args.model_name in ['JODIE', 'DyRep', 'TGN']:\n",
    "            # reinitialize memory of memory-based models at the start of each epoch\n",
    "            model[0].memory_bank.__init_memory_bank__()\n",
    "\n",
    "        # store train losses and metrics\n",
    "        train_losses, train_metrics = [], []\n",
    "        train_idx_data_loader_tqdm = tqdm(train_idx_data_loader, ncols=120)\n",
    "        for batch_idx, train_data_indices in enumerate(train_idx_data_loader_tqdm):\n",
    "            train_data_indices = train_data_indices.numpy()\n",
    "            batch_src_node_ids, batch_dst_node_ids, batch_node_interact_times, batch_edge_ids = \\\n",
    "                train_data.src_node_ids[train_data_indices], train_data.dst_node_ids[train_data_indices], \\\n",
    "                train_data.node_interact_times[train_data_indices], train_data.edge_ids[train_data_indices]\n",
    "\n",
    "            _, batch_neg_dst_node_ids = train_neg_edge_sampler.sample(size=len(batch_src_node_ids))\n",
    "            batch_neg_src_node_ids = batch_src_node_ids\n",
    "\n",
    "            # we need to compute for positive and negative edges respectively, because the new sampling strategy (for evaluation) allows the negative source nodes to be\n",
    "            # different from the source nodes, this is different from previous works that just replace destination nodes with negative destination nodes\n",
    "            if args.model_name in ['TGAT', 'CAWN', 'TCL']:\n",
    "                # get temporal embedding of source and destination nodes\n",
    "                # two Tensors, with shape (batch_size, node_feat_dim)\n",
    "                batch_src_node_embeddings, batch_dst_node_embeddings = \\\n",
    "                        model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_src_node_ids,\n",
    "                                                                          dst_node_ids=batch_dst_node_ids,\n",
    "                                                                          node_interact_times=batch_node_interact_times,\n",
    "                                                                          num_neighbors=args.num_neighbors)\n",
    "\n",
    "                # get temporal embedding of negative source and negative destination nodes\n",
    "                # two Tensors, with shape (batch_size, node_feat_dim)\n",
    "                batch_neg_src_node_embeddings, batch_neg_dst_node_embeddings = \\\n",
    "                    model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_neg_src_node_ids,\n",
    "                                                                          dst_node_ids=batch_neg_dst_node_ids,\n",
    "                                                                          node_interact_times=batch_node_interact_times,\n",
    "                                                                          num_neighbors=args.num_neighbors)\n",
    "            elif args.model_name in ['JODIE', 'DyRep', 'TGN']:\n",
    "                # note that negative nodes do not change the memories while the positive nodes change the memories,\n",
    "                # we need to first compute the embeddings of negative nodes for memory-based models\n",
    "                # get temporal embedding of negative source and negative destination nodes\n",
    "                # two Tensors, with shape (batch_size, node_feat_dim)\n",
    "                batch_neg_src_node_embeddings, batch_neg_dst_node_embeddings = \\\n",
    "                    model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_neg_src_node_ids,\n",
    "                                                                          dst_node_ids=batch_neg_dst_node_ids,\n",
    "                                                                          node_interact_times=batch_node_interact_times,\n",
    "                                                                          edge_ids=None,\n",
    "                                                                          edges_are_positive=False,\n",
    "                                                                          num_neighbors=args.num_neighbors)\n",
    "                # get temporal embedding of source and destination nodes\n",
    "                # two Tensors, with shape (batch_size, node_feat_dim)\n",
    "                batch_src_node_embeddings, batch_dst_node_embeddings = \\\n",
    "                    model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_src_node_ids,\n",
    "                                                                          dst_node_ids=batch_dst_node_ids,\n",
    "                                                                          node_interact_times=batch_node_interact_times,\n",
    "                                                                          edge_ids=batch_edge_ids,\n",
    "                                                                          edges_are_positive=True,\n",
    "                                                                          num_neighbors=args.num_neighbors)\n",
    "                    \n",
    "                    \n",
    "\n",
    "                    \n",
    "                    \n",
    "            elif args.model_name in ['GraphMixer']:\n",
    "                # get temporal embedding of source and destination nodes\n",
    "                # two Tensors, with shape (batch_size, node_feat_dim)\n",
    "                batch_src_node_embeddings, batch_dst_node_embeddings = \\\n",
    "                    model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_src_node_ids,\n",
    "                                                                          dst_node_ids=batch_dst_node_ids,\n",
    "                                                                          node_interact_times=batch_node_interact_times,\n",
    "                                                                          num_neighbors=args.num_neighbors,\n",
    "                                                                          time_gap=args.time_gap)\n",
    "\n",
    "                # get temporal embedding of negative source and negative destination nodes\n",
    "                # two Tensors, with shape (batch_size, node_feat_dim)\n",
    "                batch_neg_src_node_embeddings, batch_neg_dst_node_embeddings = \\\n",
    "                    model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_neg_src_node_ids,\n",
    "                                                                          dst_node_ids=batch_neg_dst_node_ids,\n",
    "                                                                          node_interact_times=batch_node_interact_times,\n",
    "                                                                          num_neighbors=args.num_neighbors,\n",
    "                                                                          time_gap=args.time_gap)\n",
    "            elif args.model_name in ['DyGFormer']:\n",
    "                # get temporal embedding of source and destination nodes\n",
    "                # two Tensors, with shape (batch_size, node_feat_dim)\n",
    "                batch_src_node_embeddings, batch_dst_node_embeddings = \\\n",
    "                    model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_src_node_ids,\n",
    "                                                                          dst_node_ids=batch_dst_node_ids,\n",
    "                                                                          node_interact_times=batch_node_interact_times)\n",
    "\n",
    "                # get temporal embedding of negative source and negative destination nodes\n",
    "                # two Tensors, with shape (batch_size, node_feat_dim)\n",
    "                batch_neg_src_node_embeddings, batch_neg_dst_node_embeddings = \\\n",
    "                    model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_neg_src_node_ids,\n",
    "                                                                          dst_node_ids=batch_neg_dst_node_ids,\n",
    "                                                                          node_interact_times=batch_node_interact_times)\n",
    "            else:\n",
    "                raise ValueError(f\"Wrong value for model_name {args.model_name}!\")\n",
    "            # get positive and negative probabilities, shape (batch_size, )\n",
    "            positive_probabilities = model[1](input_1=batch_src_node_embeddings, input_2=batch_dst_node_embeddings).squeeze(dim=-1).sigmoid()\n",
    "            negative_probabilities = model[1](input_1=batch_neg_src_node_embeddings, input_2=batch_neg_dst_node_embeddings).squeeze(dim=-1).sigmoid()\n",
    "\n",
    "            predicts = torch.cat([positive_probabilities, negative_probabilities], dim=0)\n",
    "            labels = torch.cat([torch.ones_like(positive_probabilities), torch.zeros_like(negative_probabilities)], dim=0)\n",
    "\n",
    "            loss = loss_func(input=predicts, target=labels)\n",
    "            train_losses.append(loss.item())\n",
    "            train_metrics.append(get_link_prediction_metrics(predicts=predicts, labels=labels))\n",
    "\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            train_idx_data_loader_tqdm.set_description(f'Epoch: {epoch + 1}, train for the {batch_idx + 1}-th batch, train loss: {loss.item()}')\n",
    "\n",
    "            if args.model_name in ['JODIE', 'DyRep', 'TGN']:\n",
    "                # detach the memories and raw messages of nodes in the memory bank after each batch, so we don't back propagate to the start of time\n",
    "                model[0].memory_bank.detach_memory_bank()\n",
    "\n",
    "        if args.model_name in ['JODIE', 'DyRep', 'TGN']:\n",
    "            # backup memory bank after training so it can be used for new validation nodes\n",
    "            train_backup_memory_bank = model[0].memory_bank.backup_memory_bank()\n",
    "\n",
    "            logger.info(f'Epoch: {epoch + 1}, learning rate: {optimizer.param_groups[0][\"lr\"]}, train loss: {np.mean(train_losses):.4f}')\n",
    "            for metric_name in train_metrics[0].keys():\n",
    "                logger.info(f'train {metric_name}, {np.mean([train_metric[metric_name] for train_metric in train_metrics]):.4f}')\n",
    "\n",
    "        # Save final model manually using the same name/path as EarlyStopping would\n",
    "        save_model_path = os.path.join(save_model_folder, f\"{args.save_model_name}.pkl\")\n",
    "        torch.save(model.state_dict(), save_model_path)\n",
    "\n",
    "        # Also save non-parametric memory data if using JODIE, DyRep, or TGN\n",
    "        if args.model_name in ['JODIE', 'DyRep', 'TGN']:\n",
    "            save_memory_path = os.path.join(save_model_folder, f\"{args.save_model_name}_nonparametric_data.pkl\")\n",
    "            torch.save(model[0].memory_bank.node_raw_messages, save_memory_path)\n",
    "\n",
    "        logger.info(f\"Saved model to: {save_model_path}\")\n",
    "\n",
    "        # evaluate the best model\n",
    "        logger.info(f'get final performance on dataset {args.dataset_name}...')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b28447ca",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
