{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from ogb.nodeproppred import PygNodePropPredDataset\n",
    "import numpy as np\n",
    "import torch_geometric.transforms as T\n",
    "\n",
    "dataset = PygNodePropPredDataset(name = \"ogbn-arxiv\", root = \"datasets/\", transform=T.Compose([T.ToUndirected(), T.ToSparseTensor()])) \n",
    "\n",
    "split_idx = dataset.get_idx_split()\n",
    "train_idx, valid_idx, test_idx = split_idx[\"train\"], split_idx[\"valid\"], split_idx[\"test\"]\n",
    "graph = dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('ogbn-arxiv', 'ogbn_arxiv')"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset_name = \"ogbn-arxiv\"\n",
    "new = dataset_name.replace(\"-\", \"_\")\n",
    "dataset_name, new"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Data(num_nodes=169343, x=[169343, 128], node_year=[169343, 1], y=[169343, 1], adj_t=[169343, 169343, nnz=2315598])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "128"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "num_features = graph.x.shape[1]\n",
    "num_features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'train': tensor([     0,      1,      2,  ..., 169145, 169148, 169251]),\n",
       " 'valid': tensor([   349,    357,    366,  ..., 169185, 169261, 169296]),\n",
       " 'test': tensor([   346,    398,    451,  ..., 169340, 169341, 169342])}"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "split_idx = dataset.get_idx_split()\n",
    "split_idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SparseTensor(row=tensor([     0,      0,      0,  ..., 169341, 169342, 169342]),\n",
       "             col=tensor([   411,    640,   1162,  ..., 163274,  27824, 158981]),\n",
       "             size=(169343, 169343), nnz=2315598, density=0.01%)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "graph.adj_t"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([169343])"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset.y.squeeze().size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ml_collections import ConfigDict\n",
    "dataset_name = \"ogbn-arxiv\"\n",
    "num_nodes = graph.num_nodes\n",
    "num_features = graph.num_features\n",
    "dataset_info = ConfigDict()\n",
    "dataset_info.n_features = num_features\n",
    "dataset_info.n_classes = len(graph.y.unique())\n",
    "dataset_info.n_nodes = num_nodes\n",
    "dataset_info.dataset_name = dataset_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dataset_name: ogbn-arxiv\n",
       "n_classes: 40\n",
       "n_features: 128\n",
       "n_nodes: 169343"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset_info"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_nodes = graph.num_nodes\n",
    "training_mask = torch.zeros(num_nodes, dtype=torch.bool)\n",
    "training_mask[train_idx] = True\n",
    "validation_mask = torch.zeros(num_nodes, dtype=torch.bool)\n",
    "validation_mask[valid_idx] = True\n",
    "test_mask = torch.zeros(num_nodes, dtype=torch.bool)\n",
    "test_mask[test_idx] = True\n",
    "\n",
    "unlabeled_mask = ~(training_mask | validation_mask | test_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
       "        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,\n",
       "        36, 37, 38, 39])"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "graph.y.unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([90941]), torch.Size([29799]), torch.Size([48603]))"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_idx.size(), valid_idx.size(), test_idx.size()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## load OGB-Arxiv dataset as in their experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "import scipy.sparse as sp\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch_sparse\n",
    "import logging\n",
    "from robust_diffusion.helper import utils\n",
    "\n",
    "make_undirected = True\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "pyg_dataset = PygNodePropPredDataset(name = \"ogbn-arxiv\", root = \"datasets/\")\n",
    "\n",
    "data = pyg_dataset[0]\n",
    "\n",
    "if hasattr(data, '__num_nodes__'):\n",
    "    num_nodes = data.__num_nodes__\n",
    "else:\n",
    "    num_nodes = data.num_nodes\n",
    "\n",
    "if hasattr(pyg_dataset, 'get_idx_split'):\n",
    "    split = pyg_dataset.get_idx_split()\n",
    "else:\n",
    "    split = dict(\n",
    "        train=data.train_mask.nonzero().squeeze(),\n",
    "        valid=data.val_mask.nonzero().squeeze(),\n",
    "        test=data.test_mask.nonzero().squeeze()\n",
    "    )\n",
    "\n",
    "# converting to numpy arrays, so we don't have to handle different\n",
    "# array types (tensor/numpy/list) later on.\n",
    "# Also we need numpy arrays because Numba cant determine type of torch.Tensor\n",
    "split = {k: v.numpy() for k, v in split.items()}\n",
    "\n",
    "edge_index = data.edge_index.cpu()\n",
    "if data.edge_attr is None:\n",
    "    edge_weight = torch.ones(edge_index.size(1))\n",
    "else:\n",
    "    edge_weight = data.edge_attr\n",
    "edge_weight = edge_weight.cpu()\n",
    "\n",
    "adj = sp.csr_matrix((edge_weight, edge_index), (num_nodes, num_nodes))\n",
    "\n",
    "del edge_index\n",
    "del edge_weight\n",
    "\n",
    "# make unweighted\n",
    "adj.data = np.ones_like(adj.data)\n",
    "\n",
    "if make_undirected:\n",
    "    adj = utils.to_symmetric_scipy(adj)\n",
    "\n",
    "    logging.debug(\"Memory Usage after making the graph undirected:\")\n",
    "    logging.debug(utils.get_max_memory_bytes() / (1024 ** 3))\n",
    "\n",
    "logging.debug(\"Memory Usage after normalizing the graph\")\n",
    "logging.debug(utils.get_max_memory_bytes() / (1024 ** 3))\n",
    "\n",
    "adj = torch_sparse.SparseTensor.from_scipy(adj).coalesce().to(device)\n",
    "\n",
    "attr_matrix = data.x.cpu().numpy()\n",
    "\n",
    "attr = torch.from_numpy(attr_matrix).to(device)\n",
    "\n",
    "logging.debug(\"Memory Usage after normalizing graph attributes:\")\n",
    "logging.debug(utils.get_max_memory_bytes() / (1024 ** 3))\n",
    "\n",
    "labels = data.y.squeeze().to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([169343, 128]),\n",
       " SparseTensor(row=tensor([     0,      0,      0,  ..., 169341, 169342, 169342], device='cuda:0'),\n",
       "              col=tensor([   411,    640,   1162,  ..., 163274,  27824, 158981], device='cuda:0'),\n",
       "              val=tensor([1., 1., 1.,  ..., 1., 1., 1.], device='cuda:0', dtype=torch.float64),\n",
       "              size=(169343, 169343), nnz=2315598, density=0.01%),\n",
       " torch.Size([169343]))"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "attr.size(), adj, labels.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((90941,), (29799,), (48603,))"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "split[\"train\"].shape, split[\"valid\"].shape, split[\"test\"].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([     0,      1,      2, ..., 169145, 169148, 169251]),\n",
       " array([   349,    357,    366, ..., 169185, 169261, 169296]),\n",
       " array([   346,    398,    451, ..., 169340, 169341, 169342]))"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "split[\"train\"], split[\"valid\"], split[\"test\"]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
