{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0a455b63-f38a-4b3e-9ff2-8de16eaeef97",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/sujan/miniforge3/envs/tf-gpu/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading http://snap.stanford.edu/ogb/data/nodeproppred/arxiv.zip\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloaded 0.08 GB: 100%|███████████████████████| 81/81 [03:37<00:00,  2.68s/it]\n",
      "Processing...\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting /Users/sujan/Downloads/ReVar/data/arxiv/arxiv.zip\n",
      "Loading necessary files...\n",
      "This might take a while.\n",
      "Processing graphs...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████| 1/1 [00:00<00:00, 43240.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converting graphs into PyG objects...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████| 1/1 [00:00<00:00, 4854.52it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "Done!\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data(num_nodes=169343, edge_index=[2, 1166243], x=[169343, 150], node_year=[169343, 1], y=[169343])\n",
      "Number of nodes: 169343, Number of features: 150, Number of classes: 40\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from ogb.nodeproppred import PygNodePropPredDataset\n",
    "import os\n",
    "from torch_geometric.data import Data\n",
    "from torch_geometric.data.data import DataEdgeAttr, DataTensorAttr\n",
    "from torch_geometric.data.storage import GlobalStorage\n",
    "from torch.serialization import safe_globals\n",
    "\n",
    "# Set dataset path\n",
    "dataset_path = '/Users/sujan/Downloads/ReVar/data/arxiv'\n",
    "os.makedirs(dataset_path, exist_ok=True)\n",
    "\n",
    "# Load ArXiv dataset safely\n",
    "with safe_globals([Data, DataEdgeAttr, DataTensorAttr, GlobalStorage]):\n",
    "    dataset = PygNodePropPredDataset(name=\"ogbn-arxiv\", root=dataset_path)\n",
    "\n",
    "data = dataset[0]\n",
    "hidden_dim = 150\n",
    "data.x = torch.randn(data.num_nodes, hidden_dim)  # Random features\n",
    "data.y = data.y.squeeze()  # Flatten labels\n",
    "\n",
    "print(data)\n",
    "print(f\"Number of nodes: {data.num_nodes}, Number of features: {data.num_node_features}, Number of classes: {dataset.num_classes}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "67361f15-e5ad-4419-9e8a-8817e0fe73ca",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      ">>> Running for mask: Arxiv_70_30_masked_indices_seed46.npy\n",
      "\n",
      "=== Running GCN with mask file: Arxiv_70_30_masked_indices_seed46.npy ===\n",
      "Train mask size: 118541, Test mask size: 50802\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/sujan/miniforge3/envs/tf-gpu/lib/python3.10/site-packages/torch_geometric/deprecation.py:26: UserWarning: 'dropout_adj' is deprecated, use 'dropout_edge' instead\n",
      "  warnings.warn(out)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[GCN Encoder] Epoch 020 | Contrastive Loss: -0.9329\n",
      "[GCN Encoder] Epoch 040 | Contrastive Loss: -0.9731\n",
      "[GCN Encoder] Epoch 060 | Contrastive Loss: -0.9887\n",
      "[GCN Encoder] Epoch 080 | Contrastive Loss: -0.9957\n",
      "[GCN Encoder] Epoch 100 | Contrastive Loss: -0.9986\n",
      "[GCN Encoder] Epoch 120 | Contrastive Loss: -0.9997\n",
      "[GCN Encoder] Epoch 140 | Contrastive Loss: -0.9999\n",
      "[GCN Encoder] Epoch 160 | Contrastive Loss: -1.0000\n",
      "[GCN Encoder] Epoch 180 | Contrastive Loss: -1.0000\n",
      "[GCN Encoder] Epoch 200 | Contrastive Loss: -1.0000\n",
      "[Linear Classifier] Epoch 020 | Loss: 3.0376 | Acc: 0.1635 | F1: 0.0070\n",
      "[Linear Classifier] Epoch 040 | Loss: 3.0098 | Acc: 0.1636 | F1: 0.0070\n",
      "[Linear Classifier] Epoch 060 | Loss: 3.0072 | Acc: 0.1637 | F1: 0.0070\n",
      "[Linear Classifier] Epoch 080 | Loss: 3.0063 | Acc: 0.1637 | F1: 0.0070\n",
      "[Linear Classifier] Epoch 100 | Loss: 3.0057 | Acc: 0.1637 | F1: 0.0070\n",
      "[Linear Classifier] Epoch 120 | Loss: 3.0051 | Acc: 0.1637 | F1: 0.0070\n",
      "[Linear Classifier] Epoch 140 | Loss: 3.0046 | Acc: 0.1637 | F1: 0.0070\n",
      "[Linear Classifier] Epoch 160 | Loss: 3.0040 | Acc: 0.1637 | F1: 0.0070\n",
      "[Linear Classifier] Epoch 180 | Loss: 3.0034 | Acc: 0.1637 | F1: 0.0070\n",
      "[Linear Classifier] Epoch 200 | Loss: 3.0028 | Acc: 0.1637 | F1: 0.0070\n",
      "\n",
      "=== Running GAT with mask file: Arxiv_70_30_masked_indices_seed46.npy ===\n",
      "Train mask size: 118541, Test mask size: 50802\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/sujan/miniforge3/envs/tf-gpu/lib/python3.10/site-packages/torch_geometric/deprecation.py:26: UserWarning: 'dropout_adj' is deprecated, use 'dropout_edge' instead\n",
      "  warnings.warn(out)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[GAT Encoder] Epoch 020 | Contrastive Loss: -0.9401\n",
      "[GAT Encoder] Epoch 040 | Contrastive Loss: -0.9598\n",
      "[GAT Encoder] Epoch 060 | Contrastive Loss: -0.9777\n",
      "[GAT Encoder] Epoch 080 | Contrastive Loss: -0.9918\n",
      "[GAT Encoder] Epoch 100 | Contrastive Loss: -0.9974\n",
      "[GAT Encoder] Epoch 120 | Contrastive Loss: -0.9993\n",
      "[GAT Encoder] Epoch 140 | Contrastive Loss: -0.9998\n",
      "[GAT Encoder] Epoch 160 | Contrastive Loss: -0.9999\n",
      "[GAT Encoder] Epoch 180 | Contrastive Loss: -1.0000\n",
      "[GAT Encoder] Epoch 200 | Contrastive Loss: -1.0000\n",
      "[Linear Classifier] Epoch 020 | Loss: 3.0286 | Acc: 0.1637 | F1: 0.0070\n",
      "[Linear Classifier] Epoch 040 | Loss: 2.9995 | Acc: 0.1637 | F1: 0.0070\n",
      "[Linear Classifier] Epoch 060 | Loss: 2.9950 | Acc: 0.1637 | F1: 0.0070\n",
      "[Linear Classifier] Epoch 080 | Loss: 2.9943 | Acc: 0.1637 | F1: 0.0070\n",
      "[Linear Classifier] Epoch 100 | Loss: 2.9939 | Acc: 0.1637 | F1: 0.0070\n",
      "[Linear Classifier] Epoch 120 | Loss: 2.9937 | Acc: 0.1637 | F1: 0.0070\n",
      "[Linear Classifier] Epoch 140 | Loss: 2.9935 | Acc: 0.1637 | F1: 0.0070\n",
      "[Linear Classifier] Epoch 160 | Loss: 2.9933 | Acc: 0.1637 | F1: 0.0070\n",
      "[Linear Classifier] Epoch 180 | Loss: 2.9931 | Acc: 0.1637 | F1: 0.0070\n",
      "[Linear Classifier] Epoch 200 | Loss: 2.9930 | Acc: 0.1637 | F1: 0.0070\n",
      "\n",
      "=== Running SAGE with mask file: Arxiv_70_30_masked_indices_seed46.npy ===\n",
      "Train mask size: 118541, Test mask size: 50802\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/sujan/miniforge3/envs/tf-gpu/lib/python3.10/site-packages/torch_geometric/deprecation.py:26: UserWarning: 'dropout_adj' is deprecated, use 'dropout_edge' instead\n",
      "  warnings.warn(out)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[SAGE Encoder] Epoch 020 | Contrastive Loss: -0.9970\n",
      "[SAGE Encoder] Epoch 040 | Contrastive Loss: -0.9994\n",
      "[SAGE Encoder] Epoch 060 | Contrastive Loss: -0.9998\n",
      "[SAGE Encoder] Epoch 080 | Contrastive Loss: -0.9999\n",
      "[SAGE Encoder] Epoch 100 | Contrastive Loss: -0.9999\n",
      "[SAGE Encoder] Epoch 120 | Contrastive Loss: -0.9999\n",
      "[SAGE Encoder] Epoch 140 | Contrastive Loss: -1.0000\n",
      "[SAGE Encoder] Epoch 160 | Contrastive Loss: -1.0000\n",
      "[SAGE Encoder] Epoch 180 | Contrastive Loss: -1.0000\n",
      "[SAGE Encoder] Epoch 200 | Contrastive Loss: -1.0000\n",
      "[Linear Classifier] Epoch 020 | Loss: 3.0438 | Acc: 0.1631 | F1: 0.0076\n",
      "[Linear Classifier] Epoch 040 | Loss: 3.0123 | Acc: 0.1525 | F1: 0.0133\n",
      "[Linear Classifier] Epoch 060 | Loss: 3.0024 | Acc: 0.1581 | F1: 0.0111\n",
      "[Linear Classifier] Epoch 080 | Loss: 2.9979 | Acc: 0.1605 | F1: 0.0099\n",
      "[Linear Classifier] Epoch 100 | Loss: 2.9946 | Acc: 0.1611 | F1: 0.0095\n",
      "[Linear Classifier] Epoch 120 | Loss: 2.9920 | Acc: 0.1614 | F1: 0.0093\n",
      "[Linear Classifier] Epoch 140 | Loss: 2.9899 | Acc: 0.1615 | F1: 0.0093\n",
      "[Linear Classifier] Epoch 160 | Loss: 2.9880 | Acc: 0.1616 | F1: 0.0092\n",
      "[Linear Classifier] Epoch 180 | Loss: 2.9865 | Acc: 0.1613 | F1: 0.0091\n",
      "[Linear Classifier] Epoch 200 | Loss: 2.9851 | Acc: 0.1613 | F1: 0.0092\n",
      "\n",
      "===== Summary across seeds (70_30) =====\n",
      "GAT -> Acc: 16.3675 ± nan, BAcc: 2.5000 ± nan, F1: 0.7033 ± nan\n",
      "GCN -> Acc: 16.3675 ± nan, BAcc: 2.5000 ± nan, F1: 0.7033 ± nan\n",
      "SAGE -> Acc: 16.1253 ± nan, BAcc: 2.4914 ± nan, F1: 0.9169 ± nan\n",
      "\n",
      "===== Average Execution Times (70_30) =====\n",
      "  encoder  avg_execution_time\n",
      "0     GAT         9462.589315\n",
      "1     GCN          432.879682\n",
      "2    SAGE          321.530078\n",
      "\n",
      ">>> Running for mask: Arxiv_30_70_masked_indices_seed46.npy\n",
      "\n",
      "=== Running GCN with mask file: Arxiv_30_70_masked_indices_seed46.npy ===\n",
      "Train mask size: 50803, Test mask size: 118540\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/sujan/miniforge3/envs/tf-gpu/lib/python3.10/site-packages/torch_geometric/deprecation.py:26: UserWarning: 'dropout_adj' is deprecated, use 'dropout_edge' instead\n",
      "  warnings.warn(out)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[GCN Encoder] Epoch 020 | Contrastive Loss: -0.9307\n",
      "[GCN Encoder] Epoch 040 | Contrastive Loss: -0.9718\n",
      "[GCN Encoder] Epoch 060 | Contrastive Loss: -0.9879\n",
      "[GCN Encoder] Epoch 080 | Contrastive Loss: -0.9951\n",
      "[GCN Encoder] Epoch 100 | Contrastive Loss: -0.9982\n",
      "[GCN Encoder] Epoch 120 | Contrastive Loss: -0.9995\n",
      "[GCN Encoder] Epoch 140 | Contrastive Loss: -0.9999\n",
      "[GCN Encoder] Epoch 160 | Contrastive Loss: -1.0000\n",
      "[GCN Encoder] Epoch 180 | Contrastive Loss: -1.0000\n",
      "[GCN Encoder] Epoch 200 | Contrastive Loss: -1.0000\n",
      "[Linear Classifier] Epoch 020 | Loss: 3.0363 | Acc: 0.1608 | F1: 0.0070\n",
      "[Linear Classifier] Epoch 040 | Loss: 3.0086 | Acc: 0.1609 | F1: 0.0069\n",
      "[Linear Classifier] Epoch 060 | Loss: 3.0043 | Acc: 0.1610 | F1: 0.0069\n",
      "[Linear Classifier] Epoch 080 | Loss: 3.0032 | Acc: 0.1611 | F1: 0.0070\n",
      "[Linear Classifier] Epoch 100 | Loss: 3.0027 | Acc: 0.1613 | F1: 0.0070\n",
      "[Linear Classifier] Epoch 120 | Loss: 3.0023 | Acc: 0.1613 | F1: 0.0070\n",
      "[Linear Classifier] Epoch 140 | Loss: 3.0018 | Acc: 0.1613 | F1: 0.0070\n",
      "[Linear Classifier] Epoch 160 | Loss: 3.0014 | Acc: 0.1613 | F1: 0.0070\n",
      "[Linear Classifier] Epoch 180 | Loss: 3.0009 | Acc: 0.1613 | F1: 0.0070\n",
      "[Linear Classifier] Epoch 200 | Loss: 3.0004 | Acc: 0.1614 | F1: 0.0070\n",
      "\n",
      "=== Running GAT with mask file: Arxiv_30_70_masked_indices_seed46.npy ===\n",
      "Train mask size: 50803, Test mask size: 118540\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/sujan/miniforge3/envs/tf-gpu/lib/python3.10/site-packages/torch_geometric/deprecation.py:26: UserWarning: 'dropout_adj' is deprecated, use 'dropout_edge' instead\n",
      "  warnings.warn(out)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[GAT Encoder] Epoch 020 | Contrastive Loss: -0.9383\n",
      "[GAT Encoder] Epoch 040 | Contrastive Loss: -0.9582\n",
      "[GAT Encoder] Epoch 060 | Contrastive Loss: -0.9762\n",
      "[GAT Encoder] Epoch 080 | Contrastive Loss: -0.9912\n",
      "[GAT Encoder] Epoch 100 | Contrastive Loss: -0.9970\n",
      "[GAT Encoder] Epoch 120 | Contrastive Loss: -0.9991\n",
      "[GAT Encoder] Epoch 140 | Contrastive Loss: -0.9997\n",
      "[GAT Encoder] Epoch 160 | Contrastive Loss: -0.9999\n",
      "[GAT Encoder] Epoch 180 | Contrastive Loss: -0.9999\n",
      "[GAT Encoder] Epoch 200 | Contrastive Loss: -1.0000\n",
      "[Linear Classifier] Epoch 020 | Loss: 3.0326 | Acc: 0.1616 | F1: 0.0070\n",
      "[Linear Classifier] Epoch 040 | Loss: 2.9961 | Acc: 0.1616 | F1: 0.0070\n",
      "[Linear Classifier] Epoch 060 | Loss: 2.9912 | Acc: 0.1616 | F1: 0.0070\n",
      "[Linear Classifier] Epoch 080 | Loss: 2.9903 | Acc: 0.1616 | F1: 0.0070\n",
      "[Linear Classifier] Epoch 100 | Loss: 2.9897 | Acc: 0.1616 | F1: 0.0070\n",
      "[Linear Classifier] Epoch 120 | Loss: 2.9893 | Acc: 0.1616 | F1: 0.0070\n",
      "[Linear Classifier] Epoch 140 | Loss: 2.9890 | Acc: 0.1616 | F1: 0.0070\n",
      "[Linear Classifier] Epoch 160 | Loss: 2.9886 | Acc: 0.1616 | F1: 0.0070\n",
      "[Linear Classifier] Epoch 180 | Loss: 2.9884 | Acc: 0.1616 | F1: 0.0070\n",
      "[Linear Classifier] Epoch 200 | Loss: 2.9882 | Acc: 0.1616 | F1: 0.0070\n",
      "\n",
      "=== Running SAGE with mask file: Arxiv_30_70_masked_indices_seed46.npy ===\n",
      "Train mask size: 50803, Test mask size: 118540\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/sujan/miniforge3/envs/tf-gpu/lib/python3.10/site-packages/torch_geometric/deprecation.py:26: UserWarning: 'dropout_adj' is deprecated, use 'dropout_edge' instead\n",
      "  warnings.warn(out)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[SAGE Encoder] Epoch 020 | Contrastive Loss: -0.9964\n",
      "[SAGE Encoder] Epoch 040 | Contrastive Loss: -0.9994\n",
      "[SAGE Encoder] Epoch 060 | Contrastive Loss: -0.9998\n",
      "[SAGE Encoder] Epoch 080 | Contrastive Loss: -0.9999\n",
      "[SAGE Encoder] Epoch 100 | Contrastive Loss: -0.9999\n",
      "[SAGE Encoder] Epoch 120 | Contrastive Loss: -0.9999\n",
      "[SAGE Encoder] Epoch 140 | Contrastive Loss: -1.0000\n",
      "[SAGE Encoder] Epoch 160 | Contrastive Loss: -1.0000\n",
      "[SAGE Encoder] Epoch 180 | Contrastive Loss: -1.0000\n",
      "[SAGE Encoder] Epoch 200 | Contrastive Loss: -1.0000\n",
      "[Linear Classifier] Epoch 020 | Loss: 3.0450 | Acc: 0.1604 | F1: 0.0083\n",
      "[Linear Classifier] Epoch 040 | Loss: 3.0022 | Acc: 0.1486 | F1: 0.0131\n",
      "[Linear Classifier] Epoch 060 | Loss: 2.9888 | Acc: 0.1559 | F1: 0.0118\n",
      "[Linear Classifier] Epoch 080 | Loss: 2.9819 | Acc: 0.1584 | F1: 0.0108\n",
      "[Linear Classifier] Epoch 100 | Loss: 2.9767 | Acc: 0.1585 | F1: 0.0107\n",
      "[Linear Classifier] Epoch 120 | Loss: 2.9723 | Acc: 0.1590 | F1: 0.0107\n",
      "[Linear Classifier] Epoch 140 | Loss: 2.9686 | Acc: 0.1586 | F1: 0.0107\n",
      "[Linear Classifier] Epoch 160 | Loss: 2.9654 | Acc: 0.1584 | F1: 0.0109\n",
      "[Linear Classifier] Epoch 180 | Loss: 2.9627 | Acc: 0.1583 | F1: 0.0111\n",
      "[Linear Classifier] Epoch 200 | Loss: 2.9602 | Acc: 0.1581 | F1: 0.0112\n",
      "\n",
      "===== Summary across seeds (30_70) =====\n",
      "GAT -> Acc: 16.1574 ± nan, BAcc: 2.5001 ± nan, F1: 0.6965 ± nan\n",
      "GCN -> Acc: 16.1439 ± nan, BAcc: 2.4989 ± nan, F1: 0.6985 ± nan\n",
      "SAGE -> Acc: 15.8132 ± nan, BAcc: 2.5136 ± nan, F1: 1.1208 ± nan\n",
      "\n",
      "===== Average Execution Times (30_70) =====\n",
      "  encoder  avg_execution_time\n",
      "0     GAT         9265.353171\n",
      "1     GCN          468.545609\n",
      "2    SAGE          285.551053\n"
     ]
    }
   ],
   "source": [
    "import glob\n",
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import time\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from sklearn.metrics import accuracy_score, f1_score\n",
    "from src.utils import compute_accuracy\n",
    "from src.transform import get_graph_drop_transform\n",
    "from src.imbalance import Imbalance_\n",
    "from layers import GNN\n",
    "from layers.Classifier import Classifier\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "\n",
    "def train_and_evaluate_revar(data, mask_file, emb_dir, split_name,\n",
    "                             net=\"GCN\", hidden_dim=150, n_head=1, epochs=200):\n",
    "    # -------------------------------\n",
    "    # Step 1: Handle Imbalance\n",
    "    # -------------------------------\n",
    "    num_nodes = data.num_nodes\n",
    "    data.train_mask = torch.zeros(num_nodes, dtype=torch.bool)\n",
    "    data.test_mask = torch.zeros(num_nodes, dtype=torch.bool)\n",
    "\n",
    "    imb = Imbalance_(name=\"arxiv\", data=data, ratio=0)\n",
    "    test_mask = imb.split_semi_dataset(mask_file=mask_file)\n",
    "    train_mask = ~test_mask\n",
    "    data.train_mask = train_mask\n",
    "    data.test_mask = test_mask\n",
    "\n",
    "    print(f\"\\n=== Running {net} with mask file: {os.path.basename(mask_file)} ===\")\n",
    "    print(f\"Train mask size: {train_mask.sum().item()}, Test mask size: {test_mask.sum().item()}\")\n",
    "\n",
    "    # -------------------------------\n",
    "    # Step 2: Augmentations\n",
    "    # -------------------------------\n",
    "    drop_edge_p, drop_feat_p = 0.2, 0.3\n",
    "    transform1 = get_graph_drop_transform(drop_edge_p, drop_feat_p)\n",
    "    transform2 = get_graph_drop_transform(drop_edge_p, drop_feat_p)\n",
    "    data_aug1 = transform1(data)\n",
    "    data_aug2 = transform2(data)\n",
    "\n",
    "    # -------------------------------\n",
    "    # Step 3: Encoder (GNN)\n",
    "    # -------------------------------\n",
    "    encoder_model = GNN(\n",
    "        layer_sizes=[data.num_features, hidden_dim, hidden_dim],\n",
    "        net=net,\n",
    "        n_head=n_head\n",
    "    ).to(device)\n",
    "\n",
    "    optimizer = torch.optim.Adam(encoder_model.parameters(), lr=0.01)\n",
    "\n",
    "    start_time = time.time()\n",
    "    for epoch in range(epochs):\n",
    "        encoder_model.train()\n",
    "        optimizer.zero_grad()\n",
    "        z1 = encoder_model(data_aug1)\n",
    "        z2 = encoder_model(data_aug2)\n",
    "        sim_loss = -torch.cosine_similarity(z1, z2, dim=-1).mean()\n",
    "        sim_loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        if (epoch + 1) % 20 == 0:\n",
    "            print(f\"[{net} Encoder] Epoch {epoch+1:03d} | Contrastive Loss: {sim_loss.item():.4f}\")\n",
    "    enc_time = time.time() - start_time\n",
    "\n",
    "    # -------------------------------\n",
    "    # Step 4: Get embeddings\n",
    "    # -------------------------------\n",
    "    encoder_model.eval()\n",
    "    with torch.no_grad():\n",
    "        embeddings = encoder_model(data).cpu().numpy()\n",
    "\n",
    "    # Save embeddings with split name\n",
    "    seed_id = os.path.basename(mask_file).split(\"seed\")[-1].replace(\".npy\", \"\")\n",
    "    emb_file = os.path.join(emb_dir, f\"Arxiv_{net}_{split_name}_seed{seed_id}.npy\")\n",
    "    np.save(emb_file, embeddings)\n",
    "\n",
    "    # -------------------------------\n",
    "    # Step 5: Classifier (Linear)\n",
    "    # -------------------------------\n",
    "    num_classes = int(data.y.max().item() + 1)\n",
    "    clf = Classifier(hidden_size=hidden_dim, num_class=num_classes).to(device)\n",
    "    optimizer = torch.optim.Adam(clf.parameters(), lr=0.01)\n",
    "    criterion = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "    clf_time_start = time.time()\n",
    "    for epoch in range(200):\n",
    "        clf.train()\n",
    "        optimizer.zero_grad()\n",
    "        logits, preds = clf(torch.tensor(embeddings).to(device))\n",
    "        loss = criterion(logits[data.train_mask], data.y[data.train_mask])\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        if (epoch + 1) % 20 == 0:\n",
    "            clf.eval()\n",
    "            with torch.no_grad():\n",
    "                _, preds_eval = clf(torch.tensor(embeddings).to(device))\n",
    "                acc = accuracy_score(\n",
    "                    data.y[data.test_mask].cpu(),\n",
    "                    preds_eval[data.test_mask].cpu()\n",
    "                )\n",
    "                f1 = f1_score(\n",
    "                    data.y[data.test_mask].cpu(),\n",
    "                    preds_eval[data.test_mask].cpu(),\n",
    "                    average=\"macro\"\n",
    "                )\n",
    "            print(f\"[Linear Classifier] Epoch {epoch+1:03d} | \"\n",
    "                  f\"Loss: {loss.item():.4f} | Acc: {acc:.4f} | F1: {f1:.4f}\")\n",
    "\n",
    "    clf_time = time.time() - clf_time_start\n",
    "\n",
    "    # -------------------------------\n",
    "    # Step 6: Final Evaluation\n",
    "    # -------------------------------\n",
    "    clf.eval()\n",
    "    with torch.no_grad():\n",
    "        logits, preds = clf(torch.tensor(embeddings).to(device))\n",
    "        train_acc, test_acc, train_bacc, test_bacc, train_f1, test_f1 = \\\n",
    "            compute_accuracy(preds, data.y, data.train_mask, data.test_mask)\n",
    "\n",
    "    total_time = enc_time + clf_time\n",
    "    return {\n",
    "        \"encoder\": net,\n",
    "        \"seed\": seed_id,\n",
    "        \"accuracy\": test_acc,\n",
    "        \"balanced_acc\": test_bacc,\n",
    "        \"f1\": test_f1,\n",
    "        \"time\": total_time\n",
    "    }\n",
    "\n",
    "\n",
    "def run_pipeline(data, split_name, mask_dir, emb_dir, save_dir):\n",
    "    os.makedirs(emb_dir, exist_ok=True)\n",
    "    os.makedirs(save_dir, exist_ok=True)\n",
    "\n",
    "    results = []\n",
    "    # Select only mask files for seed 46\n",
    "    mask_files = sorted(glob.glob(os.path.join(mask_dir, \"*seed46.npy\")))\n",
    "\n",
    "    for mask_file in mask_files:\n",
    "        print(f\"\\n>>> Running for mask: {os.path.basename(mask_file)}\")\n",
    "        results.append(train_and_evaluate_revar(data, mask_file, emb_dir, split_name, net=\"GCN\"))\n",
    "        results.append(train_and_evaluate_revar(data, mask_file, emb_dir, split_name, net=\"GAT\", n_head=8))\n",
    "        results.append(train_and_evaluate_revar(data, mask_file, emb_dir, split_name, net=\"SAGE\"))\n",
    "\n",
    "    # Save per-seed results\n",
    "    df = pd.DataFrame(results)\n",
    "\n",
    "    # ---- (1) Compute summary results ----\n",
    "    summary = df.groupby(\"encoder\").agg({\n",
    "        \"accuracy\": [\"mean\", \"std\"],\n",
    "        \"balanced_acc\": [\"mean\", \"std\"],\n",
    "        \"f1\": [\"mean\", \"std\"]\n",
    "    }).reset_index()\n",
    "\n",
    "    summary.columns = [\"encoder\",\n",
    "                       \"accuracy_mean\", \"accuracy_std\",\n",
    "                       \"balanced_acc_mean\", \"balanced_acc_std\",\n",
    "                       \"f1_mean\", \"f1_std\"]\n",
    "\n",
    "    # Pretty print results\n",
    "    print(f\"\\n===== Summary across seeds ({split_name}) =====\")\n",
    "    for _, row in summary.iterrows():\n",
    "        print(f\"{row['encoder']} -> \"\n",
    "              f\"Acc: {row['accuracy_mean']:.4f} ± {row['accuracy_std']:.4f}, \"\n",
    "              f\"BAcc: {row['balanced_acc_mean']:.4f} ± {row['balanced_acc_std']:.4f}, \"\n",
    "              f\"F1: {row['f1_mean']:.4f} ± {row['f1_std']:.4f}\")\n",
    "\n",
    "    # Save summary table\n",
    "    summary_file = os.path.join(save_dir, f\"arxiv_{split_name}_results.csv\")\n",
    "    summary.to_csv(summary_file, index=False)\n",
    "\n",
    "    # ---- (2) Save average execution times ----\n",
    "    exec_times = df.groupby(\"encoder\")[\"time\"].mean().reset_index()\n",
    "    exec_times.rename(columns={\"time\": \"avg_execution_time\"}, inplace=True)\n",
    "\n",
    "    exec_file = os.path.join(save_dir, f\"arxiv_{split_name}_execution_times.csv\")\n",
    "    exec_times.to_csv(exec_file, index=False)\n",
    "\n",
    "    print(f\"\\n===== Average Execution Times ({split_name}) =====\")\n",
    "    print(exec_times)\n",
    "\n",
    "\n",
    "\n",
    "# ============================\n",
    "# Run for both splits\n",
    "# ============================\n",
    "splits = {\n",
    "    \"70_30\": {\n",
    "        \"mask_dir\": \"/Users/sujan/Modularity based semi supervised learning/masks/Arxiv/70_30\",\n",
    "        \"emb_dir\": \"/Users/sujan/Downloads/ReVar/embeddings/Arxiv/70_30\",\n",
    "        \"save_dir\": \"/Users/sujan/Downloads/ReVar/results/Arxiv/70_30\"\n",
    "    },\n",
    "    \"30_70\": {\n",
    "        \"mask_dir\": \"/Users/sujan/Modularity based semi supervised learning/masks/Arxiv/30_70\",\n",
    "        \"emb_dir\": \"/Users/sujan/Downloads/ReVar/embeddings/Arxiv/30_70\",\n",
    "        \"save_dir\": \"/Users/sujan/Downloads/ReVar/results/Arxiv/30_70\"\n",
    "    }\n",
    "}\n",
    "\n",
    "for split_name, paths in splits.items():\n",
    "    run_pipeline(data, split_name, paths[\"mask_dir\"], paths[\"emb_dir\"], paths[\"save_dir\"])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8fe024d1-2476-434e-8906-399ed6000469",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (tf-gpu)",
   "language": "python",
   "name": "tf-gpu"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
