{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "roman-empire\n"
     ]
    }
   ],
   "source": [
    "import argparse\n",
    "import random\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.utils import to_undirected, remove_self_loops, add_self_loops\n",
    "\n",
    "from logger import *\n",
    "from dataset import load_dataset\n",
    "from data_utils import eval_acc, eval_rocauc, load_fixed_splits\n",
    "from eval import *\n",
    "from parse import parse_method, parser_add_main_args\n",
    "\n",
    "from my_utils.utils import spade,hnsw,construct_adj, spectral_embedding_eig,SPF,construct_weighted_adj,spade_nonetworkx\n",
    "from scipy.sparse import csr_matrix\n",
    "\n",
    "\n",
    "device = torch.device(\"cuda\")\n",
    "### Load and preprocess data ###\n",
    "dataset = load_dataset('./data/', 'roman-empire')\n",
    "if len(dataset.label.shape) == 1:\n",
    "    dataset.label = dataset.label.unsqueeze(1)\n",
    "dataset.label = dataset.label.to(device)\n",
    "# split_idx_lst = load_fixed_splits('./data/', dataset, name=dataset)\n",
    "### Basic information of datasets ###\n",
    "n = dataset.graph['num_nodes']\n",
    "e = dataset.graph['edge_index'].shape[1]\n",
    "c = max(dataset.label.max().item() + 1, dataset.label.shape[1])\n",
    "d = dataset.graph['node_feat'].shape[1]\n",
    "dataset.graph['edge_index'] = to_undirected(dataset.graph['edge_index'])\n",
    "dataset.graph['edge_index'], _ = remove_self_loops(dataset.graph['edge_index'])\n",
    "dataset.graph['edge_index'], _ = add_self_loops(dataset.graph['edge_index'], num_nodes=n)\n",
    "dataset.graph['edge_index'], dataset.graph['node_feat'] = \\\n",
    "    dataset.graph['edge_index'].to(device), dataset.graph['node_feat'].to(device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[    0,     0,     1,  ..., 22659, 22660, 22661],\n",
       "        [    1,     2,     0,  ..., 22659, 22660, 22661]], device='cuda:0')"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset.graph['edge_index']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.0493,  0.0997,  0.0219,  ...,  0.0593, -0.0207,  0.0046],\n",
       "        [ 0.0922, -0.0325,  0.0003,  ...,  0.1769, -0.0559,  0.1271],\n",
       "        [ 0.0067,  0.0278,  0.0351,  ...,  0.1029, -0.0392, -0.0053],\n",
       "        ...,\n",
       "        [ 0.0142, -0.0185,  0.0129,  ...,  0.0558, -0.0206, -0.0319],\n",
       "        [ 0.0120, -0.0353, -0.0055,  ...,  0.0438, -0.0054,  0.0237],\n",
       "        [-0.0073,  0.0157, -0.0554,  ...,  0.0766,  0.0084, -0.0336]],\n",
       "       device='cuda:0')"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset.graph['node_feat']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def edge_index_to_csr(edge_index, num_nodes=None):\n",
    "    # Ensure edge_index is on CPU\n",
    "    edge_index = edge_index.cpu()\n",
    "    # Extract source and target nodes\n",
    "    row = edge_index[0].numpy()\n",
    "    col = edge_index[1].numpy()\n",
    "    # Create data array (1s for unweighted graph)\n",
    "    data = np.ones(row.shape[0], dtype=np.float32)\n",
    "    # Infer number of nodes if not provided\n",
    "    if num_nodes is None:\n",
    "        num_nodes = max(row.max(), col.max()) + 1\n",
    "    # Create the CSR matrix\n",
    "    adj_csr = csr_matrix((data, (row, col)), shape=(num_nodes, num_nodes))\n",
    "\n",
    "    return adj_csr\n",
    "\n",
    "test123 = edge_index_to_csr(dataset.graph['edge_index'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<22662x22662 sparse matrix of type '<class 'numpy.float32'>'\n",
       "\twith 88516 stored elements in Compressed Sparse Row format>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test123"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Polynormer",
   "language": "python",
   "name": "new_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
