{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ls5uvXWJEtGj"
      },
      "source": [
        "# Code for Layer-diverse Negative sampling GCN\n",
        "Thsi code can be directly run on google colab."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dR5F_DpFFOR0"
      },
      "source": [
        "# 0. Pre"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "HC1O7XNFX7Np",
        "outputId": "40c99b5e-e4e7-488a-a3f5-c4b225adae22"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "1.13.1+cu116\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m9.4/9.4 MB\u001b[0m \u001b[31m41.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.5/4.5 MB\u001b[0m \u001b[31m24.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h  Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
            "  Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
            "  Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m280.2/280.2 KB\u001b[0m \u001b[31m7.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m512.4/512.4 KB\u001b[0m \u001b[31m25.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h  Building wheel for torch-geometric (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
            "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
            "Collecting pykeops\n",
            "  Downloading pykeops-2.1.1.tar.gz (87 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m87.4/87.4 KB\u001b[0m \u001b[31m4.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.8/dist-packages (from pykeops) (1.21.6)\n",
            "Collecting pybind11\n",
            "  Downloading pybind11-2.10.3-py3-none-any.whl (222 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m222.4/222.4 KB\u001b[0m \u001b[31m10.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting keopscore==2.1.1\n",
            "  Downloading keopscore-2.1.1.tar.gz (84 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m84.6/84.6 KB\u001b[0m \u001b[31m9.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "Building wheels for collected packages: pykeops, keopscore\n",
            "  Building wheel for pykeops (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for pykeops: filename=pykeops-2.1.1-py3-none-any.whl size=112292 sha256=860c4affa92e98e1b8809768de6b8116798c097dbeb6a0b9d11057799d3e0587\n",
            "  Stored in directory: /root/.cache/pip/wheels/d7/4e/bf/e93e607209605d0374bb41fce87fc39623c94cd40e2740a2fb\n",
            "  Building wheel for keopscore (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for keopscore: filename=keopscore-2.1.1-py3-none-any.whl size=148013 sha256=78a9ba0c7d02a0cf537efb16865e35cda2ae0ddc259295df40fdec97d39d403b\n",
            "  Stored in directory: /root/.cache/pip/wheels/5f/bb/b4/11efe588aaa15cb3ee7cefb93cd7c429daf0a18857fc7887be\n",
            "Successfully built pykeops keopscore\n",
            "Installing collected packages: pybind11, keopscore, pykeops\n",
            "Successfully installed keopscore-2.1.1 pybind11-2.10.3 pykeops-2.1.1\n",
            "[KeOps] Warning : cuda was detected, but driver API could not be initialized. Switching to cpu only.\n"
          ]
        }
      ],
      "source": [
        "# Install required packages.\n",
        "import os\n",
        "import torch\n",
        "os.environ['TORCH'] = torch.__version__\n",
        "print(torch.__version__)\n",
        "\n",
        "!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html\n",
        "!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html\n",
        "!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git\n",
        "\n",
        "# !pip install dppy\n",
        "!pip install pykeops\n",
        "\n",
        "from pykeops.torch import LazyTensor\n",
        "\n",
        "import torch\n",
        "import numpy as np\n",
        "import networkx as nx\n",
        "import matplotlib.pyplot as plt"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "X2kcBreSqV1u"
      },
      "outputs": [],
      "source": [
        "# This code is for calculation of Cluster Overlap Rate / Not the main algorithm\n",
        "# K-means\n",
        "use_cuda = torch.cuda.is_available()\n",
        "dtype = torch.float32 if use_cuda else torch.float64\n",
        "device_id = \"cuda:0\" if use_cuda else \"cpu\"\n",
        "def KMeans(x, K=10, Niter=10, verbose=True):\n",
        "    \"\"\"Implements Lloyd's algorithm for the Euclidean metric.\"\"\"\n",
        "\n",
        "    start = time.time()\n",
        "    N, D = x.shape  # Number of samples, dimension of the ambient space\n",
        "\n",
        "    c = x[:K, :].clone()  # Simplistic initialization for the centroids\n",
        "\n",
        "    x_i = LazyTensor(x.view(N, 1, D))  # (N, 1, D) samples\n",
        "    c_j = LazyTensor(c.view(1, K, D))  # (1, K, D) centroids\n",
        "\n",
        "    # K-means loop:\n",
        "    # - x  is the (N, D) point cloud,\n",
        "    # - cl is the (N,) vector of class labels\n",
        "    # - c  is the (K, D) cloud of cluster centroids\n",
        "    for i in range(Niter):\n",
        "\n",
        "        # E step: assign points to the closest cluster -------------------------\n",
        "        D_ij = ((x_i - c_j) ** 2).sum(-1)  # (N, K) symbolic squared distances\n",
        "        cl = D_ij.argmin(dim=1).long().view(-1)  # Points -> Nearest cluster\n",
        "\n",
        "        # M step: update the centroids to the normalized cluster average: ------\n",
        "        # Compute the sum of points per cluster:\n",
        "        c.zero_()\n",
        "        c.scatter_add_(0, cl[:, None].repeat(1, D), x)\n",
        "\n",
        "        # Divide by the number of points per cluster:\n",
        "        Ncl = torch.bincount(cl, minlength=K).type_as(c).view(K, 1)\n",
        "        c /= Ncl  # in-place division to compute the average\n",
        "\n",
        "    if verbose:  # Fancy display -----------------------------------------------\n",
        "        if use_cuda:\n",
        "            torch.cuda.synchronize()\n",
        "        end = time.time()\n",
        "        print(\n",
        "            f\"K-means for the Euclidean metric with {N:,} points in dimension {D:,}, K = {K:,}:\"\n",
        "        )\n",
        "        print(\n",
        "            \"Timing for {} iterations: {:.5f}s = {} x {:.5f}s\\n\".format(\n",
        "                Niter, end - start, Niter, (end - start) / Niter\n",
        "            )\n",
        "        )\n",
        "\n",
        "    return cl, c"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qhip7suoFsGG"
      },
      "source": [
        "# 1 Dataset\n",
        "## 1.1 Load Dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "zAPuaAzbYrMv",
        "outputId": "7937b26e-de8a-4c64-aa9e-e4f1bd2301f7"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n",
            "Dataset: Cora():\n",
            "======================\n",
            "Number of graphs: 1\n",
            "Number of features: 1433\n",
            "Number of classes: 7\n",
            "\n",
            "Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])\n",
            "===========================================================================================================\n",
            "Number of nodes: 2708\n",
            "Number of edges: 10556\n",
            "Average node degree: 3.90\n",
            "Contains isolated nodes: False\n",
            "Contains self-loops: False\n",
            "Is undirected: True\n"
          ]
        }
      ],
      "source": [
        "from torch_geometric.datasets import Planetoid, Amazon, Coauthor\n",
        "from torch_geometric.transforms import NormalizeFeatures\n",
        "\n",
        "DatasetName = 'Cora'\n",
        "dataset = Planetoid(root='data/Planetoid', name=DatasetName, transform=NormalizeFeatures())\n",
        "# dataset = Coauthor(root='data/Coauthor', name=DatasetName, transform=NormalizeFeatures())\n",
        "# dataset = Amazon(root='data/Amazon', name=DatasetName, transform=NormalizeFeatures())\n",
        "\n",
        "print()\n",
        "print(f'Dataset: {dataset}:')\n",
        "print('======================')\n",
        "print(f'Number of graphs: {len(dataset)}')\n",
        "print(f'Number of features: {dataset.num_features}')\n",
        "print(f'Number of classes: {dataset.num_classes}')\n",
        "\n",
        "data = dataset[0]\n",
        "\n",
        "print()\n",
        "print(data)\n",
        "print('===========================================================================================================')\n",
        "\n",
        "# Gather some statistics about the graph.\n",
        "print(f'Number of nodes: {data.num_nodes}')\n",
        "print(f'Number of edges: {data.num_edges}')\n",
        "print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')\n",
        "print(f'Contains isolated nodes: {data.has_isolated_nodes()}')\n",
        "print(f'Contains self-loops: {data.has_self_loops()}')\n",
        "print(f'Is undirected: {data.is_undirected()}')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "YXvQHFkh0Yh8",
        "outputId": "5fd54cbe-d2c7-4178-8bce-f010ea76e154"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Number of training nodes: 140\n",
            "Training node label rate: 0.0517\n"
          ]
        }
      ],
      "source": [
        "'''\n",
        "When choose Amazon, Coauthor\n",
        "Need add mask for train/val/test\n",
        "'''\n",
        "# train_mask = torch.zeros(data.num_nodes, dtype=torch.bool)\n",
        "# print(train_mask.sum())\n",
        "\n",
        "# for cls in range(dataset.num_classes):\n",
        "#   index = np.argwhere(data.y.numpy()==cls)[0:20]\n",
        "#   index = np.squeeze(index)\n",
        "#   # print(index)\n",
        "#   train_mask[index] = True\n",
        "\n",
        "# test_mask = torch.zeros(data.num_nodes, dtype=torch.bool)\n",
        "# test_mask[-1001:-1] = True\n",
        "# print(test_mask.sum())\n",
        "# val_mask = torch.zeros(data.num_nodes, dtype=torch.bool)\n",
        "# val_mask[-2001:-1001] = True\n",
        "# print(val_mask.sum())\n",
        "# data.train_mask = train_mask\n",
        "# data.test_mask = test_mask\n",
        "# data.val_mask = val_mask\n",
        "# print(data)\n",
        "\n",
        "print(f'Number of training nodes: {data.train_mask.sum()}')\n",
        "print(f'Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.4f}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DzjfXQv008pG"
      },
      "source": [
        "##1.2 Get the largest connected subgraph for shortest-path"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Zr1drKdbYt6y"
      },
      "outputs": [],
      "source": [
        "from torch_geometric.utils import to_networkx, from_networkx\n",
        "import networkx as nx\n",
        "\n",
        "def GetMaxConnectGraph(OrginData):\n",
        "  G = to_networkx(OrginData, to_undirected=True)\n",
        "  Subnode = max(nx.connected_components(G), key=len)\n",
        "  print(f'Number of nodes of max-connect Graph: {len(Subnode)}')\n",
        "\n",
        "  SubGraph = G.subgraph(Subnode)\n",
        "  #Change from networkx to pytorch\n",
        "  PyTSubGraph = from_networkx(SubGraph)\n",
        "  SubGraph = to_networkx(PyTSubGraph, to_undirected=True)\n",
        "\n",
        "  #Make new sub-graph data\n",
        "  SubnodeList = list(Subnode)\n",
        "  PyTSubGraph.test_mask = data.test_mask[SubnodeList]\n",
        "  PyTSubGraph.train_mask = data.train_mask[SubnodeList]\n",
        "  PyTSubGraph.val_mask = data.val_mask[SubnodeList]\n",
        "  PyTSubGraph.x = data.x[SubnodeList]\n",
        "  PyTSubGraph.y = data.y[SubnodeList]\n",
        "  return PyTSubGraph,SubGraph"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "EnarMfhLqxrT",
        "outputId": "c7a0a1c2-19e7-4364-9343-f131de676139"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Number of nodes of max-connect Graph: 2485\n",
            "Data(edge_index=[2, 10138], num_nodes=2485, test_mask=[2485], train_mask=[2485], val_mask=[2485], x=[2485, 1433], y=[2485])\n",
            "PyTSubGraph: Data(edge_index=[2, 10138], num_nodes=2485, test_mask=[2485], train_mask=[2485], val_mask=[2485], x=[2485, 1433], y=[2485])\n",
            "Average degree: 4.0796780684104625\n"
          ]
        }
      ],
      "source": [
        "\n",
        "\n",
        "PyTSubGraph, SubGraph = GetMaxConnectGraph(data)\n",
        "print(PyTSubGraph)\n",
        "num_nodes = PyTSubGraph.num_nodes\n",
        "edge_index = PyTSubGraph.edge_index\n",
        "\n",
        "avgDegree = edge_index.shape[1]/num_nodes\n",
        "print(f'PyTSubGraph: {PyTSubGraph}')\n",
        "print(f'Average degree: {avgDegree}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dgZxKUThk8CY"
      },
      "source": [
        "## 1.3 Get Candidate Sets\n",
        "### 1.3.1 Get Path-diction of selected node"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vYZEqXEVaIwd"
      },
      "outputs": [],
      "source": [
        "import time\n",
        "import random\n",
        "from networkx import single_source_shortest_path,single_source_shortest_path_length\n",
        "import collections\n",
        "from collections import defaultdict\n",
        "'''\n",
        "  Using Shortest Path Dictionary\n",
        "  Save the last node of a path:\n",
        "      1) length >2\n",
        "      2) the degree of the beginning node of this path >10\n",
        "'''\n",
        "def GetNodeLenDict(SPathDict,DegreeList,avgDegree):\n",
        "  LenDict = defaultdict(list)\n",
        "  for node in SPathDict:\n",
        "    templen = SPathDict[node]\n",
        "    if templen > 2 and DegreeList[node]>avgDegree: # Exclusion of first-order nearest neighbors\n",
        "      LenDict[templen].append(node)\n",
        "  return LenDict"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ShPcXj8ElGYu"
      },
      "outputs": [],
      "source": [
        "'''\n",
        "  1 calculate Degree of every node in sparse Adj matrix\n",
        "'''\n",
        "DegreeList = []\n",
        "out_list = edge_index[0].tolist()\n",
        "d = collections.Counter(out_list)\n",
        "DegreeList = []\n",
        "for i in range(num_nodes):\n",
        "  DegreeList.append(d[i])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xX1jw4TK3skI"
      },
      "outputs": [],
      "source": [
        "'''\n",
        "  2 Random Select 50% nodes to get shortest path of every node\n",
        "    Save the last node of path in SPEndNodeDict\n",
        "\n",
        "    Example:\n",
        "    1) Random select central Node \"6\" as \"i\" in \"SelCenNodeList\"\n",
        "    2) SPathDict: Get all shortest path of \"6\" to other Nodes\n",
        "    3) LenDict: Get last nodes diction for every lenth\n",
        "        LenDict.keys() : dict_keys([3, 4, 5, 6, 7, 8, 9])\n",
        "        LenDict[4] : [16384, 16387, 16399, 18, 16405....]\n",
        "'''\n",
        "SPEndNodeDict = defaultdict(dict)\n",
        "SelCenNodeList = random.sample(range(0, num_nodes), int(num_nodes/100))\n",
        "for i in SelCenNodeList:\n",
        "    SPathDict = single_source_shortest_path_length(SubGraph, i,cutoff=10)\n",
        "    LenDict = GetNodeLenDict(SPathDict,DegreeList,avgDegree)\n",
        "    SPEndNodeDict[i]=LenDict"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qZjDbhdfUBmq"
      },
      "source": [
        "### 1.3.2 Get candidates set"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UZW--YFmdMUB"
      },
      "outputs": [],
      "source": [
        "'''\n",
        "Get central node on shortest path\n",
        "'''\n",
        "def GetCandidateNode(SPEndNodeDict,SelCenNodeList,SelectNum):\n",
        "  candidates = defaultdict(list)\n",
        "  c = 0\n",
        "  for node in SelCenNodeList:\n",
        "    SelNode = []\n",
        "    select_num = SelectNum\n",
        "    for i in SPEndNodeDict[node]:\n",
        "      if len(SPEndNodeDict[node][i])>0:\n",
        "        randomSel = random.sample(SPEndNodeDict[node][i], 1)\n",
        "        # print(randomSel)\n",
        "        SelNode.extend(randomSel)\n",
        "        select_num = select_num -1\n",
        "      if select_num ==0:\n",
        "        break\n",
        "    candidates[node].extend(SelNode)\n",
        "  return candidates"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "8TwNPg1go7uR",
        "outputId": "061b392e-09de-44f6-d92d-e50af2ec2a8f"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "0.0007550716400146484\n",
            "defaultdict(<class 'list'>, {1434: [1660, 1231, 2087, 998, 1295], 1497: [877, 2021, 1856, 1024, 791], 690: [2070, 745, 1819, 1360, 998], 739: [1075, 1742, 202, 1521, 1985], 1245: [91, 1232, 2229, 239, 531], 2225: [813, 524, 1407, 262, 1152], 355: [1563, 1729, 1976, 361, 1618], 89: [2280, 121, 491, 714, 1805], 1039: [542, 596, 1651, 1918, 2270], 2085: [476, 203, 1089, 1204, 803], 2484: [497, 133, 2285, 2284, 828], 1678: [603, 961, 2093, 1564, 1102], 409: [1427, 25, 2016, 150, 935], 2295: [882, 1564, 782, 81, 202], 2000: [917, 801, 1061, 30, 1833], 2246: [823, 431, 1122, 1708, 1699], 1635: [458, 2040, 1779, 2345, 44], 533: [2082, 11, 999, 202, 626], 1067: [476, 376, 121, 1639, 958], 2287: [977, 444, 52, 1404, 838], 701: [727, 1660, 1412, 11, 1418], 636: [1744, 350, 1415, 1545, 25], 1275: [883, 1914, 2025, 48, 1835], 1490: [1640, 1708, 1504, 2280, 2388]})\n"
          ]
        }
      ],
      "source": [
        "begin = time.time()\n",
        "candidates = GetCandidateNode(SPEndNodeDict,SelCenNodeList,5)\n",
        "end = time.time()\n",
        "print(end-begin)\n",
        "print(candidates)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yGfAe5A-Spi7"
      },
      "outputs": [],
      "source": [
        "'''\n",
        "Get candidate set using central node and its first-order neighbors\n",
        "'''\n",
        "AvgNodeDgree = int(data.num_edges / data.num_nodes)+1\n",
        "def GetCandiSet(SelCenNodeList,candidates):\n",
        "  Adjcandidates = defaultdict(list)\n",
        "  for i in SelCenNodeList:\n",
        "    AdjList = []\n",
        "    for CenterIndex in candidates[i]:\n",
        "        EdgeIndex = torch.nonzero(edge_index[0]==CenterIndex)# get adj node index of central node\n",
        "        EdgeIndex = torch.reshape(EdgeIndex, (-1,))\n",
        "        AjdIndex = edge_index[1][EdgeIndex]\n",
        "        AjdIndex = AjdIndex.tolist()\n",
        "        if len(AjdIndex)>AvgNodeDgree:\n",
        "            AjdIndex =  random.sample(AjdIndex,AvgNodeDgree)\n",
        "        AdjList.extend(AjdIndex)\n",
        "        AdjList.append(CenterIndex)\n",
        "    Adjcandidates[i] = AdjList\n",
        "  return Adjcandidates"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "HIaRqtXvplCk",
        "outputId": "9eca48ad-f0ec-4871-c3c8-6781c1e5a18a"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "0.01597285270690918\n",
            "defaultdict(<class 'list'>, {1434: [878, 977, 293, 271, 1660, 2018, 1513, 224, 103, 1231, 1381, 2089, 2088, 1012, 2087, 543, 12, 1165, 995, 998, 742, 1792, 1418, 1796, 1295], 1497: [1942, 1619, 1057, 685, 877, 762, 944, 1458, 2022, 2021, 291, 1846, 1923, 1922, 1856, 1521, 1995, 763, 1672, 1024, 1462, 1063, 32, 387, 791], 690: [1051, 1208, 1601, 2069, 2070, 1877, 78, 1232, 48, 745, 831, 342, 104, 114, 1819, 2178, 928, 2175, 1534, 1360, 543, 1165, 1468, 2209, 998], 739: [745, 374, 78, 142, 1075, 78, 1201, 467, 469, 1742, 154, 1728, 500, 157, 202, 1034, 1024, 2226, 2125, 1521, 1015, 798, 1131, 1477, 1985], 1245: [121, 552, 438, 1652, 91, 1788, 21, 2158, 745, 1232, 954, 1387, 1289, 239, 2229, 146, 2229, 579, 2325, 239, 566, 2132, 2294, 304, 531], 2225: [806, 2228, 393, 689, 813, 603, 2190, 1672, 456, 524, 2084, 37, 2295, 1256, 1407, 826, 542, 162, 176, 262, 911, 2362, 446, 2462, 1152], 355: [1724, 1723, 1320, 496, 1563, 421, 178, 1002, 245, 1729, 703, 1754, 1727, 1974, 1976, 1001, 1063, 47, 970, 361, 1500, 1245, 665, 2441, 1618], 89: [1578, 2281, 553, 1114, 2280, 1491, 599, 837, 1905, 121, 831, 1247, 587, 935, 491, 2218, 579, 966, 821, 714, 803, 1806, 1166, 274, 1805], 1039: [1504, 1724, 1193, 1730, 542, 321, 1591, 466, 799, 596, 384, 95, 1635, 1486, 1651, 921, 364, 652, 126, 1918, 992, 2241, 882, 338, 2270], 2085: [896, 1245, 2307, 501, 476, 1846, 797, 1563, 202, 203, 65, 1039, 1245, 2201, 1089, 1202, 424, 1203, 2249, 1204, 1810, 274, 1802, 841, 803], 2484: [1797, 1400, 771, 1997, 497, 95, 1491, 416, 271, 133, 236, 1619, 2232, 685, 2285, 887, 1294, 1404, 1178, 2284, 2049, 664, 2037, 1932, 828], 1678: [524, 1979, 296, 357, 603, 906, 1450, 1474, 1314, 961, 2345, 1989, 2092, 332, 2093, 1831, 419, 1135, 1830, 1564, 2351, 2218, 1387, 1310, 1102], 409: [2407, 1640, 994, 30, 1427, 1298, 2185, 671, 2015, 25, 1195, 1236, 1380, 2015, 2016, 2038, 689, 211, 1555, 150, 114, 1821, 1823, 342, 935], 2295: [227, 1299, 999, 641, 882, 1252, 1830, 419, 2082, 1564, 308, 859, 785, 178, 782, 294, 954, 1982, 1861, 81, 348, 1728, 157, 1563, 202], 2000: [1259, 2419, 327, 542, 917, 1570, 1245, 229, 1358, 801, 2096, 267, 1503, 2224, 1061, 1427, 1508, 1966, 1643, 30, 444, 1361, 26, 751, 1833], 2246: [397, 29, 1773, 1189, 823, 271, 8, 1662, 1846, 431, 1403, 1034, 1402, 1120, 1122, 2259, 1769, 1261, 289, 1708, 1521, 1104, 2287, 896, 1699], 1635: [521, 1942, 771, 1641, 458, 1546, 1786, 61, 775, 2040, 1774, 1122, 710, 711, 1779, 128, 1038, 1152, 343, 2345, 1276, 645, 2063, 414, 44], 533: [131, 569, 2084, 1829, 2082, 138, 2460, 1936, 1934, 11, 1519, 1256, 1323, 79, 999, 1728, 500, 348, 157, 202, 1583, 1071, 1185, 1245, 626], 1067: [1071, 1245, 896, 501, 476, 1894, 271, 2321, 1731, 376, 271, 1642, 599, 91, 121, 1492, 2002, 1491, 115, 1639, 2225, 27, 838, 821, 958], 2287: [1157, 1665, 891, 868, 977, 1964, 1841, 1343, 1739, 444, 93, 2450, 1245, 1495, 52, 1533, 788, 53, 2284, 1404, 1411, 821, 958, 2218, 838], 701: [2201, 730, 65, 1695, 727, 1456, 1644, 271, 293, 1660, 1239, 2136, 2013, 84, 1412, 1936, 2460, 1934, 1894, 11, 220, 302, 1450, 1795, 1418], 636: [99, 1818, 782, 1745, 1744, 308, 2284, 1605, 1362, 350, 103, 583, 2036, 1036, 1415, 416, 631, 2162, 870, 1545, 633, 2015, 1298, 2185, 25], 1275: [2048, 1272, 667, 226, 883, 1992, 1042, 413, 1471, 1914, 2033, 45, 477, 1496, 2025, 704, 591, 1843, 1880, 48, 1303, 1257, 1540, 1833, 1835], 1490: [1492, 1633, 314, 1427, 1640, 1518, 289, 1122, 1769, 1708, 542, 2000, 1563, 20, 1504, 1578, 1114, 553, 1309, 2280, 1054, 2390, 2391, 660, 2388]})\n"
          ]
        }
      ],
      "source": [
        "begin = time.time()\n",
        "Adjcandidate = GetCandiSet(SelCenNodeList,candidates)\n",
        "end = time.time()\n",
        "print(end-begin)\n",
        "print(Adjcandidate)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lw8wGb-SUTuA"
      },
      "source": [
        "##1.4 Get Community"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "tDP0RPEI6STz",
        "outputId": "1a4696f2-1a78-444a-9044-d568100d4cc1"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "100\n"
          ]
        }
      ],
      "source": [
        "from networkx.algorithms import community\n",
        "NumbCom = 100\n",
        "Iteration =100\n",
        "FluidCom = community.asyn_fluidc(SubGraph,NumbCom,Iteration)\n",
        "LabelComu = list(FluidCom)\n",
        "print(len(LabelComu))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wvUgx8N0UskY"
      },
      "outputs": [],
      "source": [
        "def GetCommunityMatrix(label_community,num_nodes):\n",
        "  lcomuMatrix = torch.zeros(len(label_community), num_nodes)\n",
        "  com = 0\n",
        "  for c in label_community:\n",
        "    for ele in c:\n",
        "      lcomuMatrix[com][ele] = 1\n",
        "    com = com+1\n",
        "  return lcomuMatrix"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "70igsQ8lUwBz",
        "outputId": "68364976-7cef-42d7-f9e3-95b3ccbd0112"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "torch.Size([100, 2485])\n",
            "2485\n",
            "tensor([21., 17., 10., 29., 19., 27., 41., 25., 29., 82.])\n"
          ]
        }
      ],
      "source": [
        "# Get community matrix / Every node belong to different community\n",
        "LComuMatrix = GetCommunityMatrix(LabelComu,num_nodes)\n",
        "print(LComuMatrix.shape)\n",
        "print(LComuMatrix.shape[1])\n",
        "print(LComuMatrix.sum(1)[0:10])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qFCZiG_IHvSP"
      },
      "source": [
        "# 2 Define DPP"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zGnQ0_hIvzVC"
      },
      "outputs": [],
      "source": [
        "def elementary_symmetric_poly(k,N,eigenvalues):\n",
        "  e_M = 1e-6*torch.ones((N + 1, k + 1)) #e_M = e^n_l\n",
        "  e_M[:,0] = 1\n",
        "  for l in range(1,k+1):\n",
        "    v = e_M[:-1,l-1] * eigenvalues\n",
        "    e_l = torch.cumsum(v, dim=0)\n",
        "    e_M[1:,l] = e_l\n",
        "  return e_M\n",
        "\n",
        "def k_DPP_Sampling_VectorIndex(k,N,eigenvalues):\n",
        "  e_M = elementary_symmetric_poly(k,N,eigenvalues) # e_M[N+1,k+1]\n",
        "  J = []\n",
        "  l = k\n",
        "  u = torch.rand([N,k])\n",
        "  tt = torch.unsqueeze(eigenvalues, 1)\n",
        "  a =  tt*e_M[:-1,:-1]/e_M[1:,1:]\n",
        "  for n in range(N,0,-1):\n",
        "    if l == 0 :\n",
        "      # print(\"n\",n)\n",
        "      break\n",
        "    if u[n-1,l-1] < a[n-1,l-1]:\n",
        "      J.append(n-1)\n",
        "      l = l-1\n",
        "  return J"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JRb011jyv0W9"
      },
      "outputs": [],
      "source": [
        "def torch_proj_dpp_sampler_eig_KuTa12(eig_vecs, size=None, IsMask=False, Mask = None):\n",
        "    \"\"\" Sample from :math:`\\\\operatorname{DPP}(K)` using the\n",
        "          eigendecomposition of the similarity kernel :math:`K`.\n",
        "    It is based on the orthogonalization of the selected eigenvectors.\n",
        "    :param eig_vals:\n",
        "        Collection of eigen values of the similarity kernel :math:`K`.\n",
        "    :type eig_vals:\n",
        "        list\n",
        "    :return: list\n",
        "        A sample from :math:`\\\\operatorname{DPP}(K)`.\n",
        "    \"\"\"\n",
        "\n",
        "    # rng = check_random_state(random_state)\n",
        "\n",
        "    # Initialization\n",
        "    V = eig_vecs.clone()\n",
        "\n",
        "    N, rank = V.shape  # ground set size / rank(K)\n",
        "    if size is None or size > rank:  # full projection DPP\n",
        "        size = rank\n",
        "    # else: k-DPP with k = size\n",
        "\n",
        "    sampl = torch.zeros(size).int()  # sample list\n",
        "\n",
        "    # Phase 1: Already performed!\n",
        "    # Select eigvecs with Bernoulli variables with parameter the eigvals\n",
        "\n",
        "    # Phase 2: Chain rule\n",
        "    norms_2 = (V*V).sum(axis=1) # ||V_i:||^2\n",
        "    # print(rank)\n",
        "    # Following [Algo 1, KuTa12], the aim is to compute the orhto\n",
        "    #  complement of the subspace spanned by the selected eigenvectors\n",
        "    #  to the canonical vectors \\{e_i ; i \\in Y\\}. We proceed recursively.\n",
        "    for it in range(size):\n",
        "        # j = np.random.choice(N, p=np.abs(norms_2) / (rank - it))\n",
        "        probs = torch.abs(norms_2) /(rank - it)\n",
        "        # print(probs)\n",
        "\n",
        "        if torch.isnan(probs).any():\n",
        "          break\n",
        "\n",
        "        dist=torch.distributions.categorical.Categorical(probs=probs)\n",
        "        j = dist.sample()\n",
        "        sampl[it] = j\n",
        "        if it == size - 1:\n",
        "            break\n",
        "\n",
        "        # Cancel the contribution of e_i to the remaining vectors that is,\n",
        "        #  find the subspace of V that is orthogonal to \\{e_i ; i \\in Y\\}\n",
        "        # Take the index of a vector that has a non null contribution on e_j\n",
        "        k = torch.where(V[j, :] != 0)[0][0]\n",
        "        # Cancel the contribution of the remaining vectors on e_j, but stay\n",
        "        #  in the subspace spanned by V i.e. get the subspace of V orthogonal\n",
        "        #  to \\{e_i ; i \\in Y\\}\n",
        "        V -= torch.outer(V[:, k] / V[j, k], V[j, :])\n",
        "\n",
        "        # V_:j is set to 0 so we delete it and we can derive an orthononormal\n",
        "        #  basis of the subspace under consideration\n",
        "        # V, _ = la.qr(np.delete(V, k, axis=1), mode='economic')\n",
        "        tempIndex = list(range(V.shape[1]))\n",
        "        tempIndex.remove(k)\n",
        "        V_delet_k = V[:,tempIndex]\n",
        "\n",
        "        V, _ =  torch.linalg.qr(V_delet_k)\n",
        "        norms_2 = (V*V).sum(axis=1)   # ||V_i:||^2\n",
        "\n",
        "    return sampl.tolist()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MLEVizYmUauj"
      },
      "source": [
        "#3 Define LDGCN model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "52qmgivTlUXk"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "from torch_geometric.nn import MessagePassing\n",
        "from torch_geometric.utils import add_self_loops, degree\n",
        "from torch_scatter import gather_csr, scatter\n",
        "class GCNConv(MessagePassing):\n",
        "    def __init__(self, in_channels, out_channels):\n",
        "        super(GCNConv, self).__init__(aggr='add')  # \"Add\" aggregation.\n",
        "        self.lin = torch.nn.Linear(in_channels, out_channels)\n",
        "\n",
        "    def forward(self, x, edge_index):\n",
        "        # x has shape [N, in_channels]\n",
        "        # edge_index has shape [2, E]\n",
        "\n",
        "        # Step 1: Add self-loops to the adjacency matrix.\n",
        "        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))\n",
        "\n",
        "        # Step 2: Linearly transform node feature matrix.\n",
        "        x = self.lin(x)\n",
        "\n",
        "        # Step 3-5: Start propagating messages.\n",
        "        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)\n",
        "\n",
        "    def message(self, x_j, edge_index, size):\n",
        "        # x_j has shape [E, out_channels]\n",
        "        # edge_index has shape [2, E]\n",
        "\n",
        "        # Step 3: Normalize node features.\n",
        "        row, col = edge_index\n",
        "        deg = degree(row, size[0], dtype=x_j.dtype)  # [N, ]\n",
        "        deg_inv_sqrt = deg.pow(-0.5)   # [N, ]\n",
        "        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]\n",
        "\n",
        "        return norm.view(-1, 1) * x_j\n",
        "\n",
        "    def update(self, aggr_out):\n",
        "        # aggr_out has shape [N, out_channels]\n",
        "\n",
        "        # Step 5: Return new node embeddings.\n",
        "        return aggr_out"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RQ7elDEdldNk"
      },
      "outputs": [],
      "source": [
        "import gc\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import random\n",
        "\n",
        "from numpy import linalg as LA\n",
        "\n",
        "import time\n",
        "\n",
        "class LDGCN(torch.nn.Module):\n",
        "    def __init__(self,seed, hidden_channels,num_layers,num_nodes,LabelComu,LComuMatrix,DegreeList):\n",
        "        super(LDGCN, self).__init__()\n",
        "        torch.manual_seed(seed)\n",
        "\n",
        "        self.num_layers = num_layers\n",
        "        self.CONVs = torch.nn.ModuleList()\n",
        "        self.DPPCONVs = torch.nn.ModuleList()\n",
        "        self.NegRates = nn.ParameterList()\n",
        "        for layer in range(self.num_layers-1):\n",
        "            if layer == 0:\n",
        "                self.CONVs.append(GCNConv(dataset.num_features, hidden_channels))\n",
        "                self.DPPCONVs.append(GCNConv(dataset.num_features, hidden_channels))\n",
        "                self.NegRates.append(nn.Parameter(torch.FloatTensor([1])))\n",
        "            else:\n",
        "                self.CONVs.append(GCNConv(hidden_channels, hidden_channels))\n",
        "                self.DPPCONVs.append(GCNConv(hidden_channels, hidden_channels))\n",
        "                self.NegRates.append(nn.Parameter(torch.FloatTensor([1])))\n",
        "        self.CONVs.append(GCNConv(hidden_channels, dataset.num_classes))\n",
        "        self.DPPCONVs.append(GCNConv(hidden_channels, dataset.num_classes))\n",
        "        self.NegRates.append(nn.Parameter(torch.FloatTensor([1])))\n",
        "\n",
        "        self.LabelComu = LabelComu\n",
        "        self.LComuMatrix = LComuMatrix\n",
        "        self.DegreeList = DegreeList\n",
        "        self.num_classes = dataset.num_classes\n",
        "        self.num_nodes = num_nodes\n",
        "        # According to experimental experience, setting NegRate to 1~3 can achieve better results\n",
        "        self.NegRate = 1\n",
        "        self.DPPedge_index_list = []\n",
        "        self.ComuFeatureM = 0\n",
        "        self.TrueClusterList = []\n",
        "        self.MultiClusterList = []\n",
        "        self.LayerDppIndex = []\n",
        "\n",
        "\n",
        "    def forward(self, x, edge_index,Adjcandidate,DPPedge_index_list,train_neg_rate):\n",
        "        self.DPPedge_index_list = DPPedge_index_list\n",
        "        for layer in range(self.num_layers):\n",
        "          self.ComuFeatureM = self.GetComuFeatureMatrix(x,)\n",
        "          x_temp = x.clone()\n",
        "          posi_x = self.CONVs[layer](x, edge_index)\n",
        "          if len(self.DPPedge_index_list) < self.num_layers:\n",
        "            with torch.no_grad():\n",
        "              Dppedge_index = self.GetDppMatrix(x_temp,Adjcandidate)\n",
        "              self.DPPedge_index_list.append(Dppedge_index)\n",
        "          else:\n",
        "            Dppedge_index = self.DPPedge_index_list[layer]\n",
        "\n",
        "          nega_x = self.DPPCONVs[layer](x_temp, Dppedge_index)\n",
        "          if train_neg_rate: #NegRate is a trainable parameter\n",
        "            x = posi_x - self.NegRates[layer] * nega_x\n",
        "          else:\n",
        "            x = posi_x - self.NegRate * nega_x\n",
        "\n",
        "          if layer < (self.num_layers-1):\n",
        "            x = x.relu()\n",
        "            x = F.dropout(x, p=0.5, training=self.training)\n",
        "            if self.training:\n",
        "              T_cl, _ = KMeans(x, self.num_classes, verbose=False)\n",
        "              M_cl, _ = KMeans(x, 5*self.num_classes, verbose=False)\n",
        "              self.TrueClusterList.append(T_cl)\n",
        "              self.MultiClusterList.append(M_cl)\n",
        "        return x,self.DPPedge_index_list\n",
        "\n",
        "    def GetComuFeatureMatrix(self,x):\n",
        "        temp = 0\n",
        "        n = 0\n",
        "        ComuFeatureM = torch.zeros(len(self.LabelComu), x.shape[1]) #[number of Comu, number of features]\n",
        "        for c in self.LabelComu:\n",
        "          for ele in c:\n",
        "            temp= temp + x[ele]\n",
        "          ComuFeatureM[n] = temp / len(c)\n",
        "          n = n+1\n",
        "        return ComuFeatureM\n",
        "\n",
        "    def GetDppMatrix(self,x,Adjcandidate):\n",
        "        # for sparse adj matrix\n",
        "        beginList = []\n",
        "        endList = []\n",
        "        valueList = []\n",
        "        LstDppIndex = defaultdict(list)\n",
        "        for node in Adjcandidate:\n",
        "          nodeDegree = self.DegreeList[node]\n",
        "          if nodeDegree > len(Adjcandidate[node]):\n",
        "              end = Adjcandidate[node]\n",
        "          else:\n",
        "              #Get community index of central node \"node\"\n",
        "              cumIndex = torch.nonzero(self.LComuMatrix[:,node]).reshape(-1)\n",
        "              #Get feature of community of \"node\"\n",
        "              CNodeFeature = self.ComuFeatureM[cumIndex.item()]\n",
        "              #index list of cummunity of every node\n",
        "              AjdIndex = Adjcandidate[node]\n",
        "              cumiList = torch.nonzero(self.LComuMatrix[:,AjdIndex])[:,0]\n",
        "\n",
        "              Index = self.Rek_dpp_kernel(node, AjdIndex,x, cumiList, CNodeFeature, nodeDegree)\n",
        "              LstDppIndex[node].extend(Index)\n",
        "              AjdIndex = np.array(AjdIndex)\n",
        "              end = AjdIndex[Index]\n",
        "              if type(end) == np.int64:\n",
        "                end = [end]\n",
        "          value = [1]*len(end)\n",
        "          begin = [node]*len(end)\n",
        "          beginList.extend(begin)\n",
        "          endList.extend(end)\n",
        "          valueList.extend(value)\n",
        "        i = torch.LongTensor([beginList,endList])   #row, col\n",
        "        v = torch.FloatTensor(valueList)    #data\n",
        "        sparseMatrix = torch.sparse.FloatTensor(i, v, torch.Size([num_nodes,num_nodes]))\n",
        "        self.LayerDppIndex.append(LstDppIndex)\n",
        "        return sparseMatrix._indices()\n",
        "\n",
        "    def Rek_dpp_kernel(self, GivenNode, AjdIndex, x, cumiList,CNodeFeature,k_number):\n",
        "        # print(x)\n",
        "        NodeWeight = x[AjdIndex]\n",
        "        ComuWeight = self.ComuFeatureM[cumiList]\n",
        "        NodecosSim = self.CosSimilarity(NodeWeight,NodeWeight)\n",
        "        ComucosSim = self.CosSimilarity(NodeWeight,ComuWeight)\n",
        "\n",
        "        cos = torch.nn.CosineSimilarity(dim=0, eps=1e-6)\n",
        "        NodeFeatureCos = cos(CNodeFeature,(NodeWeight.sum(0)/NodeWeight.shape[0]))\n",
        "\n",
        "        CNodeFeMatrix = torch.ones_like(ComuWeight)\n",
        "        CNodeFeMatrix = CNodeFeMatrix * CNodeFeature\n",
        "        QijMatrix =  NodeFeatureCos * self.CosSimilarity(CNodeFeMatrix,ComuWeight)\n",
        "\n",
        "        Lmatrix = QijMatrix*torch.mm(ComucosSim, ComucosSim.transpose(0,1)) * torch.exp(NodecosSim-1)* QijMatrix.transpose(0,1)\n",
        "\n",
        "        Lmatrix = Lmatrix + 0.01*torch.eye(Lmatrix.shape[0])\n",
        "        EigVal, EigVec = torch.linalg.eigh(Lmatrix)\n",
        "        ind_selected = k_DPP_Sampling_VectorIndex(k_number,Lmatrix.shape[0],EigVal.real)\n",
        "\n",
        "        if len(self.DPPedge_index_list)>=1:\n",
        "          EigVec = self.GetMaskEigVector(GivenNode, x, AjdIndex,EigVec)\n",
        "          DppEigVec = EigVec[:, ind_selected].real\n",
        "          samples = torch_proj_dpp_sampler_eig_KuTa12(DppEigVec,k_number) #,IsMask = True, Mask = probMask)\n",
        "        else:\n",
        "          DppEigVec = EigVec[:, ind_selected].real\n",
        "          samples = torch_proj_dpp_sampler_eig_KuTa12(DppEigVec,k_number)\n",
        "\n",
        "        return samples\n",
        "\n",
        "    def GetMaskEigVector(self, GivenNode, x, AjdIndex,EigVec):\n",
        "        LstNegNodesList = []\n",
        "        # Get Node List of last layer negative sample of GivenNode\n",
        "        tempIndex = torch.nonzero(self.DPPedge_index_list[-1][0] == GivenNode).squeeze()\n",
        "        tempList = self.DPPedge_index_list[-1][1,tempIndex].tolist()\n",
        "        if isinstance(tempList, int):\n",
        "          LstNegNodesList.extend([tempList])\n",
        "        else:\n",
        "          LstNegNodesList.extend(tempList)\n",
        "\n",
        "        # comuSim = self.CosSimilarity(ComuWeight,LstNegNodesCom)\n",
        "        NodeSim = torch.cdist(x[AjdIndex], x[LstNegNodesList], p=2)\n",
        "\n",
        "        NodeSim = self.CosSimilarity(x[AjdIndex], x[LstNegNodesList])\n",
        "\n",
        "        simValue, _ = torch.max(NodeSim, 1)\n",
        "\n",
        "        lastIndex = self.LayerDppIndex[-1][GivenNode]\n",
        "        ThreshIndex = []\n",
        "        TempIndex = (simValue>0.9).nonzero().squeeze().tolist()\n",
        "        if isinstance(TempIndex, int):\n",
        "          ThreshIndex.extend([TempIndex])\n",
        "        else:\n",
        "          ThreshIndex.extend(TempIndex)\n",
        "        UnionIndex = list(set(lastIndex).union(set(ThreshIndex)))\n",
        "        # leftAmount = int(len(index)/2)\n",
        "        # leftindex = random.sample(index, leftAmount)\n",
        "        V = EigVec.clone()\n",
        "        for j in UnionIndex:\n",
        "            k = torch.argmax(torch.abs(V[j, :]))\n",
        "            V -= 0.9 * torch.outer(V[:, k] / V[j, k], V[j, :])\n",
        "\n",
        "        return V\n",
        "\n",
        "    def CosSimilarity(self,a,b):\n",
        "        AnormTemp = a.norm(dim=1)[:, None]\n",
        "        AoneTemp = torch.ones_like(AnormTemp)\n",
        "        AnormTemp = torch.where(AnormTemp == 0,AoneTemp,AnormTemp)\n",
        "        Aweight_norm = a / AnormTemp\n",
        "        BnormTemp = b.norm(dim=1)[:, None]\n",
        "        BoneTemp = torch.ones_like(BnormTemp)\n",
        "        BnormTemp = torch.where(BnormTemp == 0,BoneTemp,BnormTemp)\n",
        "        Bweight_norm = b / BnormTemp\n",
        "        cosSim = torch.mm(Aweight_norm, Bweight_norm.transpose(0,1))\n",
        "        return cosSim"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HJZ_70QwLhCq"
      },
      "outputs": [],
      "source": [
        "def MAD(x):\n",
        "  x_norm = x / x.norm(dim=1)[:, None]\n",
        "  dist = 1-torch.mm(x_norm, x_norm.transpose(0,1))\n",
        "  one1 = torch.ones_like(dist)\n",
        "  zreo1 = torch.zeros_like(dist)\n",
        "  dist_1 = torch.where(dist > 0, one1, zreo1)\n",
        "  D = dist.sum(0)/dist_1.sum(0)\n",
        "  one = torch.ones_like(D)\n",
        "  zero = torch.zeros_like(D)\n",
        "  D_1 = torch.where(D > 0, one, zero)\n",
        "  Mad = D.sum()/D_1.sum()\n",
        "  return Mad.detach().numpy()\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0IWd2jfei1W0"
      },
      "outputs": [],
      "source": [
        "def GetNodeOverate(edge_index_list,SelCenNodeList):\n",
        "  overate = []\n",
        "  for j in range(1,len(edge_index_list)):\n",
        "    count = 0\n",
        "    for node in SelCenNodeList:\n",
        "      EdgeIndex1 = torch.nonzero(edge_index_list[j-1][0]==node)\n",
        "      EdgeIndex1 = torch.reshape(EdgeIndex1, (-1,))\n",
        "      EndNode1 = edge_index_list[j-1][1][EdgeIndex1]\n",
        "\n",
        "      EdgeIndex2 = torch.nonzero(edge_index_list[j][0]==node)\n",
        "      EdgeIndex2 = torch.reshape(EdgeIndex2, (-1,))\n",
        "      EndNode2 = edge_index_list[j][1][EdgeIndex2]\n",
        "\n",
        "      for n in EndNode1:\n",
        "        count = count + int(torch.isin(n,EndNode2))\n",
        "    overate.append(count/edge_index_list[j].shape[1])\n",
        "  print(\"Layer Node overate: \", np.average(np.array(overate)))\n",
        "  return np.average(np.array(overate))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5aW-cEPZdHO8"
      },
      "outputs": [],
      "source": [
        "def ComOverate(edge_index_list,LComuMatrix,SelCenNodeList):\n",
        "  overate = []\n",
        "  for j in range(1,len(edge_index_list)):\n",
        "    count = 0\n",
        "    for node in (SelCenNodeList):\n",
        "      EdgeIndex1 = torch.nonzero(edge_index_list[j-1][0]==node)\n",
        "      EdgeIndex1 = torch.reshape(EdgeIndex1, (-1,))\n",
        "      EndNode1 = edge_index_list[j-1][1][EdgeIndex1]\n",
        "\n",
        "      EdgeIndex2 = torch.nonzero(edge_index_list[j][0]==node)\n",
        "      EdgeIndex2 = torch.reshape(EdgeIndex2, (-1,))\n",
        "      EndNode2 = edge_index_list[j][1][EdgeIndex2]\n",
        "\n",
        "      A1com = torch.nonzero(LComuMatrix[:,EndNode1])[:,0]\n",
        "      A2com = torch.nonzero(LComuMatrix[:,EndNode2])[:,0]\n",
        "      for n in A1com:\n",
        "        count = count + int(torch.isin(n,A2com))\n",
        "    # print(\"Layer: %1d/%1d  Comunity Repetition rate: %.4f\" % (l-1,l, count/DPPedge_index_list[l].shape[1]))\n",
        "    overate.append(count/DPPedge_index_list[j].shape[1])\n",
        "  print(\"Layer Community overate: \", np.average(np.array(overate)))\n",
        "  return np.average(np.array(overate))\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "m367EEj6xr4N"
      },
      "outputs": [],
      "source": [
        "def ClusterOverate(edge_index_list,ClusterList,SelCenNodeList):\n",
        "  overate = []\n",
        "  for j in range(1,len(edge_index_list)):\n",
        "    count = 0\n",
        "    for node in (SelCenNodeList):\n",
        "      EdgeIndex1 = torch.nonzero(edge_index_list[j-1][0]==node)\n",
        "      EdgeIndex1 = torch.reshape(EdgeIndex1, (-1,))\n",
        "      EndNode1 = edge_index_list[j-1][1][EdgeIndex1]\n",
        "\n",
        "      EdgeIndex2 = torch.nonzero(edge_index_list[j][0]==node)\n",
        "      EdgeIndex2 = torch.reshape(EdgeIndex2, (-1,))\n",
        "      EndNode2 = edge_index_list[j][1][EdgeIndex2]\n",
        "\n",
        "      A1com = ClusterList[j-1][EndNode1]\n",
        "      A2com = ClusterList[j-1][EndNode2]\n",
        "      for n in A1com:\n",
        "        count = count + int(torch.isin(n,A2com))\n",
        "    overate.append(count/DPPedge_index_list[j].shape[1])\n",
        "  # print(\"Layer Cluste overate:\", np.average(np.array(overate)))\n",
        "  return np.average(np.array(overate))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GiWDAOpOH91j"
      },
      "source": [
        "# 4 Training and Test"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "-oUugkrUmYcw",
        "outputId": "ee1ea857-3f04-4471-f573-a3e88b13eb2d"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "LDGCN(\n",
            "  (CONVs): ModuleList(\n",
            "    (0): GCNConv()\n",
            "    (1): GCNConv()\n",
            "    (2): GCNConv()\n",
            "    (3): GCNConv()\n",
            "  )\n",
            "  (DPPCONVs): ModuleList(\n",
            "    (0): GCNConv()\n",
            "    (1): GCNConv()\n",
            "    (2): GCNConv()\n",
            "    (3): GCNConv()\n",
            "  )\n",
            "  (NegRates): ParameterList(\n",
            "      (0): Parameter containing: [torch.float32 of size 1]\n",
            "      (1): Parameter containing: [torch.float32 of size 1]\n",
            "      (2): Parameter containing: [torch.float32 of size 1]\n",
            "      (3): Parameter containing: [torch.float32 of size 1]\n",
            "  )\n",
            ")\n",
            "Dataset: Cora\n",
            "Number of negative sample: 24\n",
            "Seed: 1\n",
            "---------------\n",
            "DPPedge_index_list torch.Size([2, 117])\n",
            "Epoch time: 0.5037224292755127\n",
            "---------------\n",
            "Layer Node overate:  0.3190883190883191\n",
            "---------------\n",
            "Layer Community overate:  0.6153846153846154\n",
            "---------------\n",
            "Layer True-Cluster overate:  0.7150997150997153\n",
            "Layer 5*Cluster overate:  0.41595441595441596\n",
            "---------------\n",
            "num: 001, Epoch: 020, Loss: 1.5201, Train Acc: 0.6721, Val Acc: 0.4967, Test Acc: 0.4896, MAD: 0.5638, Time: 0.4846\n",
            "num: 001, Epoch: 040, Loss: 0.4305, Train Acc: 0.9344, Val Acc: 0.7211, Test Acc: 0.7552, MAD: 0.7900, Time: 0.4900\n",
            "num: 001, Epoch: 060, Loss: 0.2628, Train Acc: 0.9918, Val Acc: 0.6972, Test Acc: 0.7563, MAD: 0.8037, Time: 0.4807\n",
            "num: 001, Epoch: 080, Loss: 0.2767, Train Acc: 1.0000, Val Acc: 0.7233, Test Acc: 0.7596, MAD: 0.7587, Time: 0.4759\n",
            "num: 001, Epoch: 100, Loss: 0.2033, Train Acc: 1.0000, Val Acc: 0.7407, Test Acc: 0.7596, MAD: 0.7742, Time: 0.4794\n",
            "num: 001, Epoch: 120, Loss: 0.1220, Train Acc: 1.0000, Val Acc: 0.7233, Test Acc: 0.7596, MAD: 0.7854, Time: 0.4804\n",
            "num: 001, Epoch: 140, Loss: 0.2448, Train Acc: 1.0000, Val Acc: 0.7364, Test Acc: 0.7596, MAD: 0.7831, Time: 0.4865\n",
            "num: 001, Epoch: 160, Loss: 0.1421, Train Acc: 0.9918, Val Acc: 0.7255, Test Acc: 0.7596, MAD: 0.7849, Time: 0.4830\n",
            "num: 001, Epoch: 180, Loss: 0.1443, Train Acc: 1.0000, Val Acc: 0.7407, Test Acc: 0.7596, MAD: 0.7687, Time: 0.4898\n",
            "num: 001, Epoch: 200, Loss: 0.0641, Train Acc: 1.0000, Val Acc: 0.7647, Test Acc: 0.7650, MAD: 0.7967, Time: 0.4773\n"
          ]
        }
      ],
      "source": [
        "import sys\n",
        "\n",
        "def test(Adjcandidate,DPPedge_index_list,train_neg_rate):\n",
        "      model.eval()\n",
        "      out,_= model(PyTSubGraph.x, edge_index,Adjcandidate,DPPedge_index_list,train_neg_rate)\n",
        "      pred = out.argmax(dim=1)  # Use the class with highest probability.\n",
        "\n",
        "      accs = []\n",
        "      for mask in [PyTSubGraph.train_mask, PyTSubGraph.val_mask, PyTSubGraph.test_mask]:\n",
        "          accs.append(int((pred[mask] == PyTSubGraph.y[mask]).sum()) / int(mask.sum()))\n",
        "      return accs\n",
        "\n",
        "#----------------\n",
        "# Set training parameters\n",
        "num_layers = 4\n",
        "Runs = 2\n",
        "Epochs = 201\n",
        "# sel_num = 10\n",
        "#--------------------\n",
        "#Set whether NegRate is a trainable parameter\n",
        "#if False:\n",
        "#   NegRate = 1\n",
        "train_neg_rate = False\n",
        "#-----------------\n",
        "\n",
        "lavgNodeOverateList = []\n",
        "lavgComOverateList = []\n",
        "lavgClsOverateList = []\n",
        "\n",
        "for num in range(1,Runs):\n",
        "  seed = num\n",
        "  model = LDGCN(seed = seed,\n",
        "          hidden_channels=16,\n",
        "          num_layers=num_layers,\n",
        "          num_nodes=num_nodes,\n",
        "          DegreeList = DegreeList,\n",
        "          LabelComu=LabelComu,\n",
        "          LComuMatrix=LComuMatrix)\n",
        "\n",
        "\n",
        "  # print(\"num_layers\",num_layers)\n",
        "  optimizer = torch.optim.Adam(model.parameters(), lr=0.02, weight_decay=5e-4)\n",
        "  criterion = torch.nn.CrossEntropyLoss()\n",
        "  begin = 0\n",
        "  if num == 1 :\n",
        "    print(model)\n",
        "\n",
        "  EpochList = []\n",
        "  LossList = []\n",
        "  TrainAccList = []\n",
        "  ValAccList = []\n",
        "  TestAccList = []\n",
        "  MADList =[]\n",
        "\n",
        "  best_val_acc = final_test_acc = 0\n",
        "  print(f'Dataset: {DatasetName}')\n",
        "  print(f'Number of negative sample: {len(SelCenNodeList)}')\n",
        "  print(f'Seed: {seed}')\n",
        "  for epoch in range(begin+1, Epochs):\n",
        "\n",
        "\n",
        "      start = time.time()\n",
        "      DPPedge_index_list = []\n",
        "      ClusterList = []\n",
        "      model.train()\n",
        "      optimizer.zero_grad()  # Clear gradients.\n",
        "\n",
        "      # #Get candidates\n",
        "      if epoch % 1 ==0:\n",
        "        candidates = GetCandidateNode(SPEndNodeDict,SelCenNodeList,5)\n",
        "        Adjcandidate = GetCandiSet(SelCenNodeList,candidates)\n",
        "\n",
        "      out,DPPedge_index_list = model(PyTSubGraph.x, edge_index,Adjcandidate,DPPedge_index_list,train_neg_rate)\n",
        "      loss = criterion(out[PyTSubGraph.train_mask], PyTSubGraph.y[PyTSubGraph.train_mask])  # Compute the loss solely based on the training nodes.\n",
        "      loss.backward()  # Derive gradients.\n",
        "      optimizer.step()\n",
        "      end = time.time()\n",
        "\n",
        "      train_acc, val_acc, tmp_test_acc = test(Adjcandidate,DPPedge_index_list,train_neg_rate)\n",
        "      if val_acc > best_val_acc:\n",
        "          best_val_acc = val_acc\n",
        "          test_acc = tmp_test_acc\n",
        "\n",
        "      EpochList.append(epoch)\n",
        "      LossList.append(loss.detach().numpy())\n",
        "      TrainAccList.append(train_acc)\n",
        "      ValAccList.append(val_acc)\n",
        "      TestAccList.append(test_acc)\n",
        "      madvalues = MAD(out)\n",
        "      MADList.append(madvalues)\n",
        "\n",
        "      if epoch == 1 or epoch == begin+1:\n",
        "          print(\"---------------\")\n",
        "          print(\"DPPedge_index_list\",DPPedge_index_list[0].shape)\n",
        "          print(\"Epoch time:\",end-start)\n",
        "          print(\"---------------\")\n",
        "          lavgNodeOverate  = GetNodeOverate(DPPedge_index_list,SelCenNodeList)\n",
        "          lavgNodeOverateList.append(lavgNodeOverate)\n",
        "          print(\"---------------\")\n",
        "          lavgComOverate = ComOverate(DPPedge_index_list,LComuMatrix,SelCenNodeList)\n",
        "          lavgComOverateList.append(lavgComOverate)\n",
        "          print(\"---------------\")\n",
        "          lavgTClsOverate = ClusterOverate(DPPedge_index_list,model.TrueClusterList,SelCenNodeList)\n",
        "          print(\"Layer True-Cluster overate: \", lavgTClsOverate)\n",
        "          lavgMulClsOverate = ClusterOverate(DPPedge_index_list,model.MultiClusterList,SelCenNodeList)\n",
        "          print(\"Layer 5*Cluster overate: \", lavgMulClsOverate)\n",
        "          print(\"---------------\")\n",
        "\n",
        "      if epoch % 20 ==0:\n",
        "          print(f'num: {num:03d}, Epoch: {epoch:03d}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}, MAD: {madvalues:.4f}, Time: {end-start:.4f}')"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "dR5F_DpFFOR0",
        "qhip7suoFsGG",
        "qFCZiG_IHvSP"
      ],
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}