{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sherirto/CSTAG"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "bash command\n",
    "\n",
    "```shell\n",
    "export HF_ENDPOINT=https://hf-mirror.com\n",
    "huggingface-cli download --repo-type dataset --resume-download Sherirto/CSTAG --local-dir graph-feature-selection/data/CSTAG --local-dir-use-symlinks False\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "CSTAG_DATASET_BASE_PATH = \"../data/CSTAG\"\n",
    "# 查找文件夹下所有目录下的.pt文件，返回文件路径列表\n",
    "import os\n",
    "import dgl\n",
    "import numpy  as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "\n",
    "def find_files(dir_path, end_with='.pt'):\n",
    "    file_list = []\n",
    "    for root, dirs, files in os.walk(dir_path):\n",
    "        for file in files:\n",
    "            if file.endswith(end_with):\n",
    "                file_list.append(os.path.join(root, file))\n",
    "    return file_list\n",
    "\n",
    "find_files(CSTAG_DATASET_BASE_PATH, '.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = [\"Children\", \"Computers\", \"Fitness\", \"History\", \"Photo\"]\n",
    "num_data_splits = 10\n",
    "train_prop = 0.5\n",
    "valid_prop = 0.25\n",
    "save_path_base = \"../data\"\n",
    "\n",
    "for dataset in datasets:\n",
    "    # edges\n",
    "    edges = None\n",
    "    graph_file = find_files(f\"{CSTAG_DATASET_BASE_PATH}/{dataset}\", '.pt')\n",
    "    glist, label_dict = dgl.load_graphs(graph_file[0])\n",
    "    for i, g in enumerate(glist):\n",
    "        print(f\"Dataset: {dataset}, Graph {i+1}:\")\n",
    "        print(f\"Number of nodes: {g.number_of_nodes()}\")\n",
    "        edges = torch.tensor(np.array(g.edges()).T)\n",
    "        print(f\"Number of edges: {g.number_of_edges()}, edges dtype: {edges.dtype}\")\n",
    "    # node labels\n",
    "    csv_file = find_files(f\"{CSTAG_DATASET_BASE_PATH}/{dataset}\", '.csv')\n",
    "    csv_data = pd.read_csv(csv_file[0])\n",
    "    node_labels = torch.tensor(np.array(csv_data['label'].values))\n",
    "    print(f\"Number of labels: {len(node_labels)}, node_labels dtype: {node_labels.dtype}\")\n",
    "    # node features\n",
    "    features_file = find_files(f\"{CSTAG_DATASET_BASE_PATH}/{dataset}\", '.npy')\n",
    "    node_features = torch.tensor(np.load(features_file[0])).to(torch.float)\n",
    "    print(f\"Number of features: {len(node_features)}, node_features dtype: {node_features.dtype}\")\n",
    "    print(f\"Feature dimension: {len(node_features[0])}\")\n",
    "    # Splits\n",
    "    seed = 0\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed(seed)\n",
    "    num_nodes = len(node_labels)\n",
    "    mask_number = torch.zeros(num_data_splits, num_nodes)\n",
    "    for i in range(num_data_splits):\n",
    "        mask_number[i] = torch.randperm(num_nodes)\n",
    "    train_masks = (mask_number<=(train_prop*num_nodes))\n",
    "    val_masks = (torch.logical_and(mask_number<=((train_prop+valid_prop)*num_nodes),mask_number>(train_prop*num_nodes)))\n",
    "    test_masks = (mask_number>((train_prop+valid_prop)*num_nodes))\n",
    "    print(f\"mask ratio(train:val:test): {train_masks.sum().item()/num_nodes/num_data_splits:.2f}\"\\\n",
    "          f\":{val_masks.sum().item()/num_nodes/num_data_splits:.2f}\"\n",
    "            f\":{test_masks.sum().item()/num_nodes/num_data_splits:.2f}\")\n",
    "    np.savez(f'{save_path_base}/{dataset}.npz',\n",
    "            node_features=node_features.numpy(),\n",
    "            node_labels=node_labels.numpy(),\n",
    "            edges=edges.numpy(),\n",
    "            train_masks=train_masks.numpy(),\n",
    "            val_masks=val_masks.numpy(),\n",
    "            test_masks=test_masks.numpy())\n",
    "    print(f\"save {dataset} graph, train, val, test masks to {save_path_base}/{dataset}.npz\")\n",
    "    print(\"-\" * 50)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## For roman-empire, amazon-ratings, minesweeper, tolokers, questions datasets , please refer to [CLICK THIS LINK](https://github.com/yandex-research/heterophilous-graphs/tree/main/data) for download"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hg",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
