{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "46e18457-6dbb-4928-adad-4b279400977c",
   "metadata": {},
   "outputs": [],
   "source": [
    "SEED = 46"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "48b4862d-cf6b-4b8a-bfc3-b886ff3d6262",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/sujan/miniforge3/envs/tf-gpu/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "%run load_datasets.ipynb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8e7376eb-3241-43dd-94ce-f8cde1ce430c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Cora] Saved: ./masks/Cora/70_30/Cora_70_30_masked_indices_seed42.npy\n",
      "[Cora] Saved: ./masks/Cora/70_30/Cora_70_30_masked_indices_seed46.npy\n",
      "[Cora] Saved: ./masks/Cora/70_30/Cora_70_30_masked_indices_seed123.npy\n",
      "[Cora] Saved: ./masks/Cora/70_30/Cora_70_30_masked_indices_seed2025.npy\n",
      "[Cora] Saved: ./masks/Cora/70_30/Cora_70_30_masked_indices_seed999.npy\n",
      "[Cora] Saved: ./masks/Cora/30_70/Cora_30_70_masked_indices_seed42.npy\n",
      "[Cora] Saved: ./masks/Cora/30_70/Cora_30_70_masked_indices_seed46.npy\n",
      "[Cora] Saved: ./masks/Cora/30_70/Cora_30_70_masked_indices_seed123.npy\n",
      "[Cora] Saved: ./masks/Cora/30_70/Cora_30_70_masked_indices_seed2025.npy\n",
      "[Cora] Saved: ./masks/Cora/30_70/Cora_30_70_masked_indices_seed999.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/sujan/miniforge3/envs/tf-gpu/lib/python3.10/site-packages/scipy/sparse/_index.py:108: SparseEfficiencyWarning: Changing the sparsity structure of a csr_matrix is expensive. lil_matrix is more efficient.\n",
      "  self._set_intXint(row, col, x.flat[0])\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CiteSeer] Saved: ./masks/CiteSeer/70_30/CiteSeer_70_30_masked_indices_seed42.npy\n",
      "[CiteSeer] Saved: ./masks/CiteSeer/70_30/CiteSeer_70_30_masked_indices_seed46.npy\n",
      "[CiteSeer] Saved: ./masks/CiteSeer/70_30/CiteSeer_70_30_masked_indices_seed123.npy\n",
      "[CiteSeer] Saved: ./masks/CiteSeer/70_30/CiteSeer_70_30_masked_indices_seed2025.npy\n",
      "[CiteSeer] Saved: ./masks/CiteSeer/70_30/CiteSeer_70_30_masked_indices_seed999.npy\n",
      "[CiteSeer] Saved: ./masks/CiteSeer/30_70/CiteSeer_30_70_masked_indices_seed42.npy\n",
      "[CiteSeer] Saved: ./masks/CiteSeer/30_70/CiteSeer_30_70_masked_indices_seed46.npy\n",
      "[CiteSeer] Saved: ./masks/CiteSeer/30_70/CiteSeer_30_70_masked_indices_seed123.npy\n",
      "[CiteSeer] Saved: ./masks/CiteSeer/30_70/CiteSeer_30_70_masked_indices_seed2025.npy\n",
      "[CiteSeer] Saved: ./masks/CiteSeer/30_70/CiteSeer_30_70_masked_indices_seed999.npy\n",
      "[PubMed] Saved: ./masks/PubMed/70_30/PubMed_70_30_masked_indices_seed42.npy\n",
      "[PubMed] Saved: ./masks/PubMed/70_30/PubMed_70_30_masked_indices_seed46.npy\n",
      "[PubMed] Saved: ./masks/PubMed/70_30/PubMed_70_30_masked_indices_seed123.npy\n",
      "[PubMed] Saved: ./masks/PubMed/70_30/PubMed_70_30_masked_indices_seed2025.npy\n",
      "[PubMed] Saved: ./masks/PubMed/70_30/PubMed_70_30_masked_indices_seed999.npy\n",
      "[PubMed] Saved: ./masks/PubMed/30_70/PubMed_30_70_masked_indices_seed42.npy\n",
      "[PubMed] Saved: ./masks/PubMed/30_70/PubMed_30_70_masked_indices_seed46.npy\n",
      "[PubMed] Saved: ./masks/PubMed/30_70/PubMed_30_70_masked_indices_seed123.npy\n",
      "[PubMed] Saved: ./masks/PubMed/30_70/PubMed_30_70_masked_indices_seed2025.npy\n",
      "[PubMed] Saved: ./masks/PubMed/30_70/PubMed_30_70_masked_indices_seed999.npy\n",
      "[AmazonPhotos] Saved: ./masks/AmazonPhotos/70_30/AmazonPhotos_70_30_masked_indices_seed42.npy\n",
      "[AmazonPhotos] Saved: ./masks/AmazonPhotos/70_30/AmazonPhotos_70_30_masked_indices_seed46.npy\n",
      "[AmazonPhotos] Saved: ./masks/AmazonPhotos/70_30/AmazonPhotos_70_30_masked_indices_seed123.npy\n",
      "[AmazonPhotos] Saved: ./masks/AmazonPhotos/70_30/AmazonPhotos_70_30_masked_indices_seed2025.npy\n",
      "[AmazonPhotos] Saved: ./masks/AmazonPhotos/70_30/AmazonPhotos_70_30_masked_indices_seed999.npy\n",
      "[AmazonPhotos] Saved: ./masks/AmazonPhotos/30_70/AmazonPhotos_30_70_masked_indices_seed42.npy\n",
      "[AmazonPhotos] Saved: ./masks/AmazonPhotos/30_70/AmazonPhotos_30_70_masked_indices_seed46.npy\n",
      "[AmazonPhotos] Saved: ./masks/AmazonPhotos/30_70/AmazonPhotos_30_70_masked_indices_seed123.npy\n",
      "[AmazonPhotos] Saved: ./masks/AmazonPhotos/30_70/AmazonPhotos_30_70_masked_indices_seed2025.npy\n",
      "[AmazonPhotos] Saved: ./masks/AmazonPhotos/30_70/AmazonPhotos_30_70_masked_indices_seed999.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/sujan/miniforge3/envs/tf-gpu/lib/python3.10/site-packages/torch_geometric/datasets/wikics.py:45: UserWarning: The WikiCS dataset now returns an undirected graph by default. Please explicitly specify 'is_undirected=False' to restore the old behavior.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[WikiCS] Saved: ./masks/WikiCS/70_30/WikiCS_70_30_masked_indices_seed42.npy\n",
      "[WikiCS] Saved: ./masks/WikiCS/70_30/WikiCS_70_30_masked_indices_seed46.npy\n",
      "[WikiCS] Saved: ./masks/WikiCS/70_30/WikiCS_70_30_masked_indices_seed123.npy\n",
      "[WikiCS] Saved: ./masks/WikiCS/70_30/WikiCS_70_30_masked_indices_seed2025.npy\n",
      "[WikiCS] Saved: ./masks/WikiCS/70_30/WikiCS_70_30_masked_indices_seed999.npy\n",
      "[WikiCS] Saved: ./masks/WikiCS/30_70/WikiCS_30_70_masked_indices_seed42.npy\n",
      "[WikiCS] Saved: ./masks/WikiCS/30_70/WikiCS_30_70_masked_indices_seed46.npy\n",
      "[WikiCS] Saved: ./masks/WikiCS/30_70/WikiCS_30_70_masked_indices_seed123.npy\n",
      "[WikiCS] Saved: ./masks/WikiCS/30_70/WikiCS_30_70_masked_indices_seed2025.npy\n",
      "[WikiCS] Saved: ./masks/WikiCS/30_70/WikiCS_30_70_masked_indices_seed999.npy\n",
      "[Arxiv] Saved: ./masks/Arxiv/70_30/Arxiv_70_30_masked_indices_seed42.npy\n",
      "[Arxiv] Saved: ./masks/Arxiv/70_30/Arxiv_70_30_masked_indices_seed46.npy\n",
      "[Arxiv] Saved: ./masks/Arxiv/70_30/Arxiv_70_30_masked_indices_seed123.npy\n",
      "[Arxiv] Saved: ./masks/Arxiv/70_30/Arxiv_70_30_masked_indices_seed2025.npy\n",
      "[Arxiv] Saved: ./masks/Arxiv/70_30/Arxiv_70_30_masked_indices_seed999.npy\n",
      "[Arxiv] Saved: ./masks/Arxiv/30_70/Arxiv_30_70_masked_indices_seed42.npy\n",
      "[Arxiv] Saved: ./masks/Arxiv/30_70/Arxiv_30_70_masked_indices_seed46.npy\n",
      "[Arxiv] Saved: ./masks/Arxiv/30_70/Arxiv_30_70_masked_indices_seed123.npy\n",
      "[Arxiv] Saved: ./masks/Arxiv/30_70/Arxiv_30_70_masked_indices_seed2025.npy\n",
      "[Arxiv] Saved: ./masks/Arxiv/30_70/Arxiv_30_70_masked_indices_seed999.npy\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "\n",
    "# ---- helper to extract labels ----\n",
    "def get_labels(dataset_class):\n",
    "    dataset = dataset_class()\n",
    "    graph = dataset[0]\n",
    "    y = graph.y\n",
    "    if y.ndim > 1:  # one-hot encoded\n",
    "        labels = np.argmax(y, axis=1)\n",
    "    else:  # already integers\n",
    "        labels = y\n",
    "    return labels\n",
    "\n",
    "\n",
    "# ---- masking function ----\n",
    "def generate_and_save_masks(\n",
    "    dataset_classes,\n",
    "    dataset_names,\n",
    "    base_dir=\"./masks\",\n",
    "    splits={\"70_30\": 0.3, \"30_70\": 0.7},\n",
    "    seeds=[42, 46, 123, 2025, 999]\n",
    "):\n",
    "    for dataset_class, dataset_name in zip(dataset_classes, dataset_names):\n",
    "        labels = get_labels(dataset_class)\n",
    "        num_nodes = len(labels)\n",
    "\n",
    "        for split_name, mask_fraction in splits.items():\n",
    "            split_dir = os.path.join(base_dir, dataset_name, split_name)\n",
    "            os.makedirs(split_dir, exist_ok=True)\n",
    "\n",
    "            for seed in seeds:\n",
    "                rng = np.random.default_rng(seed)\n",
    "                masked_indices = rng.choice(\n",
    "                    np.arange(num_nodes),\n",
    "                    int(num_nodes * mask_fraction),\n",
    "                    replace=False\n",
    "                )\n",
    "\n",
    "                filename = f\"{dataset_name}_{split_name}_masked_indices_seed{seed}.npy\"\n",
    "                save_path = os.path.join(split_dir, filename)\n",
    "\n",
    "                np.save(save_path, masked_indices)\n",
    "                print(f\"[{dataset_name}] Saved: {save_path}\")\n",
    "\n",
    "\n",
    "# ---- run with all datasets ----\n",
    "generate_and_save_masks(\n",
    "    dataset_classes=[\n",
    "        CoraDataset,\n",
    "        CiteSeerDataset,\n",
    "        PubMedDataset,\n",
    "        AmazonPhotosDataset,\n",
    "        WikiCSDataset,\n",
    "        ArxivDataset\n",
    "    ],\n",
    "    dataset_names=[\"Cora\", \"CiteSeer\", \"PubMed\", \"AmazonPhotos\", \"WikiCS\", \"Arxiv\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7ddb93d0-a38a-49fd-9c4e-9d8bf4cd4101",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Shape: (812,)\n"
     ]
    }
   ],
   "source": [
    "# Example: load one saved mask\n",
    "arr = np.load(\"./masks/Cora/70_30/Cora_70_30_masked_indices_seed42.npy\")\n",
    "\n",
    "print(\"Shape:\", arr.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a557650b-ff17-4671-84e6-2d9871747457",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (tf-gpu)",
   "language": "python",
   "name": "tf-gpu"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
