{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ffec5cc1-fdba-4f14-ba72-bf9c8a7f19ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import argparse\n",
    "from tqdm import tqdm\n",
    "import random\n",
    "import logging\n",
    "import time\n",
    "import os\n",
    "import sys\n",
    "from torch_geometric.data import Data\n",
    "import networkx as nx\n",
    "from torch_geometric.utils import to_undirected\n",
    "import torch.nn.functional as F\n",
    "from brec.dataset import BRECDataset\n",
    "from brec.evaluator import evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "70dfaad1-6b84-4e8e-9612-d6cf77bd6270",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mon_op(A, B):\n",
    "    return A + B + torch.mm(A, B)\n",
    "def image(X, m):\n",
    "    \n",
    "    n = X.size(0)\n",
    "    device, out_dtype = X.device, torch.float64\n",
    "\n",
    "    cover = torch.zeros(n, n, n, device=device, dtype=out_dtype)\n",
    "    if m == 0:\n",
    "        return cover\n",
    "\n",
    "    \n",
    "    base_dec = (X != 0).to(out_dtype)\n",
    "    wdec = X.clone()  \n",
    "\n",
    "    for i in range(n):\n",
    "        dec = base_dec.clone()\n",
    "\n",
    "        \n",
    "        cover[i].t()[i] = X.t()[i]   \n",
    "        dec[i] = 0                   \n",
    "        dec[:, i] = 0               \n",
    "\n",
    "        c = 1\n",
    "        while c < m:\n",
    "            \n",
    "            row_active = (cover[i].sum(dim=1) != 0)   \n",
    "            col_active = (dec.sum(dim=0) != 0)        \n",
    "            mask = row_active & col_active            \n",
    "            M = mask.to(out_dtype).unsqueeze(0).expand(n, -1)  \n",
    "\n",
    "            Md = M * dec                               \n",
    "            # IMPORTANT: element-wise product with transpose (NOT matmul)\n",
    "            om = Md - (Md * Md.t())                    \n",
    "\n",
    "            # Update cover \n",
    "            cover[i] = mon_op(om * wdec, cover[i])    \n",
    "\n",
    "            # Update dec \n",
    "            dec = dec - (om + Md.t())\n",
    "            c += 1\n",
    "\n",
    "    return torch.log1p(cover)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a697fb55-33bf-44fe-82bc-d8732345e609",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math, numbers\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "def _tuple_contains_inf(x):\n",
    "    if isinstance(x, torch.Tensor):\n",
    "        return torch.isinf(x).any().item()\n",
    "    if isinstance(x, np.ndarray):\n",
    "        return np.isinf(x).any()\n",
    "    if isinstance(x, (list, tuple)):\n",
    "        return any(_tuple_contains_inf(v) for v in x)\n",
    "    if isinstance(x, numbers.Real):\n",
    "        return math.isinf(float(x))\n",
    "    return False\n",
    "\n",
    "def _tuple_contains_nan(x):\n",
    "    if isinstance(x, torch.Tensor):\n",
    "        return torch.isnan(x).any().item()\n",
    "    if isinstance(x, np.ndarray):\n",
    "        return np.isnan(x).any()\n",
    "    if isinstance(x, (list, tuple)):\n",
    "        return any(_tuple_contains_nan(v) for v in x)\n",
    "    if isinstance(x, numbers.Real):\n",
    "        return math.isnan(float(x))\n",
    "    return False\n",
    "\n",
    "def _tuple_contains_nonfinite(x):\n",
    "    # True if any NaN or ±inf anywhere\n",
    "    if isinstance(x, torch.Tensor):\n",
    "        return (~torch.isfinite(x)).any().item()\n",
    "    if isinstance(x, np.ndarray):\n",
    "        return (~np.isfinite(x)).any()\n",
    "    if isinstance(x, (list, tuple)):\n",
    "        return any(_tuple_contains_nonfinite(v) for v in x)\n",
    "    if isinstance(x, numbers.Real):\n",
    "        return not math.isfinite(float(x))\n",
    "    return False\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9aa78522-5ea9-4c26-9f51-4d782bb7b4ae",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "First item in graph_tuple_list: [<networkx.classes.graph.Graph object at 0x000002B93C9206D0>\n",
      " <networkx.classes.graph.Graph object at 0x000002B93C9207F0>]\n",
      "Type of first graph in tuple: <class 'networkx.classes.graph.Graph'>\n",
      "snn test starting ---\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "00%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:02<00:00, 21.27it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Basic part costs time 2.84; Correct in 60 / 60\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "00%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:12<00:00,  8.30it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Regular part costs time 60.16; Correct in 100 / 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "00%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:05<00:00, 19.92it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extension part costs time 8.78; Correct in 100 / 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "00%|████████████████████████████████████████████████████████████████████████████████| 100/100 [01:43<00:00,  1.03s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CFI part costs time 615.3; Correct in 100 / 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "00%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:10<00:00,  1.89it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4-Vertex_Condition part costs time 63.47; Correct in 20 / 20\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "00%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  3.21it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Distance_Regular part costs time 37.36; Correct in 20 / 20\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████| 400/400 [02:15<00:00,  2.95it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reliability part costs time 775.53; Correct in 400 / 400\n",
      "Costs time 1563.44; Correct in 800 / 400, Acc = 2.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "\"\"\"We utilized the Non-GNNs code from the BREC dataset repository (https://github.com/GraphPKU/BREC/tree/Release/Non-GNNs)\n",
    "and integrated our code as a function named snn within this framework.\"\"\"\n",
    "\n",
    "# Set random seeds for reproducibility\n",
    "np.random.seed(2022)\n",
    "random.seed(2022)\n",
    "\n",
    "# Placeholder functions for methods not used\n",
    "def func_None():\n",
    "    raise NotImplementedError(f\"Cannot find func {args.method}\")\n",
    "\n",
    "\n",
    "import networkx as nx\n",
    "import torch\n",
    "\n",
    "def snn(gr):\n",
    "    # Ensure the input is a NetworkX graph\n",
    "    if not isinstance(gr, nx.Graph):\n",
    "        raise TypeError(f\"Expected a NetworkX graph, got {type(gr)}\")\n",
    "    \n",
    "    # Get the number of nodes and edges\n",
    "    num_nodes = gr.number_of_nodes()\n",
    "    num_edges = gr.number_of_edges()\n",
    "    \n",
    "    # Create adjacency matrix from the graph\n",
    "    Ad_mat = torch.tensor(nx.to_numpy_array(gr), dtype=torch.float64)\n",
    "    # Apply custom processing\n",
    "    Image = image(Ad_mat,25)  \n",
    "    su = torch.sum(Image, 0)\n",
    "    output_snn_beta = mon_op(mon_op(su.t(),su)**(1/4),mon_op(su.t(),su)**(1/4))\n",
    "    s = torch.var(output_snn_beta)\n",
    "    det=torch.linalg.det(output_snn_beta)\n",
    "    det_su=torch.linalg.det(su)\n",
    "    result = (det, s)\n",
    "\n",
    "    if _tuple_contains_nan(result):\n",
    "        print(\"[WARNING] snn returned NaN in its output tuple:\", result)\n",
    "    if _tuple_contains_inf(result):\n",
    "        print(\"[WARNING] snn returned ±inf in its output tuple:\", result)\n",
    "    # or a single combined check:\n",
    "    if _tuple_contains_nonfinite(result):\n",
    "        print(\"[WARNING] snn returned non-finite (NaN/±inf):\", result)\n",
    "\n",
    "    return (det,s)\n",
    "\n",
    "\n",
    "# Dictionary mapping methods to their respective functions\n",
    "func_dict = {\n",
    "    \"fwl\": func_None,  \n",
    "    \"wl\": func_None,   \n",
    "    \"snn\": snn,  \n",
    "}\n",
    "\n",
    "def wl_method(method, G, k=None, mode=None):\n",
    "    return func_dict.get(method, func_None)(G)\n",
    "\n",
    "# Define dataset partitions\n",
    "part_dict = {\n",
    "    \"Basic\": (0, 60),\n",
    "    \"Regular\": (60, 160),\n",
    "    \"Extension\": (160, 260),\n",
    "    \"CFI\": (260, 360),\n",
    "    \"4-Vertex_Condition\": (360, 380),\n",
    "    \"Distance_Regular\": (380, 400),\n",
    "    \"Reliability\": (400, 800),\n",
    "}\n",
    "\n",
    "# Handle argument parsing for terminal and interactive environments\n",
    "if len(sys.argv) > 1 and \"--file\" in sys.argv:\n",
    "    parser = argparse.ArgumentParser(description=\"Test non-GNN methods on BREC.\")\n",
    "    parser.add_argument(\"--file\", type=str, default=\"brec_nonGNN.npy\")\n",
    "    parser.add_argument(\"--method\", type=str, default=\"snn\")\n",
    "    parser.add_argument(\"--graph_type\", type=str, default=\"none\")\n",
    "    args = parser.parse_args()\n",
    "else:\n",
    "    # Manual argument setup for interactive environments\n",
    "    class Args:\n",
    "        file = r\"C:\\brec_nonGNN.npy\"  # Path to the dataset file\n",
    "        method = \"snn\"      \n",
    "        graph_type = \"none\"      \n",
    "\n",
    "    args = Args()\n",
    "\n",
    "G_TYPE = args.graph_type.strip()\n",
    "if G_TYPE == \"none\":\n",
    "    method_name = args.method\n",
    "else:\n",
    "    if G_TYPE in part_dict:\n",
    "        method_name = f\"{args.method}_{G_TYPE}\"\n",
    "    else:\n",
    "        raise NotImplementedError(f\"{G_TYPE} does not exist!\")\n",
    "\n",
    "path = os.path.join(\"result\", method_name)\n",
    "os.makedirs(path, exist_ok=True)\n",
    "os.makedirs(os.path.join(path, \"part_result\"), exist_ok=True)\n",
    "\n",
    "LOG_FORMAT = \"%(asctime)s - %(levelname)s - %(message)s\"\n",
    "DATE_FORMAT = \"%m/%d/%Y %H:%M:%S %p\"\n",
    "logging.basicConfig(\n",
    "    filename=os.path.join(path, \"logging.log\"),\n",
    "    level=logging.INFO,\n",
    "    format=LOG_FORMAT,\n",
    "    datefmt=DATE_FORMAT,\n",
    ")\n",
    "logging.info(args)\n",
    "\n",
    "def count_distinguish_num(graph_tuple_list):\n",
    "    logging.info(f\"{method_name} test starting ---\")\n",
    "    print(f\"{method_name} test starting ---\")\n",
    "\n",
    "    cnt = 0\n",
    "    correct_list = []\n",
    "    time_cost = 0\n",
    "    DATA_NUM = (\n",
    "        400 if G_TYPE == \"none\" else int(part_dict[G_TYPE][1] - part_dict[G_TYPE][0])\n",
    "    )\n",
    "\n",
    "    for part_name, part_range in part_dict.items():\n",
    "        if not (G_TYPE == \"none\" or G_TYPE == part_name):\n",
    "            continue\n",
    "\n",
    "        logging.info(f\"{part_name} part starting ---\")\n",
    "\n",
    "        cnt_part = 0\n",
    "        correct_list_part = []\n",
    "        start = time.process_time()\n",
    "\n",
    "        for id in tqdm(range(part_range[0], part_range[1])):\n",
    "            graph_tuple = graph_tuple_list[id]\n",
    "            if not wl_method(\n",
    "                args.method, graph_tuple[0]\n",
    "            ) == wl_method(args.method, graph_tuple[1]):\n",
    "                cnt += 1\n",
    "                cnt_part += 1\n",
    "                correct_list.append(id)\n",
    "                correct_list_part.append(id)\n",
    "            else:\n",
    "                logging.info(f\"Wrong in {id}\")\n",
    "\n",
    "        end = time.process_time()\n",
    "        time_cost_part = round(end - start, 2)\n",
    "        time_cost += time_cost_part\n",
    "\n",
    "        logging.info(\n",
    "            f\"{part_name} part costs time {time_cost_part}; Correct in {cnt_part} / {part_range[1] - part_range[0]}\"\n",
    "        )\n",
    "        print(\n",
    "            f\"{part_name} part costs time {time_cost_part}; Correct in {cnt_part} / {part_range[1] - part_range[0]}\"\n",
    "        )\n",
    "        np.save(os.path.join(path, \"part_result\", part_name), correct_list_part)\n",
    "\n",
    "    time_cost = round(time_cost, 2)\n",
    "    Acc = round(cnt / DATA_NUM, 2)\n",
    "\n",
    "    logging.info(f\"Costs time {time_cost}; Correct in {cnt} / {DATA_NUM}, Acc = {Acc}\")\n",
    "    print(f\"Costs time {time_cost}; Correct in {cnt} / {DATA_NUM}, Acc = {Acc}\")\n",
    "\n",
    "    np.save(os.path.join(path, \"result\"), correct_list)\n",
    "\n",
    "    return\n",
    "\n",
    "def main():\n",
    "    graph_tuple_list = np.load(args.file, allow_pickle=True)\n",
    "    \n",
    "    \n",
    "    print(\"First item in graph_tuple_list:\", graph_tuple_list[0])  \n",
    "    print(\"Type of first graph in tuple:\", type(graph_tuple_list[0][0]))  \n",
    "    count_distinguish_num(graph_tuple_list)\n",
    "\n",
    "    \n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac9af4f5-a33b-4ccc-ab4c-7451514a0f85",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
