{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "eb81e1aa-1d17-4313-a2bc-fd553805810d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/sujan/Downloads/GraFN/grafn-env/lib/python3.12/site-packages/outdated/__init__.py:36: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.\n",
      "  from pkg_resources import parse_version\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data(x=[169343, 150], edge_index=[2, 2315598], y=[169343])\n",
      "Number of nodes: 169343\n",
      "Number of features: 150\n",
      "Number of classes: 40\n",
      "\n",
      "=== Processing Mask File: Arxiv_70_30_masked_indices_seed46.npy (70_30, GCN) ===\n",
      "Train nodes: 118541 | Test nodes: 50802\n",
      "Epoch 001 | Train Loss: 5.0897 | Val Loss: 5.0897\n",
      "Epoch 020 | Train Loss: 3.1748 | Val Loss: 3.1980\n",
      "Epoch 040 | Train Loss: 2.9998 | Val Loss: 3.0362\n",
      "Epoch 060 | Train Loss: 2.8308 | Val Loss: 2.8776\n",
      "Epoch 080 | Train Loss: 2.7340 | Val Loss: 2.7986\n",
      "Epoch 100 | Train Loss: 2.6888 | Val Loss: 2.7646\n",
      "Epoch 120 | Train Loss: 2.6535 | Val Loss: 2.7395\n",
      "Epoch 140 | Train Loss: 2.6236 | Val Loss: 2.7183\n",
      "Epoch 160 | Train Loss: 2.5972 | Val Loss: 2.7003\n",
      "Epoch 180 | Train Loss: 2.5732 | Val Loss: 2.6839\n",
      "Epoch 200 | Train Loss: 2.5510 | Val Loss: 2.6688\n",
      "Saved embeddings to /Users/sujan/Downloads/GraFN/embeddings/Arxiv/70_30/ogbn-arxiv_GCN_70_30_seed46.pt\n",
      "\n",
      "=== Processing Mask File: Arxiv_70_30_masked_indices_seed46.npy (70_30, GAT) ===\n",
      "Train nodes: 118541 | Test nodes: 50802\n",
      "Epoch 001 | Train Loss: 5.0246 | Val Loss: 5.0237\n",
      "Epoch 020 | Train Loss: 2.6254 | Val Loss: 2.6371\n",
      "Epoch 040 | Train Loss: 2.0397 | Val Loss: 2.1050\n",
      "Epoch 060 | Train Loss: 1.7045 | Val Loss: 1.8299\n",
      "Epoch 080 | Train Loss: 1.5130 | Val Loss: 1.7036\n",
      "Epoch 100 | Train Loss: 1.4061 | Val Loss: 1.6505\n",
      "Epoch 120 | Train Loss: 1.3263 | Val Loss: 1.6273\n",
      "Epoch 140 | Train Loss: 1.2505 | Val Loss: 1.5992\n",
      "Epoch 160 | Train Loss: 1.1847 | Val Loss: 1.5918\n",
      "Epoch 180 | Train Loss: 1.2058 | Val Loss: 1.5908\n",
      "Epoch 200 | Train Loss: 1.0906 | Val Loss: 1.5590\n",
      "Saved embeddings to /Users/sujan/Downloads/GraFN/embeddings/Arxiv/70_30/ogbn-arxiv_GAT_70_30_seed46.pt\n",
      "\n",
      "=== Processing Mask File: Arxiv_70_30_masked_indices_seed46.npy (70_30, GraphSAGE) ===\n",
      "Train nodes: 118541 | Test nodes: 50802\n",
      "Epoch 001 | Train Loss: 5.1182 | Val Loss: 5.1160\n",
      "Epoch 020 | Train Loss: 2.7563 | Val Loss: 2.8273\n",
      "Epoch 040 | Train Loss: 2.4524 | Val Loss: 2.6034\n",
      "Epoch 060 | Train Loss: 2.2377 | Val Loss: 2.4745\n",
      "Epoch 080 | Train Loss: 2.0888 | Val Loss: 2.3994\n",
      "Epoch 100 | Train Loss: 1.9803 | Val Loss: 2.3543\n",
      "Epoch 120 | Train Loss: 1.8996 | Val Loss: 2.3211\n",
      "Epoch 140 | Train Loss: 1.8443 | Val Loss: 2.3040\n",
      "Epoch 160 | Train Loss: 1.7999 | Val Loss: 2.2869\n",
      "Epoch 180 | Train Loss: 1.7656 | Val Loss: 2.2742\n",
      "Epoch 200 | Train Loss: 1.7395 | Val Loss: 2.2659\n",
      "Saved embeddings to /Users/sujan/Downloads/GraFN/embeddings/Arxiv/70_30/ogbn-arxiv_GraphSAGE_70_30_seed46.pt\n",
      "Saved results and execution times for ogbn-arxiv split 70_30, seed46 only.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import glob\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "import time\n",
    "import pandas as pd\n",
    "from sklearn.metrics import accuracy_score, f1_score\n",
    "from torch_geometric.transforms import ToUndirected\n",
    "from torch_geometric.nn import GCNConv, GATConv, SAGEConv\n",
    "import torch.nn as nn\n",
    "from layers import Classifier\n",
    "from models.GraFN import GraFN\n",
    "from ogb.nodeproppred import PygNodePropPredDataset\n",
    "\n",
    "# -------------------------------\n",
    "# Encoder definitions\n",
    "# -------------------------------\n",
    "class GCNEncoder(nn.Module):\n",
    "    def __init__(self, in_dim, hidden_dim, out_dim):\n",
    "        super().__init__()\n",
    "        self.conv1 = GCNConv(in_dim, hidden_dim)\n",
    "        self.conv2 = GCNConv(hidden_dim, out_dim)\n",
    "\n",
    "    def forward(self, data):\n",
    "        x, edge_index = data.x, data.edge_index\n",
    "        x = F.relu(self.conv1(x, edge_index))\n",
    "        x = self.conv2(x, edge_index)\n",
    "        return x\n",
    "\n",
    "\n",
    "class GATEncoder(nn.Module):\n",
    "    def __init__(self, in_dim, hidden_dim, out_dim, heads=8):\n",
    "        super().__init__()\n",
    "        self.conv1 = GATConv(in_dim, hidden_dim, heads=heads, concat=True)\n",
    "        self.conv2 = GATConv(hidden_dim * heads, out_dim, heads=1, concat=False)\n",
    "\n",
    "    def forward(self, data):\n",
    "        x, edge_index = data.x, data.edge_index\n",
    "        x = F.elu(self.conv1(x, edge_index))\n",
    "        x = self.conv2(x, edge_index)\n",
    "        return x\n",
    "\n",
    "\n",
    "class GraphSAGEEncoder(nn.Module):\n",
    "    def __init__(self, in_dim, hidden_dim, out_dim):\n",
    "        super().__init__()\n",
    "        self.conv1 = SAGEConv(in_dim, hidden_dim)\n",
    "        self.conv2 = SAGEConv(hidden_dim, out_dim)\n",
    "\n",
    "    def forward(self, data):\n",
    "        x, edge_index = data.x, data.edge_index\n",
    "        x = F.relu(self.conv1(x, edge_index))\n",
    "        x = self.conv2(x, edge_index)\n",
    "        return x\n",
    "\n",
    "\n",
    "# -------------------------------\n",
    "# Paths and configuration\n",
    "# -------------------------------\n",
    "dataset_name = \"ogbn-arxiv\"\n",
    "#splits = [\"30_70\", \"70_30\"]\n",
    "splits = [\"70_30\"]\n",
    "embedding_dim = 150\n",
    "epochs = 200\n",
    "\n",
    "mask_dirs_template = f\"/Users/sujan/Modularity based semi supervised learning/masks/Arxiv/{{split}}\"\n",
    "base_results_dir = f\"/Users/sujan/Downloads/GraFN/results/Arxiv\"\n",
    "base_embeddings_dir = f\"/Users/sujan/Downloads/GraFN/embeddings/Arxiv\"\n",
    "\n",
    "# -------------------------------\n",
    "# Pipeline per mask and encoder\n",
    "# -------------------------------\n",
    "def run_pipeline(data, mask_file, split_type, encoder_class, encoder_name, embedding_dim=150, epochs=200):\n",
    "    num_nodes = data.num_nodes\n",
    "    test_indices = np.load(mask_file)\n",
    "    test_mask = torch.zeros(num_nodes, dtype=torch.bool)\n",
    "    test_mask[test_indices] = True\n",
    "    train_mask = ~test_mask\n",
    "    data.train_mask = train_mask\n",
    "    data.test_mask = test_mask\n",
    "\n",
    "    print(f\"\\n=== Processing Mask File: {os.path.basename(mask_file)} ({split_type}, {encoder_name}) ===\")\n",
    "    print(f\"Train nodes: {train_mask.sum().item()} | Test nodes: {test_mask.sum().item()}\")\n",
    "\n",
    "    in_features = embedding_dim\n",
    "    hidden_dim = 128\n",
    "    out_dim = embedding_dim\n",
    "    encoder = encoder_class(in_features, hidden_dim, out_dim)\n",
    "    classifier = Classifier(out_dim, data.y.max().item() + 1)\n",
    "    model = GraFN(encoder, classifier, unique_labels=list(range(data.y.max().item() + 1)))\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)\n",
    "\n",
    "    model.train()\n",
    "    start_time = time.time()\n",
    "    for epoch in range(epochs):\n",
    "        optimizer.zero_grad()\n",
    "        out = model(data)\n",
    "        loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        if (epoch+1) % 20 == 0 or epoch == 0:\n",
    "            val_loss = F.cross_entropy(out[data.test_mask], data.y[data.test_mask])\n",
    "            print(f\"Epoch {epoch+1:03d} | Train Loss: {loss.item():.4f} | Val Loss: {val_loss.item():.4f}\")\n",
    "\n",
    "    elapsed_time = time.time() - start_time\n",
    "\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        embeddings = model(data)\n",
    "\n",
    "    mask_basename = os.path.basename(mask_file)\n",
    "    seed_str = \"seed42\"\n",
    "    if \"seed\" in mask_basename:\n",
    "        parts = mask_basename.split(\"seed\")\n",
    "        if len(parts) > 1:\n",
    "            seed_str = \"seed\" + parts[1].replace(\".npy\", \"\")\n",
    "\n",
    "    embeddings_subdir = os.path.join(base_embeddings_dir, split_type)\n",
    "    os.makedirs(embeddings_subdir, exist_ok=True)\n",
    "    embedding_filename = f\"{dataset_name}_{encoder_name}_{split_type}_{seed_str}.pt\"\n",
    "    torch.save(embeddings, os.path.join(embeddings_subdir, embedding_filename))\n",
    "    print(f\"Saved embeddings to {os.path.join(embeddings_subdir, embedding_filename)}\")\n",
    "\n",
    "    with torch.no_grad():\n",
    "        out = model(data)\n",
    "        preds = out.argmax(dim=1).cpu().numpy()\n",
    "        true_labels = data.y.cpu().numpy()\n",
    "        test_mask_np = data.test_mask.cpu().numpy()\n",
    "        test_preds = preds[test_mask_np]\n",
    "        test_labels = true_labels[test_mask_np]\n",
    "\n",
    "        acc = accuracy_score(test_labels, test_preds)\n",
    "        f1_macro = f1_score(test_labels, test_preds, average=\"macro\")\n",
    "        f1_micro = f1_score(test_labels, test_preds, average=\"micro\")\n",
    "\n",
    "    return {\n",
    "        \"encoder\": encoder_name,\n",
    "        \"accuracy\": acc,\n",
    "        \"f1_macro\": f1_macro,\n",
    "        \"f1_micro\": f1_micro,\n",
    "        \"time\": elapsed_time\n",
    "    }\n",
    "\n",
    "\n",
    "\n",
    "import numpy as np\n",
    "import scipy.sparse as sp\n",
    "from spektral.data import Dataset, Graph\n",
    "import torch\n",
    "\n",
    "# Patch torch.load to bypass weights_only restriction\n",
    "torch_load_old = torch.load\n",
    "def torch_load_patched(*args, **kwargs):\n",
    "    kwargs[\"weights_only\"] = False\n",
    "    return torch_load_old(*args, **kwargs)\n",
    "torch.load = torch_load_patched\n",
    "\n",
    "from ogb.nodeproppred import NodePropPredDataset\n",
    "\n",
    "class ArxivDataset(Dataset):\n",
    "    def __init__(self, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "\n",
    "    def read(self):\n",
    "        dataset = NodePropPredDataset(name=\"ogbn-arxiv\")\n",
    "        graph, labels = dataset[0]\n",
    "\n",
    "        x = graph[\"node_feat\"]  # (num_nodes, num_features)\n",
    "        edge_index = graph[\"edge_index\"]\n",
    "\n",
    "        a = sp.coo_matrix(\n",
    "            (np.ones(edge_index.shape[1]), (edge_index[0], edge_index[1])),\n",
    "            shape=(x.shape[0], x.shape[0]),\n",
    "        )\n",
    "\n",
    "        labels = labels.squeeze()\n",
    "        num_classes = labels.max() + 1\n",
    "        y = np.eye(num_classes)[labels]\n",
    "\n",
    "        return [Graph(x=x, a=a, y=y)]\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "import torch\n",
    "from torch_geometric.data import Data\n",
    "from torch_geometric.utils import to_undirected\n",
    "from ogb.nodeproppred import NodePropPredDataset\n",
    "import numpy as np\n",
    "\n",
    "dataset_path = '/Users/sujan/Downloads/GraFN/data/Arxiv'\n",
    "\n",
    "# Load the dataset safely without PygNodePropPredDataset\n",
    "dataset = NodePropPredDataset(name='ogbn-arxiv', root=dataset_path)\n",
    "graph, labels = dataset[0]\n",
    "\n",
    "# Node features and adjacency\n",
    "num_nodes = graph[\"node_feat\"].shape[0]\n",
    "edge_index = torch.tensor(graph[\"edge_index\"], dtype=torch.long)\n",
    "edge_index = to_undirected(edge_index)  # Make graph undirected\n",
    "\n",
    "# Labels\n",
    "y = torch.tensor(labels, dtype=torch.long).view(-1)\n",
    "\n",
    "# Randomized features\n",
    "x = torch.randn(num_nodes, 150)\n",
    "\n",
    "# Create PyG Data object\n",
    "data = Data(x=x, edge_index=edge_index, y=y)\n",
    "\n",
    "# Print summary\n",
    "print(data)\n",
    "print(f\"Number of nodes: {data.num_nodes}\")\n",
    "print(f\"Number of features: {data.num_node_features}\")\n",
    "print(f\"Number of classes: {data.y.max().item() + 1}\")\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "encoder_classes = {\n",
    "    \"GCN\": GCNEncoder,\n",
    "    \"GAT\": GATEncoder,\n",
    "    \"GraphSAGE\": GraphSAGEEncoder\n",
    "}\n",
    "\n",
    "for split_type in splits:\n",
    "    mask_dir = mask_dirs_template.format(split=split_type)\n",
    "    results_subdir = os.path.join(base_results_dir, split_type)\n",
    "    os.makedirs(results_subdir, exist_ok=True)\n",
    "\n",
    "    # Only select the specific seed file (seed46)\n",
    "    mask_files = sorted(glob.glob(os.path.join(mask_dir, \"*seed46.npy\")))\n",
    "    if len(mask_files) == 0:\n",
    "        print(f\"No seed46 mask file found for {dataset_name} split {split_type}. Skipping.\")\n",
    "        continue\n",
    "\n",
    "    all_encoder_results = []\n",
    "\n",
    "    for encoder_name, encoder_class in encoder_classes.items():\n",
    "        encoder_results = []\n",
    "        for mask_file in mask_files:\n",
    "            result = run_pipeline(\n",
    "                data,\n",
    "                mask_file,\n",
    "                split_type,\n",
    "                encoder_class,\n",
    "                encoder_name,\n",
    "                embedding_dim,\n",
    "                epochs\n",
    "            )\n",
    "            encoder_results.append(result)\n",
    "\n",
    "        # Collect and average (only one seed, so no std deviation needed)\n",
    "        acc_vals = [r[\"accuracy\"] for r in encoder_results]\n",
    "        f1_macro_vals = [r[\"f1_macro\"] for r in encoder_results]\n",
    "        f1_micro_vals = [r[\"f1_micro\"] for r in encoder_results]\n",
    "        time_vals = [r[\"time\"] for r in encoder_results]\n",
    "\n",
    "        avg_metrics = {\n",
    "            \"encoder\": encoder_name,\n",
    "            \"accuracy\": f\"{np.mean(acc_vals):.4f}\",\n",
    "            \"f1_macro\": f\"{np.mean(f1_macro_vals):.4f}\",\n",
    "            \"f1_micro\": f\"{np.mean(f1_micro_vals):.4f}\",\n",
    "            \"time\": f\"{np.mean(time_vals):.2f}\"\n",
    "        }\n",
    "        all_encoder_results.append(avg_metrics)\n",
    "\n",
    "    # Save results\n",
    "    results_df = pd.DataFrame(all_encoder_results)\n",
    "    metrics_df = results_df[['encoder', 'accuracy', 'f1_macro', 'f1_micro']]\n",
    "    times_df = results_df[['encoder', 'time']]\n",
    "\n",
    "    metrics_df.to_csv(os.path.join(results_subdir, f\"{dataset_name}_results_seed46.csv\"), index=False)\n",
    "    times_df.to_csv(os.path.join(results_subdir, f\"{dataset_name}_execution_times_seed46.csv\"), index=False)\n",
    "\n",
    "    print(f\"Saved results and execution times for {dataset_name} split {split_type}, seed46 only.\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d117f940-0cd3-4ed1-9b21-c328cde405cf",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (grafn-env)",
   "language": "python",
   "name": "grafn-env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
