{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5861494c-e662-4ca7-85f1-4c919a95afe3",
   "metadata": {},
   "outputs": [],
   "source": [
    "%config InlineBackend.figure_format = 'svg'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c17614e8-1b81-4af5-907f-a1e0b8099371",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/sujan/miniforge3/envs/tf-gpu/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import random\n",
    "import time \n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import umap.umap_ as umap\n",
    "from sklearn.cluster import AgglomerativeClustering\n",
    "from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score\n",
    "from node2vec import Node2Vec\n",
    "from tqdm import tqdm\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "from tensorflow.keras.models import Model\n",
    "from tensorflow.keras.optimizers import Adam\n",
    "from tensorflow.keras.losses import CategoricalCrossentropy\n",
    "from tensorflow.keras.metrics import CategoricalAccuracy\n",
    "import torch\n",
    "from torch_geometric.data import Data\n",
    "import spektral\n",
    "from spektral.layers import GCNConv, GATConv\n",
    "from spektral.layers import GraphSageConv\n",
    "from spektral.data import Graph, Dataset, BatchLoader\n",
    "from scipy.sparse import csr_matrix\n",
    "from torch_geometric.nn import DeepGraphInfomax, VGAE\n",
    "from torch_geometric.utils import from_networkx\n",
    "import scipy.sparse as sp\n",
    "from sklearn.metrics import accuracy_score, confusion_matrix, f1_score\n",
    "from scipy.sparse.csgraph import laplacian\n",
    "from scipy.sparse.linalg import eigsh\n",
    "from collections import Counter\n",
    "from sklearn.preprocessing import normalize\n",
    "from joblib import Parallel, delayed\n",
    "from torch_geometric.nn import GCNConv as PyG_GCNConv, VGAE as PyG_VGAE\n",
    "from torch_geometric.data import Data\n",
    "from scipy.sparse import lil_matrix\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d9151ce5-ae93-4024-ac2f-c1b39856848d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from spektral.datasets import Cora\n",
    "\n",
    "# Create a custom Dataset for the graph\n",
    "class CoraDataset(Dataset):\n",
    "    def __init__(self, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "\n",
    "    def read(self):\n",
    "        data = Cora()  # Load the dataset\n",
    "        graph = data.graphs[0]  # Access the first graph in the dataset\n",
    "        return [Graph(x=graph.x, a=graph.a, y=graph.y)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "037503e6-d875-4175-98ea-23f087dcaede",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.datasets import Planetoid\n",
    "\n",
    "# Create a custom Dataset for the graph\n",
    "class CiteSeerDataset(Dataset):\n",
    "    def __init__(self, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "\n",
    "    def read(self):\n",
    "        dataset = Planetoid(root=\".\", name=\"CiteSeer\")  # Load CiteSeer dataset\n",
    "        data = dataset[0]  # Access the first graph\n",
    "        \n",
    "        # Convert Torch tensors to NumPy\n",
    "        x = data.x.numpy()\n",
    "        edge_index = data.edge_index.numpy()\n",
    "        y = data.y.numpy()\n",
    "\n",
    "        # One-hot encode labels\n",
    "        num_classes = y.max() + 1  # Number of classes\n",
    "        y_one_hot = np.eye(num_classes)[y]  # One-hot encoding\n",
    "\n",
    "        # Convert edge_index to a sparse adjacency matrix\n",
    "        num_nodes = x.shape[0]\n",
    "        adj = csr_matrix((num_nodes, num_nodes))  # Initialize sparse matrix\n",
    "        for i in range(edge_index.shape[1]):\n",
    "            src, dst = edge_index[:, i]\n",
    "            adj[src, dst] = 1\n",
    "            adj[dst, src] = 1  # Ensure undirected graph\n",
    "\n",
    "        return [Graph(x=x, a=adj, y=y_one_hot)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "bbe7ec52-56b6-4069-a3f5-c07a072f0cd2",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Create a custom Dataset for the graph\n",
    "class PubMedDataset(Dataset):\n",
    "    def __init__(self, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "\n",
    "    def read(self):\n",
    "        dataset = Planetoid(root=\".\", name=\"PubMed\")  # Load PubMed dataset\n",
    "        data = dataset[0]  # Access the first graph\n",
    "        \n",
    "        # Convert Torch tensors to NumPy\n",
    "        x = data.x.numpy()\n",
    "        edge_index = data.edge_index.numpy()\n",
    "        y = data.y.numpy()\n",
    "\n",
    "        # One-hot encode labels\n",
    "        num_classes = y.max() + 1  # Number of classes\n",
    "        y_one_hot = np.eye(num_classes)[y]  # One-hot encoding\n",
    "\n",
    "        # Convert edge_index to a sparse adjacency matrix\n",
    "        num_nodes = x.shape[0]\n",
    "        adj = lil_matrix((num_nodes, num_nodes), dtype=np.float32)\n",
    "        for i in range(edge_index.shape[1]):\n",
    "            src, dst = edge_index[:, i]\n",
    "            adj[src, dst] = 1\n",
    "            adj[dst, src] = 1  # Ensure undirected graph\n",
    "\n",
    "        return [Graph(x=x, a=adj, y=y_one_hot)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "7888fabb-25d0-4df7-8784-60ed43d4f1b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.datasets import Amazon\n",
    "\n",
    "# Create a custom Dataset for DBLP\n",
    "class AmazonPhotosDataset(Dataset):\n",
    "    def __init__(self, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "\n",
    "    def read(self):\n",
    "        dataset = Amazon(\".\", name=\"photo\")  # Load Amazon Computers dataset\n",
    "        data = dataset[0]\n",
    "\n",
    "        x = data.x.numpy()\n",
    "        edge_index = data.edge_index.numpy()\n",
    "        y = data.y.numpy()\n",
    "\n",
    "        # One-hot encode labels\n",
    "        num_classes = y.max() + 1\n",
    "        y_one_hot = np.eye(num_classes)[y]\n",
    "\n",
    "        # Convert edge_index to adjacency matrix\n",
    "        num_nodes = x.shape[0]\n",
    "        adj = lil_matrix((num_nodes, num_nodes), dtype=np.float32)\n",
    "        for i in range(edge_index.shape[1]):\n",
    "            src, dst = edge_index[:, i]\n",
    "            adj[src, dst] = 1\n",
    "            adj[dst, src] = 1  \n",
    "\n",
    "        return [Graph(x=x, a=adj, y=y_one_hot)]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "359e48e0-47a3-4abb-99d7-9e8a6c5551fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.datasets import WikiCS\n",
    "\n",
    "class WikiCSDataset(Dataset):\n",
    "    def __init__(self, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "\n",
    "    def read(self):\n",
    "        dataset = WikiCS(root=\"./data/WikiCS\")  # Download & load WikiCS dataset\n",
    "        data = dataset[0]  # Access the graph\n",
    "\n",
    "        # Features and labels\n",
    "        x = data.x.numpy()\n",
    "        y = data.y.numpy().flatten()\n",
    "\n",
    "        # One-hot encode labels\n",
    "        num_classes = y.max() + 1\n",
    "        y_one_hot = np.eye(num_classes)[y]\n",
    "\n",
    "        # Build adjacency matrix\n",
    "        edge_index = data.edge_index.numpy()\n",
    "        num_nodes = x.shape[0]\n",
    "        adj = lil_matrix((num_nodes, num_nodes), dtype=np.float32)\n",
    "        for i in range(edge_index.shape[1]):\n",
    "            src, dst = edge_index[:, i]\n",
    "            adj[src, dst] = 1\n",
    "            adj[dst, src] = 1  # Ensure undirected\n",
    "\n",
    "        return [Graph(x=x, a=adj, y=y_one_hot)]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ebae1cf3-591c-4a06-a407-78459ab92a77",
   "metadata": {},
   "outputs": [],
   "source": [
    "# patch torch.load\n",
    "torch_load_old = torch.load\n",
    "def torch_load_patched(*args, **kwargs):\n",
    "    kwargs[\"weights_only\"] = False\n",
    "    return torch_load_old(*args, **kwargs)\n",
    "torch.load = torch_load_patched\n",
    "from ogb.nodeproppred import NodePropPredDataset\n",
    "\n",
    "class ArxivDataset(Dataset):\n",
    "    def __init__(self, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "\n",
    "    def read(self):\n",
    "        dataset = NodePropPredDataset(name=\"ogbn-arxiv\")\n",
    "        graph, labels = dataset[0]\n",
    "\n",
    "        x = graph[\"node_feat\"]  # (num_nodes, num_features)\n",
    "        a = sp.coo_matrix(\n",
    "            (np.ones(graph[\"edge_index\"].shape[1]), \n",
    "             (graph[\"edge_index\"][0], graph[\"edge_index\"][1])),\n",
    "            shape=(x.shape[0], x.shape[0]),\n",
    "        )\n",
    "\n",
    "        # labels is shape (num_nodes, 1) with integer class indices\n",
    "        labels = labels.squeeze()\n",
    "\n",
    "        # Convert to one-hot encoded labels\n",
    "        num_classes = labels.max() + 1\n",
    "        y = np.eye(num_classes)[labels]   # (num_nodes, num_classes)\n",
    "\n",
    "        return [Graph(x=x, a=a, y=y)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2faf6e58-fe84-4bf5-b872-efd2feed0e76",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (tf-gpu)",
   "language": "python",
   "name": "tf-gpu"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
