{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8b96a952-f8cd-4b68-ae2b-c1f2599b9898",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded AMAZON-PHOTOS\n",
      "Data(x=[7650, 150], edge_index=[2, 238162], y=[7650])\n",
      "Number of nodes: 7650, Number of features: 150, Number of classes: 8\n",
      "------------------------------------------------------------\n",
      "Loaded WIKICS\n",
      "Data(x=[11701, 150], edge_index=[2, 431726], y=[11701], train_mask=[11701, 20], val_mask=[11701, 20], test_mask=[11701], stopping_mask=[11701, 20])\n",
      "Number of nodes: 11701, Number of features: 150, Number of classes: 10\n",
      "------------------------------------------------------------\n",
      "Loaded PUBMED\n",
      "Data(x=[19717, 150], edge_index=[2, 88648], y=[19717], train_mask=[19717], val_mask=[19717], test_mask=[19717])\n",
      "Number of nodes: 19717, Number of features: 150, Number of classes: 3\n",
      "------------------------------------------------------------\n",
      "Loaded CITESEER\n",
      "Data(x=[3327, 150], edge_index=[2, 9104], y=[3327], train_mask=[3327], val_mask=[3327], test_mask=[3327])\n",
      "Number of nodes: 3327, Number of features: 150, Number of classes: 6\n",
      "------------------------------------------------------------\n",
      "Loaded CORA\n",
      "Data(x=[2708, 150], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])\n",
      "Number of nodes: 2708, Number of features: 150, Number of classes: 7\n",
      "------------------------------------------------------------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/sujan/miniforge3/envs/tf-gpu/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "/Users/sujan/miniforge3/envs/tf-gpu/lib/python3.10/site-packages/torch_geometric/datasets/wikics.py:45: UserWarning: The WikiCS dataset now returns an undirected graph by default. Please explicitly specify 'is_undirected=False' to restore the old behavior.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "# load_datasets_revar.ipynb\n",
    "\n",
    "import os\n",
    "import torch\n",
    "from torch_geometric.datasets import Amazon, Planetoid, WikiCS\n",
    "\n",
    "# Default dataset path\n",
    "DATASET_PATH = '/Users/sujan/Downloads/ReVar/data'\n",
    "os.makedirs(DATASET_PATH, exist_ok=True)\n",
    "\n",
    "\n",
    "def load_dataset_revar(dataset_name: str, hidden_dim: int = 150):\n",
    "    \"\"\"\n",
    "    Load a graph dataset and replace its features with random embeddings.\n",
    "    \n",
    "    Args:\n",
    "        dataset_name (str): One of ['Amazon-Photos', 'WikiCS', 'Pubmed', 'Citeseer', 'Cora'].\n",
    "        hidden_dim (int): Dimensionality of the random embeddings.\n",
    "\n",
    "    Returns:\n",
    "        data (torch_geometric.data.Data): Graph data object with x, edge_index, y, train/test masks (if available).\n",
    "        dataset (torch_geometric.data.Dataset): The loaded dataset.\n",
    "    \"\"\"\n",
    "    dataset_name = dataset_name.lower()\n",
    "\n",
    "    if dataset_name == 'amazon-photos':\n",
    "        dataset = Amazon(root=DATASET_PATH, name='Photo')\n",
    "    elif dataset_name == 'wikics':\n",
    "        dataset = WikiCS(root=DATASET_PATH)\n",
    "    elif dataset_name == 'pubmed':\n",
    "        dataset = Planetoid(root=DATASET_PATH, name='Pubmed')\n",
    "    elif dataset_name == 'citeseer':\n",
    "        dataset = Planetoid(root=DATASET_PATH, name='Citeseer')\n",
    "    elif dataset_name == 'cora':\n",
    "        dataset = Planetoid(root=DATASET_PATH, name='Cora')\n",
    "    else:\n",
    "        raise ValueError(f\"Dataset {dataset_name} not supported!\")\n",
    "\n",
    "    data = dataset[0]\n",
    "\n",
    "    # Replace node features with random embeddings\n",
    "    data.x = torch.randn(data.num_nodes, hidden_dim)\n",
    "\n",
    "    print(f\"Loaded {dataset_name.upper()}\")\n",
    "    print(data)\n",
    "    print(f\"Number of nodes: {data.num_nodes}, \"\n",
    "          f\"Number of features: {data.num_node_features}, \"\n",
    "          f\"Number of classes: {dataset.num_classes}\")\n",
    "\n",
    "    return data, dataset\n",
    "\n",
    "\n",
    "# Example usage:\n",
    "if __name__ == \"__main__\":\n",
    "    for name in [\"Amazon-Photos\", \"WikiCS\", \"Pubmed\", \"Citeseer\", \"Cora\"]:\n",
    "        load_dataset_revar(name, hidden_dim=150)\n",
    "        print(\"-\" * 60)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a04bc90d-1c20-455a-b2fa-85f0db507034",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (tf-gpu)",
   "language": "python",
   "name": "tf-gpu"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
