{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "701405a4-565f-4576-9634-4e60988889d3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\user\\anaconda3\\envs\\modularity_node_embedding\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "# Adjust path if necessary\n",
    "import os\n",
    "import numpy as np\n",
    "from benchmarking_utils_arxiv import run_benchmark\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "\n",
    "# configure output dir\n",
    "OUT_DIR = \"./benchmark_outputs\"\n",
    "os.makedirs(OUT_DIR, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9cd5bb2e-0414-4035-a2eb-07b20dd1b9d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Configuration ---\n",
    "datasets = [\"arxiv\"]  # names handled by module\n",
    "\n",
    "seeds = [46]\n",
    "# Two splits: 30-70 and 70-30. In our convention mask_frac is fraction MASKED (unlabeled).\n",
    "# For 30-70 split (30% known) -> mask_frac = 0.7 ; For 70-30 split mask_frac = 0.3\n",
    "mask_fracs = [0.7, 0.3]\n",
    "\n",
    "# Embedding / classifier lists (None -> module defaults)\n",
    "embedding_methods = ['random', 'given', 'deepwalk', 'node2vec', 'dgi', 'fuse', 'vgae']\n",
    "#embedding_methods = ['fuse']\n",
    "classifiers = ['gcn', 'gat', 'graphsage']\n",
    "\n",
    "# Embedding dimensionality\n",
    "emb_dim = 150\n",
    "\n",
    "# Training / model hyperparams (keep default-ish)\n",
    "vgae_epochs = 200\n",
    "dgi_epochs = 200\n",
    "fuse_iterations = 200\n",
    "\n",
    "# device for PyG models\n",
    "device = 'cpu'  # or 'cuda' if available and configured\n",
    "\n",
    "# Where outputs will be stored\n",
    "save_dir = OUT_DIR\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1d93c591-5d04-4e27-86aa-c8175aefc80c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using custom mask: ./masks\\Arxiv\\70_30\\Arxiv_70_30_masked_indices_seed46.npy\n",
      "[arxiv][seed=46][mf=0.7] Masked=50802, Unmasked=118541\n",
      "[arxiv][seed=46][mask_frac=0.7] Running random …\n",
      "Embedding: random, Classifier: gcn\n",
      "Accuracy: 0.3319\n",
      "F1-Score: 0.0863\n",
      "Kappa: 0.2418\n",
      "Embedding generation time: 0.46s\n",
      "Classifier runtime: 10.07s\n",
      "--------------------------------------------------\n",
      "Embedding: random, Classifier: gat\n",
      "Accuracy: 0.2237\n",
      "F1-Score: 0.0300\n",
      "Kappa: 0.0926\n",
      "Embedding generation time: 0.46s\n",
      "Classifier runtime: 44.33s\n",
      "--------------------------------------------------\n",
      "Embedding: random, Classifier: graphsage\n",
      "Accuracy: 0.1615\n",
      "F1-Score: 0.0163\n",
      "Kappa: 0.0151\n",
      "Embedding generation time: 0.46s\n",
      "Classifier runtime: 22.77s\n",
      "--------------------------------------------------\n",
      "[arxiv][seed=46][mask_frac=0.7] Running given …\n",
      "Embedding: given, Classifier: gcn\n",
      "Accuracy: 0.3819\n",
      "F1-Score: 0.0794\n",
      "Kappa: 0.2976\n",
      "Embedding generation time: 0.05s\n",
      "Classifier runtime: 9.29s\n",
      "--------------------------------------------------\n",
      "Embedding: given, Classifier: gat\n",
      "Accuracy: 0.5874\n",
      "F1-Score: 0.3294\n",
      "Kappa: 0.5462\n",
      "Embedding generation time: 0.05s\n",
      "Classifier runtime: 45.53s\n",
      "--------------------------------------------------\n",
      "Embedding: given, Classifier: graphsage\n",
      "Accuracy: 0.5416\n",
      "F1-Score: 0.1792\n",
      "Kappa: 0.4910\n",
      "Embedding generation time: 0.05s\n",
      "Classifier runtime: 21.20s\n",
      "--------------------------------------------------\n",
      "[arxiv][seed=46][mask_frac=0.7] Running deepwalk …\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Computing transition probabilities: 100%|██████████████████████████████████████| 169343/169343 [43:55<00:00, 64.26it/s]\n",
      "Generating walks (CPU: 1): 100%|████████████████████████████████████████████████████| 10/10 [1:04:02<00:00, 384.28s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Embedding: deepwalk, Classifier: gcn\n",
      "Accuracy: 0.5004\n",
      "F1-Score: 0.2180\n",
      "Kappa: 0.4411\n",
      "Embedding generation time: 13029.78s\n",
      "Classifier runtime: 9.46s\n",
      "--------------------------------------------------\n",
      "Embedding: deepwalk, Classifier: gat\n",
      "Accuracy: 0.6858\n",
      "F1-Score: 0.4601\n",
      "Kappa: 0.6562\n",
      "Embedding generation time: 13029.78s\n",
      "Classifier runtime: 45.20s\n",
      "--------------------------------------------------\n",
      "Embedding: deepwalk, Classifier: graphsage\n",
      "Accuracy: 0.6239\n",
      "F1-Score: 0.2421\n",
      "Kappa: 0.5856\n",
      "Embedding generation time: 13029.78s\n",
      "Classifier runtime: 22.82s\n",
      "--------------------------------------------------\n",
      "[arxiv][seed=46][mask_frac=0.7] Running node2vec …\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Computing transition probabilities: 100%|██████████████████████████████████████| 169343/169343 [44:03<00:00, 64.07it/s]\n",
      "Generating walks (CPU: 1): 100%|████████████████████████████████████████████████████| 10/10 [1:05:03<00:00, 390.34s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Embedding: node2vec, Classifier: gcn\n",
      "Accuracy: 0.4920\n",
      "F1-Score: 0.1992\n",
      "Kappa: 0.4311\n",
      "Embedding generation time: 12899.23s\n",
      "Classifier runtime: 9.09s\n",
      "--------------------------------------------------\n",
      "Embedding: node2vec, Classifier: gat\n",
      "Accuracy: 0.6787\n",
      "F1-Score: 0.4506\n",
      "Kappa: 0.6484\n",
      "Embedding generation time: 12899.23s\n",
      "Classifier runtime: 43.89s\n",
      "--------------------------------------------------\n",
      "Embedding: node2vec, Classifier: graphsage\n",
      "Accuracy: 0.6203\n",
      "F1-Score: 0.2421\n",
      "Kappa: 0.5818\n",
      "Embedding generation time: 12899.23s\n",
      "Classifier runtime: 22.47s\n",
      "--------------------------------------------------\n",
      "[arxiv][seed=46][mask_frac=0.7] Running dgi …\n",
      "Embedding: dgi, Classifier: gcn\n",
      "Accuracy: 0.1637\n",
      "F1-Score: 0.0070\n",
      "Kappa: 0.0000\n",
      "Embedding generation time: 633.06s\n",
      "Classifier runtime: 8.97s\n",
      "--------------------------------------------------\n",
      "Embedding: dgi, Classifier: gat\n",
      "Accuracy: 0.1313\n",
      "F1-Score: 0.0073\n",
      "Kappa: 0.0051\n",
      "Embedding generation time: 633.06s\n",
      "Classifier runtime: 43.36s\n",
      "--------------------------------------------------\n",
      "Embedding: dgi, Classifier: graphsage\n",
      "Accuracy: 0.1637\n",
      "F1-Score: 0.0070\n",
      "Kappa: 0.0000\n",
      "Embedding generation time: 633.06s\n",
      "Classifier runtime: 21.75s\n",
      "--------------------------------------------------\n",
      "[arxiv][seed=46][mask_frac=0.7] Running fuse …\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Random walks per node: 100%|█████████████████████████████████████████████████| 169343/169343 [01:54<00:00, 1475.35it/s]\n",
      "Computing attention weights: 100%|██████████████████████████████████████████| 169343/169343 [00:16<00:00, 10417.63it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [20:22<00:00,  6.11s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Embedding: fuse, Classifier: gcn\n",
      "Accuracy: 0.4997\n",
      "F1-Score: 0.2353\n",
      "Kappa: 0.4378\n",
      "Embedding generation time: 1360.30s\n",
      "Classifier runtime: 9.20s\n",
      "--------------------------------------------------\n",
      "Embedding: fuse, Classifier: gat\n",
      "Accuracy: 0.6745\n",
      "F1-Score: 0.4682\n",
      "Kappa: 0.6414\n",
      "Embedding generation time: 1360.30s\n",
      "Classifier runtime: 43.41s\n",
      "--------------------------------------------------\n",
      "Embedding: fuse, Classifier: graphsage\n",
      "Accuracy: 0.6191\n",
      "F1-Score: 0.2344\n",
      "Kappa: 0.5795\n",
      "Embedding generation time: 1360.30s\n",
      "Classifier runtime: 22.34s\n",
      "--------------------------------------------------\n",
      "[arxiv][seed=46][mask_frac=0.7] Running vgae …\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [17:50<00:00,  5.35s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Embedding: vgae, Classifier: gcn\n",
      "Accuracy: 0.1312\n",
      "F1-Score: 0.0058\n",
      "Kappa: 0.0000\n",
      "Embedding generation time: 1072.42s\n",
      "Classifier runtime: 9.20s\n",
      "--------------------------------------------------\n",
      "Embedding: vgae, Classifier: gat\n",
      "Accuracy: 0.1344\n",
      "F1-Score: 0.0082\n",
      "Kappa: 0.0082\n",
      "Embedding generation time: 1072.42s\n",
      "Classifier runtime: 43.72s\n",
      "--------------------------------------------------\n",
      "Embedding: vgae, Classifier: graphsage\n",
      "Accuracy: 0.1653\n",
      "F1-Score: 0.0092\n",
      "Kappa: 0.0036\n",
      "Embedding generation time: 1072.42s\n",
      "Classifier runtime: 22.64s\n",
      "--------------------------------------------------\n",
      "Using custom mask: ./masks\\Arxiv\\30_70\\Arxiv_30_70_masked_indices_seed46.npy\n",
      "[arxiv][seed=46][mf=0.3] Masked=118540, Unmasked=50803\n",
      "[arxiv][seed=46][mask_frac=0.3] Running random …\n",
      "Embedding: random, Classifier: gcn\n",
      "Accuracy: 0.2258\n",
      "F1-Score: 0.0528\n",
      "Kappa: 0.1302\n",
      "Embedding generation time: 0.43s\n",
      "Classifier runtime: 5.25s\n",
      "--------------------------------------------------\n",
      "Embedding: random, Classifier: gat\n",
      "Accuracy: 0.1908\n",
      "F1-Score: 0.0224\n",
      "Kappa: 0.0558\n",
      "Embedding generation time: 0.43s\n",
      "Classifier runtime: 24.83s\n",
      "--------------------------------------------------\n",
      "Embedding: random, Classifier: graphsage\n",
      "Accuracy: 0.1513\n",
      "F1-Score: 0.0165\n",
      "Kappa: 0.0099\n",
      "Embedding generation time: 0.43s\n",
      "Classifier runtime: 16.47s\n",
      "--------------------------------------------------\n",
      "[arxiv][seed=46][mask_frac=0.3] Running given …\n",
      "Embedding: given, Classifier: gcn\n",
      "Accuracy: 0.4122\n",
      "F1-Score: 0.0973\n",
      "Kappa: 0.3355\n",
      "Embedding generation time: 0.05s\n",
      "Classifier runtime: 5.24s\n",
      "--------------------------------------------------\n",
      "Embedding: given, Classifier: gat\n",
      "Accuracy: 0.5674\n",
      "F1-Score: 0.2834\n",
      "Kappa: 0.5211\n",
      "Embedding generation time: 0.05s\n",
      "Classifier runtime: 24.26s\n",
      "--------------------------------------------------\n",
      "Embedding: given, Classifier: graphsage\n",
      "Accuracy: 0.5365\n",
      "F1-Score: 0.1775\n",
      "Kappa: 0.4857\n",
      "Embedding generation time: 0.05s\n",
      "Classifier runtime: 15.56s\n",
      "--------------------------------------------------\n",
      "[arxiv][seed=46][mask_frac=0.3] Running deepwalk …\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Computing transition probabilities: 100%|██████████████████████████████████████| 169343/169343 [48:13<00:00, 58.52it/s]\n",
      "Generating walks (CPU: 1): 100%|████████████████████████████████████████████████████| 10/10 [1:01:27<00:00, 368.74s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Embedding: deepwalk, Classifier: gcn\n",
      "Accuracy: 0.5143\n",
      "F1-Score: 0.2684\n",
      "Kappa: 0.4618\n",
      "Embedding generation time: 12996.76s\n",
      "Classifier runtime: 4.45s\n",
      "--------------------------------------------------\n",
      "Embedding: deepwalk, Classifier: gat\n",
      "Accuracy: 0.6765\n",
      "F1-Score: 0.4444\n",
      "Kappa: 0.6456\n",
      "Embedding generation time: 12996.76s\n",
      "Classifier runtime: 19.42s\n",
      "--------------------------------------------------\n",
      "Embedding: deepwalk, Classifier: graphsage\n",
      "Accuracy: 0.6195\n",
      "F1-Score: 0.2339\n",
      "Kappa: 0.5807\n",
      "Embedding generation time: 12996.76s\n",
      "Classifier runtime: 13.45s\n",
      "--------------------------------------------------\n",
      "[arxiv][seed=46][mask_frac=0.3] Running node2vec …\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Computing transition probabilities: 100%|██████████████████████████████████████| 169343/169343 [43:33<00:00, 64.80it/s]\n",
      "Generating walks (CPU: 1): 100%|██████████████████████████████████████████████████████| 10/10 [58:02<00:00, 348.24s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Embedding: node2vec, Classifier: gcn\n",
      "Accuracy: 0.5032\n",
      "F1-Score: 0.2524\n",
      "Kappa: 0.4486\n",
      "Embedding generation time: 12038.33s\n",
      "Classifier runtime: 4.36s\n",
      "--------------------------------------------------\n",
      "Embedding: node2vec, Classifier: gat\n",
      "Accuracy: 0.6686\n",
      "F1-Score: 0.4370\n",
      "Kappa: 0.6368\n",
      "Embedding generation time: 12038.33s\n",
      "Classifier runtime: 19.16s\n",
      "--------------------------------------------------\n",
      "Embedding: node2vec, Classifier: graphsage\n",
      "Accuracy: 0.6175\n",
      "F1-Score: 0.2421\n",
      "Kappa: 0.5791\n",
      "Embedding generation time: 12038.33s\n",
      "Classifier runtime: 13.21s\n",
      "--------------------------------------------------\n",
      "[arxiv][seed=46][mask_frac=0.3] Running dgi …\n",
      "Embedding: dgi, Classifier: gcn\n",
      "Accuracy: 0.1616\n",
      "F1-Score: 0.0070\n",
      "Kappa: 0.0000\n",
      "Embedding generation time: 758.04s\n",
      "Classifier runtime: 4.30s\n",
      "--------------------------------------------------\n",
      "Embedding: dgi, Classifier: gat\n",
      "Accuracy: 0.1616\n",
      "F1-Score: 0.0070\n",
      "Kappa: 0.0000\n",
      "Embedding generation time: 758.04s\n",
      "Classifier runtime: 18.97s\n",
      "--------------------------------------------------\n",
      "Embedding: dgi, Classifier: graphsage\n",
      "Accuracy: 0.1616\n",
      "F1-Score: 0.0070\n",
      "Kappa: 0.0000\n",
      "Embedding generation time: 758.04s\n",
      "Classifier runtime: 13.05s\n",
      "--------------------------------------------------\n",
      "[arxiv][seed=46][mask_frac=0.3] Running fuse …\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Random walks per node: 100%|█████████████████████████████████████████████████| 169343/169343 [02:27<00:00, 1149.34it/s]\n",
      "Computing attention weights: 100%|██████████████████████████████████████████| 169343/169343 [00:15<00:00, 10981.06it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [25:29<00:00,  7.65s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Embedding: fuse, Classifier: gcn\n",
      "Accuracy: 0.4503\n",
      "F1-Score: 0.1403\n",
      "Kappa: 0.3782\n",
      "Embedding generation time: 1698.52s\n",
      "Classifier runtime: 4.42s\n",
      "--------------------------------------------------\n",
      "Embedding: fuse, Classifier: gat\n",
      "Accuracy: 0.6383\n",
      "F1-Score: 0.4318\n",
      "Kappa: 0.6003\n",
      "Embedding generation time: 1698.52s\n",
      "Classifier runtime: 19.22s\n",
      "--------------------------------------------------\n",
      "Embedding: fuse, Classifier: graphsage\n",
      "Accuracy: 0.5965\n",
      "F1-Score: 0.2479\n",
      "Kappa: 0.5559\n",
      "Embedding generation time: 1698.52s\n",
      "Classifier runtime: 13.62s\n",
      "--------------------------------------------------\n",
      "[arxiv][seed=46][mask_frac=0.3] Running vgae …\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [18:16<00:00,  5.48s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Embedding: vgae, Classifier: gcn\n",
      "Accuracy: 0.1617\n",
      "F1-Score: 0.0070\n",
      "Kappa: 0.0001\n",
      "Embedding generation time: 1098.25s\n",
      "Classifier runtime: 4.49s\n",
      "--------------------------------------------------\n",
      "Embedding: vgae, Classifier: gat\n",
      "Accuracy: 0.1616\n",
      "F1-Score: 0.0071\n",
      "Kappa: 0.0002\n",
      "Embedding generation time: 1098.25s\n",
      "Classifier runtime: 19.00s\n",
      "--------------------------------------------------\n",
      "Embedding: vgae, Classifier: graphsage\n",
      "Accuracy: 0.1616\n",
      "F1-Score: 0.0070\n",
      "Kappa: 0.0000\n",
      "Embedding generation time: 1098.25s\n",
      "Classifier runtime: 13.29s\n",
      "--------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "res = run_benchmark(\n",
    "    datasets=datasets,\n",
    "    seeds=seeds,\n",
    "    mask_fracs=mask_fracs,\n",
    "    emb_dim=emb_dim,\n",
    "    embedding_methods=embedding_methods,\n",
    "    classifiers=classifiers,\n",
    "    vgae_epochs=vgae_epochs,\n",
    "    dgi_epochs=dgi_epochs,\n",
    "    fuse_iterations=fuse_iterations,\n",
    "    save_dir=save_dir,\n",
    "    device=device,\n",
    "    masks_root=\"./masks\",\n",
    "    verbose=True\n",
    ")"
   ]
  },
  {
   "cell_type": "raw",
   "id": "7f5d3f6f-6e5a-498d-9d0a-5b42c8ac5eb7",
   "metadata": {},
   "source": [
    "import pandas as pd\n",
    "print(\"Per-run results saved at:\", save_dir)\n",
    "display(res['per_run'].head())\n",
    "display(res['avg_by_model_and_classifier'].sort_values(['dataset','mask_frac','embedding','classifier']).head(20))\n",
    "display(res['avg_embedding_times'].sort_values(['dataset','mask_frac','avg_embedding_time']).head(20))\n",
    "# You can write these DataFrames to separate CSV too (they are saved by the module)."
   ]
  },
  {
   "cell_type": "raw",
   "id": "ed9e7684-6a6e-4995-89b6-f325d22a2090",
   "metadata": {},
   "source": [
    "import numpy as np, os\n",
    "dataset = \"cora\"\n",
    "seed = 42\n",
    "emb_name = \"fuse\"\n",
    "path = os.path.join(save_dir, dataset, f\"seed_{seed}\", f\"embedding_{emb_name}.npy\")\n",
    "E = np.load(path)\n",
    "print(\"Loaded embedding shape:\", E.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6dfe209d-8cef-4078-870e-f26e3ff9c998",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
