{
  "cells": [
    {
      "cell_type": "code",
      "source": [
        "pip install ogb"
      ],
      "metadata": {
        "id": "sskEb33qnnKy"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tRMchmRhHpZ0"
      },
      "outputs": [],
      "source": [
        "pip install torch_geometric"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import os.path as osp\n",
        "import time\n",
        "import torch\n",
        "from torch_geometric.utils import to_undirected\n",
        "from torch_geometric.nn import LabelPropagation\n",
        "from torch_geometric.nn.conv.gcn_conv import gcn_norm\n",
        "from ogb.nodeproppred import PygNodePropPredDataset, Evaluator\n",
        "\n",
        "# -------------------------------\n",
        "# Patch torch.load for PyTorch >=2.6\n",
        "# -------------------------------\n",
        "torch_load_old = torch.load\n",
        "def torch_load_new(*args, **kwargs):\n",
        "    kwargs[\"weights_only\"] = False\n",
        "    return torch_load_old(*args, **kwargs)\n",
        "torch.load = torch_load_new\n",
        "\n",
        "# -------------------------------\n",
        "# Normalization helper\n",
        "# -------------------------------\n",
        "def normalize_adjs(edge_index, num_nodes, edge_weight=None, eps=1e-12):\n",
        "    \"\"\"\n",
        "    Return normalized (edge_index, edge_weight) for:\n",
        "    - DAD (symmetric): D^{-1/2} A D^{-1/2}\n",
        "    - GCN-DAD: same but with self-loops\n",
        "    - DA  (row):       D^{-1} A\n",
        "    - AD  (col):       A D^{-1}\n",
        "    \"\"\"\n",
        "    if edge_weight is None:\n",
        "        edge_weight = torch.ones(edge_index.size(1), device=edge_index.device)\n",
        "\n",
        "    row, col = edge_index\n",
        "    deg = torch.zeros(num_nodes, device=edge_weight.device).scatter_add_(0, row, edge_weight)\n",
        "    deg_inv = 1.0 / deg.clamp_min(eps)\n",
        "    deg_inv_sqrt = deg_inv.sqrt()\n",
        "\n",
        "    # DAD (no self-loops)\n",
        "    dad_w = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]\n",
        "\n",
        "    # GCN-style DAD (with self-loops)\n",
        "    gcn_ei, gcn_w = gcn_norm(\n",
        "        edge_index=edge_index, edge_weight=edge_weight,\n",
        "        num_nodes=num_nodes, add_self_loops=True\n",
        "    )\n",
        "\n",
        "    # DA (row-normalized)\n",
        "    da_w = deg_inv[row] * edge_weight\n",
        "    # AD (col-normalized)\n",
        "    ad_w = deg_inv[col] * edge_weight\n",
        "\n",
        "    return (edge_index, dad_w), (gcn_ei, gcn_w), (edge_index, da_w), (edge_index, ad_w)\n",
        "\n",
        "# -------------------------------\n",
        "# Run LP\n",
        "# -------------------------------\n",
        "def run_lp(edge_index, edge_weight, y, split_idx, evaluator, name=\"\"):\n",
        "    model = LabelPropagation(num_layers=1, alpha=1.0)\n",
        "    out = model(y, edge_index, mask=split_idx['train'], edge_weight=edge_weight)\n",
        "    y_pred = out.argmax(dim=-1, keepdim=True)\n",
        "\n",
        "    accs = {}\n",
        "    for split in [\"train\", \"valid\", \"test\"]:\n",
        "        accs[split] = evaluator.eval({\n",
        "            \"y_true\": y[split_idx[split]],\n",
        "            \"y_pred\": y_pred[split_idx[split]],\n",
        "        })[\"acc\"]\n",
        "\n",
        "    print(f\"[{name}] \"\n",
        "          f\"Train: {accs['train']:.4f} | \"\n",
        "          f\"Val: {accs['valid']:.4f} | \"\n",
        "          f\"Test: {accs['test']:.4f}\")\n",
        "    return accs\n",
        "\n",
        "# -------------------------------\n",
        "# Load dataset\n",
        "# -------------------------------\n",
        "root = './data/OGB'\n",
        "dataset = PygNodePropPredDataset('ogbn-arxiv', root=root, transform=None)\n",
        "split_idx = dataset.get_idx_split()\n",
        "graph = dataset[0]\n",
        "evaluator = Evaluator(name='ogbn-arxiv')\n",
        "\n",
        "print(\"OGB Arxiv dataset loaded successfully!\")\n",
        "print(f\"num_nodes: {graph.num_nodes}, num_edges: {graph.num_edges}\\n\")\n",
        "\n",
        "# -------------------------------\n",
        "# Helper: evaluate one graph\n",
        "# -------------------------------\n",
        "def evaluate_graph(name, edge_index, edge_weight, y, split_idx, evaluator, num_nodes):\n",
        "    results = {}\n",
        "    (DAD, GCN_DAD, DA, AD) = normalize_adjs(edge_index, num_nodes, edge_weight)\n",
        "\n",
        "    results[\"DAD\"]     = run_lp(*DAD,     y, split_idx, evaluator, f\"{name} LP (DAD)\")\n",
        "    results[\"GCN-DAD\"] = run_lp(*GCN_DAD, y, split_idx, evaluator, f\"{name} LP (GCN-DAD)\")\n",
        "    results[\"DA\"]      = run_lp(*DA,      y, split_idx, evaluator, f\"{name} LP (DA)\")\n",
        "    results[\"AD\"]      = run_lp(*AD,      y, split_idx, evaluator, f\"{name} LP (AD)\")\n",
        "    return results\n",
        "\n",
        "# -------------------------------\n",
        "# Step 1: Original adjacency A\n",
        "# -------------------------------\n",
        "edge_index = to_undirected(graph.edge_index, num_nodes=graph.num_nodes)\n",
        "print(\"=== Original Graph (A) ===\")\n",
        "results_A = evaluate_graph(\"OGB-Arxiv\", edge_index, None, graph.y, split_idx, evaluator, graph.num_nodes)\n",
        "\n",
        "# -------------------------------\n",
        "# Step 2: Transformed adjacency\n",
        "# -------------------------------\n",
        "num_nodes, num_edges = graph.num_nodes, graph.num_edges\n",
        "D = torch.sparse_coo_tensor(\n",
        "    graph.edge_index,\n",
        "    torch.ones(num_edges, device=graph.x.device),\n",
        "    (num_nodes, num_nodes)\n",
        ")\n",
        "\n",
        "start = time.time()\n",
        "D2 = torch.sparse.mm(D, D)\n",
        "D3 = torch.sparse.mm(D, D2)\n",
        "DMI = 2*D+D2\n",
        "print(f\"\\nSNN(beta,(1,1)) bult in {time.time()-start:.2f}s\")\n",
        "\n",
        "w = DMI.coalesce()\n",
        "edge_index_dmi, edge_weight_dmi = w.indices(), w.values()\n",
        "edge_index_dmi, edge_weight_dmi = to_undirected(edge_index_dmi, edge_weight_dmi)\n",
        "\n",
        "print(\"\\n=== Transformed Graph (SNN(beta,(1,1))) ===\")\n",
        "results_DMI = evaluate_graph(\"DMI\", edge_index_dmi, edge_weight_dmi, graph.y, split_idx, evaluator, num_nodes)\n",
        "\n",
        "# -------------------------------\n",
        "# Final summary\n",
        "# -------------------------------\n",
        "print(\"\\n================= SUMMARY =================\")\n",
        "for graph_name, results in [(\"OGB-Arxiv\", results_A), (\"SNN(beta,(1,1))\", results_DMI)]:\n",
        "    print(f\"\\n{graph_name}:\")\n",
        "    for norm_name, accs in results.items():\n",
        "        print(f\"{norm_name:8s} -> Train: {accs['train']:.4f} | \"\n",
        "              f\"Val: {accs['valid']:.4f} | Test: {accs['test']:.4f}\")\n"
      ],
      "metadata": {
        "id": "3qbwE1Uph4VF"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import os.path as osp\n",
        "import time\n",
        "import torch\n",
        "from torch_geometric.utils import to_undirected\n",
        "from torch_geometric.nn import LabelPropagation\n",
        "from torch_geometric.nn.conv.gcn_conv import gcn_norm\n",
        "from ogb.nodeproppred import PygNodePropPredDataset, Evaluator\n",
        "\n",
        "# -------------------------------\n",
        "# Patch torch.load for PyTorch >=2.6\n",
        "# -------------------------------\n",
        "torch_load_old = torch.load\n",
        "def torch_load_new(*args, **kwargs):\n",
        "    kwargs[\"weights_only\"] = False\n",
        "    return torch_load_old(*args, **kwargs)\n",
        "torch.load = torch_load_new\n",
        "\n",
        "# -------------------------------\n",
        "# Normalization helper\n",
        "# -------------------------------\n",
        "def normalize_adjs(edge_index, num_nodes, edge_weight=None, eps=1e-12):\n",
        "    \"\"\"\n",
        "    Return normalized (edge_index, edge_weight) for:\n",
        "    - DAD (symmetric): D^{-1/2} A D^{-1/2}\n",
        "    - GCN-DAD: same but with self-loops\n",
        "    - DA  (row):       D^{-1} A\n",
        "    - AD  (col):       A D^{-1}\n",
        "    \"\"\"\n",
        "    if edge_weight is None:\n",
        "        edge_weight = torch.ones(edge_index.size(1), device=edge_index.device)\n",
        "\n",
        "    row, col = edge_index\n",
        "    deg = torch.zeros(num_nodes, device=edge_weight.device).scatter_add_(0, row, edge_weight)\n",
        "    deg_inv = 1.0 / deg.clamp_min(eps)\n",
        "    deg_inv_sqrt = deg_inv.sqrt()\n",
        "\n",
        "    # DAD (no self-loops)\n",
        "    dad_w = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]\n",
        "\n",
        "    # GCN-style DAD (with self-loops)\n",
        "    gcn_ei, gcn_w = gcn_norm(\n",
        "        edge_index=edge_index, edge_weight=edge_weight,\n",
        "        num_nodes=num_nodes, add_self_loops=True\n",
        "    )\n",
        "\n",
        "    # DA (row-normalized)\n",
        "    da_w = deg_inv[row] * edge_weight\n",
        "    # AD (col-normalized)\n",
        "    ad_w = deg_inv[col] * edge_weight\n",
        "\n",
        "    return (edge_index, dad_w), (gcn_ei, gcn_w), (edge_index, da_w), (edge_index, ad_w)\n",
        "\n",
        "# -------------------------------\n",
        "# Run LP\n",
        "# -------------------------------\n",
        "def run_lp(edge_index, edge_weight, y, split_idx, evaluator, name=\"\"):\n",
        "    model = LabelPropagation(num_layers=1, alpha=1.0)\n",
        "    out = model(y, edge_index, mask=split_idx['train'], edge_weight=edge_weight)\n",
        "    y_pred = out.argmax(dim=-1, keepdim=True)\n",
        "\n",
        "    accs = {}\n",
        "    for split in [\"train\", \"valid\", \"test\"]:\n",
        "        accs[split] = evaluator.eval({\n",
        "            \"y_true\": y[split_idx[split]],\n",
        "            \"y_pred\": y_pred[split_idx[split]],\n",
        "        })[\"acc\"]\n",
        "\n",
        "    print(f\"[{name}] \"\n",
        "          f\"Train: {accs['train']:.4f} | \"\n",
        "          f\"Val: {accs['valid']:.4f} | \"\n",
        "          f\"Test: {accs['test']:.4f}\")\n",
        "    return accs\n",
        "\n",
        "# -------------------------------\n",
        "# Load dataset\n",
        "# -------------------------------\n",
        "root = './data/OGB'\n",
        "dataset = PygNodePropPredDataset('ogbn-arxiv', root=root, transform=None)\n",
        "split_idx = dataset.get_idx_split()\n",
        "graph = dataset[0]\n",
        "evaluator = Evaluator(name='ogbn-arxiv')\n",
        "\n",
        "print(\"OGB Arxiv dataset loaded successfully!\")\n",
        "print(f\"num_nodes: {graph.num_nodes}, num_edges: {graph.num_edges}\\n\")\n",
        "\n",
        "# -------------------------------\n",
        "# Helper: evaluate one graph\n",
        "# -------------------------------\n",
        "def evaluate_graph(name, edge_index, edge_weight, y, split_idx, evaluator, num_nodes):\n",
        "    results = {}\n",
        "    (DAD, GCN_DAD, DA, AD) = normalize_adjs(edge_index, num_nodes, edge_weight)\n",
        "\n",
        "    results[\"DAD\"]     = run_lp(*DAD,     y, split_idx, evaluator, f\"{name} LP (DAD)\")\n",
        "    results[\"GCN-DAD\"] = run_lp(*GCN_DAD, y, split_idx, evaluator, f\"{name} LP (GCN-DAD)\")\n",
        "    results[\"DA\"]      = run_lp(*DA,      y, split_idx, evaluator, f\"{name} LP (DA)\")\n",
        "    results[\"AD\"]      = run_lp(*AD,      y, split_idx, evaluator, f\"{name} LP (AD)\")\n",
        "    return results\n",
        "\n",
        "# -------------------------------\n",
        "# Step 1: Original adjacency A\n",
        "# -------------------------------\n",
        "edge_index = to_undirected(graph.edge_index, num_nodes=graph.num_nodes)\n",
        "print(\"=== Original Graph (A) ===\")\n",
        "results_A = evaluate_graph(\"OGB-Arxiv\", edge_index, None, graph.y, split_idx, evaluator, graph.num_nodes)\n",
        "\n",
        "# -------------------------------\n",
        "# Step 2: Transformed adjacency (DMI = 3D + 3D^2 + D^3)\n",
        "# -------------------------------\n",
        "num_nodes, num_edges = graph.num_nodes, graph.num_edges\n",
        "D = torch.sparse_coo_tensor(\n",
        "    graph.edge_index,\n",
        "    torch.ones(num_edges, device=graph.x.device),\n",
        "    (num_nodes, num_nodes)\n",
        ")\n",
        "\n",
        "start = time.time()\n",
        "D2 = torch.sparse.mm(D, D)\n",
        "D3 = torch.sparse.mm(D, D2)\n",
        "DMI = 3*D + 3*D2 + D3\n",
        "print(f\"\\nSNN(beta,(1,1,1)) bult in {time.time()-start:.2f}s\")\n",
        "\n",
        "w = DMI.coalesce()\n",
        "edge_index_dmi, edge_weight_dmi = w.indices(), w.values()\n",
        "edge_index_dmi, edge_weight_dmi = to_undirected(edge_index_dmi, edge_weight_dmi)\n",
        "\n",
        "print(\"\\n=== Transformed Graph (SNN(beta,(1,1,1))) ===\")\n",
        "results_DMI = evaluate_graph(\"DMI\", edge_index_dmi, edge_weight_dmi, graph.y, split_idx, evaluator, num_nodes)\n",
        "\n",
        "# -------------------------------\n",
        "# Final summary\n",
        "# -------------------------------\n",
        "print(\"\\n================= SUMMARY =================\")\n",
        "for graph_name, results in [(\"OGB-Arxiv\", results_A), (\"SNN(beta,(1,1,1))\", results_DMI)]:\n",
        "    print(f\"\\n{graph_name}:\")\n",
        "    for norm_name, accs in results.items():\n",
        "        print(f\"{norm_name:8s} -> Train: {accs['train']:.4f} | \"\n",
        "              f\"Val: {accs['valid']:.4f} | Test: {accs['test']:.4f}\")\n"
      ],
      "metadata": {
        "id": "T600OfW3GPgT"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "S6iE-tkaZSrv"
      },
      "execution_count": null,
      "outputs": []
    }
  ],
  "metadata": {
    "colab": {
      "machine_shape": "hm",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}