{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5c519b26",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import networkx as nx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "17c27b50",
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_labels_to_consecutive_integers(labels):\n",
    "    unique_labels = np.unique(labels)\n",
    "    labels_map = {label: i for i, label in enumerate(unique_labels)}\n",
    "    new_labels = np.array([labels_map[label] for label in labels])\n",
    "    \n",
    "    return new_labels\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c6b11035",
   "metadata": {},
   "source": [
    "Below are functions for computing graph metrics used in the paper \"Characterizing Graph Datasets for Node Classification: Beyond Homophily-Heterophily Dichotomy\". Each function takes a NetworkX graph and a NumPy array of labels as inputs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b45e5688",
   "metadata": {},
   "outputs": [],
   "source": [
    "def h_edge(graph, labels):\n",
    "    edges_with_same_label = 0\n",
    "    for u, v in graph.edges:\n",
    "        if labels[u] == labels[v]:\n",
    "            edges_with_same_label += 1\n",
    "\n",
    "    h_edge = edges_with_same_label / len(graph.edges)\n",
    "    \n",
    "    return h_edge\n",
    "\n",
    "\n",
    "def h_node(graph, labels):\n",
    "    h_sum = 0\n",
    "    num_zero_degree_nodes = 0\n",
    "    for u in graph.nodes():\n",
    "        if graph.degree(u) == 0:\n",
    "            num_zero_degree_nodes += 1\n",
    "            continue\n",
    "        \n",
    "        neighbors_with_same_label = 0\n",
    "        for v in graph.neighbors(u):\n",
    "            if labels[u] == labels[v]:\n",
    "                neighbors_with_same_label += 1\n",
    "\n",
    "        cur_node_h = neighbors_with_same_label / graph.degree(u)\n",
    "        h_sum += cur_node_h\n",
    "\n",
    "    h_node = h_sum / (len(graph.nodes) - num_zero_degree_nodes)\n",
    "    \n",
    "    return h_node\n",
    "\n",
    "\n",
    "def h_class(graph, labels):\n",
    "    labels = convert_labels_to_consecutive_integers(labels)\n",
    "    \n",
    "    num_classes = len(np.unique(labels))\n",
    "    num_nodes = np.array([0 for _ in range(num_classes)])\n",
    "    numerator = np.array([0 for _ in range(num_classes)])\n",
    "    denominator = np.array([0 for _ in range(num_classes)])\n",
    "    for u in graph.nodes:\n",
    "        label = labels[u]\n",
    "        cur_numerator = 0\n",
    "        for v in graph.neighbors(u):\n",
    "            if label == labels[v]:\n",
    "                cur_numerator += 1\n",
    "\n",
    "        numerator[label] += cur_numerator\n",
    "        denominator[label] += graph.degree(u)\n",
    "        num_nodes[label] += 1\n",
    "\n",
    "    h = numerator / denominator\n",
    "    h -= num_nodes / len(graph.nodes)\n",
    "    h = np.maximum(h, 0)\n",
    "    h_class = h.sum() / (num_classes - 1)\n",
    "    \n",
    "    return h_class\n",
    "\n",
    "\n",
    "def h_adj(graph, labels):\n",
    "    labels = convert_labels_to_consecutive_integers(labels)\n",
    "    \n",
    "    num_classes = len(np.unique(labels))\n",
    "    \n",
    "    degree_sums = np.zeros((num_classes,))\n",
    "    for u in graph.nodes:\n",
    "        label = labels[u]\n",
    "        degree_sums[label] += graph.degree(u)\n",
    "\n",
    "    adjust = (degree_sums ** 2 / (len(graph.edges) * 2) ** 2).sum()\n",
    "\n",
    "    h_adj = (h_edge(graph, labels) - adjust) / (1 - adjust)\n",
    "    \n",
    "    return h_adj\n",
    "\n",
    "\n",
    "def li_edge(graph, labels, eps=1e-8):\n",
    "    labels = convert_labels_to_consecutive_integers(labels)\n",
    "    \n",
    "    num_classes = len(np.unique(labels))\n",
    "    \n",
    "    class_probs = np.array([0 for _ in range(num_classes)], dtype=float)\n",
    "    class_degree_weighted_probs = np.array([0 for _ in range(num_classes)], dtype=float)\n",
    "    for u in graph.nodes:\n",
    "        label = labels[u]\n",
    "        class_probs[label] += 1\n",
    "        class_degree_weighted_probs[label] += graph.degree(u)\n",
    "\n",
    "    class_probs /= class_probs.sum()\n",
    "    class_degree_weighted_probs /= class_degree_weighted_probs.sum()\n",
    "\n",
    "    edge_probs = np.zeros((num_classes, num_classes))\n",
    "    for u, v in graph.edges:\n",
    "        label_u = labels[u]\n",
    "        label_v = labels[v]\n",
    "        edge_probs[label_u, label_v] += 1\n",
    "        edge_probs[label_v, label_u] += 1\n",
    "\n",
    "    edge_probs /= edge_probs.sum()\n",
    "    \n",
    "    edge_probs += eps\n",
    "\n",
    "    numerator = (edge_probs * np.log(edge_probs)).sum()\n",
    "    denominator = (class_degree_weighted_probs * np.log(class_degree_weighted_probs)).sum()\n",
    "    li_edge = 2 - numerator / denominator\n",
    "\n",
    "    return li_edge\n",
    "\n",
    "\n",
    "def li_node(graph, labels, eps=1e-8):\n",
    "    labels = convert_labels_to_consecutive_integers(labels)\n",
    "    \n",
    "    num_classes = len(np.unique(labels))\n",
    "    \n",
    "    class_probs = np.array([0 for _ in range(num_classes)], dtype=float)\n",
    "    class_degree_weighted_probs = np.array([0 for _ in range(num_classes)], dtype=float)\n",
    "    num_zero_degree_nodes = 0\n",
    "    for u in graph.nodes:\n",
    "        if graph.degree(u) == 0:\n",
    "            num_zero_degree_nodes += 1\n",
    "            continue\n",
    "        \n",
    "        label = labels[u]\n",
    "        class_probs[label] += 1\n",
    "        class_degree_weighted_probs[label] += graph.degree(u)\n",
    "\n",
    "    class_probs /= class_probs.sum()\n",
    "    class_degree_weighted_probs /= class_degree_weighted_probs.sum()\n",
    "    num_nonzero_degree_nodes = len(graph.nodes) - num_zero_degree_nodes\n",
    "    \n",
    "    \n",
    "    edge_probs = np.zeros((num_classes, num_classes))\n",
    "    for u, v in graph.edges:\n",
    "        label_u = labels[u]\n",
    "        label_v = labels[v]\n",
    "        edge_probs[label_u, label_v] += 1 / (num_nonzero_degree_nodes * graph.degree(u))\n",
    "        edge_probs[label_v, label_u] += 1 / (num_nonzero_degree_nodes * graph.degree(v))\n",
    "    \n",
    "    edge_probs += eps\n",
    "    \n",
    "    log = np.log(edge_probs / (class_probs.reshape(-1, 1) * class_degree_weighted_probs.reshape(1, -1)))\n",
    "    numerator = (edge_probs * log).sum()\n",
    "    denominator = (class_probs * np.log(class_probs)).sum()\n",
    "    li_node = - numerator / denominator\n",
    "    \n",
    "    return li_node\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e12d1839",
   "metadata": {},
   "source": [
    "Below we provide an example of computing metrics for the popular Cora dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e897905f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric import datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5e54d25b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset = datasets.Planetoid(root='../data', name='cora')[0]\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "cdf4b22b",
   "metadata": {},
   "outputs": [],
   "source": [
    "graph = nx.Graph()\n",
    "graph.add_nodes_from(range(len(dataset.x)))\n",
    "graph.add_edges_from(dataset.edge_index.T.numpy())\n",
    "\n",
    "labels = dataset.y.numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4645b8b3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2708"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(graph.nodes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "57837abf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "5278"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(graph.edges)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "8722bd04",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8099658961727927"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "h_edge(graph, labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "c8716f18",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8251578275927919"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "h_node(graph, labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "30d74e72",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.765718101454461"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "h_class(graph, labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "83408625",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7710854223002092"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "h_adj(graph, labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "8660c8bd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5903729565335607"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "li_edge(graph, labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "98ccf687",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6137828068999315"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "li_node(graph, labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "feca3783",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
