{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f0537b07",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import networkx as nx\n",
    "import os\n",
    "from dgl import function as fn\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "from dataset import load_graph_dataset\n",
    "from dataset_wikics import load_wikics\n",
    "from tqdm import tqdm\n",
    "from model_softmax import SimplifiedGraphNeuralNetwork\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from train import train_preprocessed_data\n",
    "from model_edge_influence import EdgeInfluenceSGC\n",
    "from tqdm import tqdm\n",
    "from dgl.data import AmazonCoBuyComputerDataset, AmazonCoBuyPhotoDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "024d6deb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(13752, torch.Size([13752, 767]))"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset = AmazonCoBuyComputerDataset()\n",
    "g = dataset[0]\n",
    "num_class = dataset.num_classes\n",
    "feat = g.ndata['feat']  # get node feature\n",
    "label = g.ndata['label']  # get node labels\n",
    "g.number_of_nodes(), feat.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "069ca83a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from https://github.com/mengliu1998/DeeperGNN/blob/master/DeeperGNN/train_eval.py\n",
    "\n",
    "def index_to_mask(index, size):\n",
    "    mask = torch.zeros(size, dtype=torch.bool, device=index.device)\n",
    "    mask[index] = 1\n",
    "    return mask\n",
    "\n",
    "def random_amazon_splits(data, num_classes, seed = 42):\n",
    "    # Set random coauthor/co-purchase splits:\n",
    "    # * 20 * num_classes labels for training\n",
    "    # * 30 * num_classes labels for validation\n",
    "    # rest labels for testing\n",
    "    torch.manual_seed(seed)\n",
    "    indices = []\n",
    "    for i in range(num_classes):\n",
    "        index = (data[0].ndata['label'] == i).nonzero().view(-1)\n",
    "        index = index[torch.randperm(index.size(0))]\n",
    "        indices.append(index)\n",
    "\n",
    "    train_index = torch.cat([i[:20] for i in indices], dim=0)\n",
    "    val_index = torch.cat([i[20:50] for i in indices], dim=0)\n",
    "\n",
    "    rest_index = torch.cat([i[50:] for i in indices], dim=0)\n",
    "    rest_index = rest_index[torch.randperm(rest_index.size(0))]\n",
    "\n",
    "    train_mask = index_to_mask(train_index, size=data[0].number_of_nodes())\n",
    "    val_mask = index_to_mask(val_index, size=data[0].number_of_nodes())\n",
    "    test_mask = index_to_mask(rest_index, size=data[0].number_of_nodes())\n",
    "\n",
    "    return train_mask, val_mask, test_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "a20c9690",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_mask, val_mask, test_mask = random_amazon_splits(dataset, num_class)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "24476070",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_mask2, val_mask2, test_mask2 = random_amazon_splits(dataset, num_class)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "3082969a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.equal(train_mask, train_mask2)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
