{
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "\\# **Notebook for reproducing CCAM results**"
      ],
      "metadata": {
        "id": "nWhFueeANvHz"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "SybjTKqRT6Pb",
        "outputId": "e2cf0ac1-5595-43d3-97b6-fef217ecdc5d"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "2.0.1+cu118\n",
            "  Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
            "  Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
            "  Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
            "  Building wheel for torch_geometric (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n"
          ]
        }
      ],
      "source": [
        "# Install required packages.\n",
        "import os\n",
        "import torch\n",
        "os.environ['TORCH'] = torch.__version__\n",
        "print(torch.__version__)\n",
        "\n",
        "#!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html\n",
        "#!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html\n",
        "!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git\n",
        "\n",
        "# Helper function for visualization.\n",
        "%matplotlib inline\n",
        "import matplotlib.pyplot as plt\n",
        "from sklearn.manifold import TSNE\n",
        "\n",
        "def visualize(h, color):\n",
        "    z = TSNE(n_components=2).fit_transform(h.detach().cpu().numpy())\n",
        "\n",
        "    plt.figure(figsize=(10,10))\n",
        "    plt.xticks([])\n",
        "    plt.yticks([])\n",
        "\n",
        "    plt.scatter(z[:, 0], z[:, 1], s=70, c=color, cmap=\"Set2\")\n",
        "    plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [],
      "metadata": {
        "id": "b0wXbAtbNrog"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 283
        },
        "id": "IVkhYxXnUJdA",
        "outputId": "0eb86783-93ab-4112-ae00-de629a3168e1"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting scipy==1.8.1\n",
            "  Downloading scipy-1.8.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (42.2 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m42.2/42.2 MB\u001b[0m \u001b[31m37.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: numpy<1.25.0,>=1.17.3 in /usr/local/lib/python3.10/dist-packages (from scipy==1.8.1) (1.23.5)\n",
            "Installing collected packages: scipy\n",
            "  Attempting uninstall: scipy\n",
            "    Found existing installation: scipy 1.11.2\n",
            "    Uninstalling scipy-1.11.2:\n",
            "      Successfully uninstalled scipy-1.11.2\n",
            "Successfully installed scipy-1.8.1\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.colab-display-data+json": {
              "pip_warning": {
                "packages": [
                  "scipy"
                ]
              }
            }
          },
          "metadata": {}
        }
      ],
      "source": [
        "!pip install scipy==1.8.1"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "G4mmDC2IUNB5",
        "outputId": "de166d25-421d-4c6b-8299-5896d0cc00ec"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mounted at /content/drive\n"
          ]
        }
      ],
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/drive')"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from torch_geometric.datasets import Planetoid, Amazon,WebKB,FB15k_237,HeterophilousGraphDataset,WikipediaNetwork,CitationFull,Actor\n",
        "from torch_geometric.transforms import NormalizeFeatures\n",
        "from torch_geometric.utils import homophily,add_self_loops, is_undirected,to_networkx,from_networkx,to_undirected, to_dense_adj, dense_to_sparse\n",
        "import time\n"
      ],
      "metadata": {
        "id": "BlXvVa2xygTh"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "An example from the Roman-Empire dataset"
      ],
      "metadata": {
        "id": "HMqwbMW9-ZTi"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "J1zUL5K6UPAf",
        "outputId": "1c37a2c2-e892-42a0-dda9-0ed2765e6a5a"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Downloading https://github.com/yandex-research/heterophilous-graphs/raw/main/data/roman_empire.npz\n",
            "Processing...\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n",
            "Dataset: HeterophilousGraphDataset(name=roman_empire):\n",
            "======================\n",
            "Number of graphs: 1\n",
            "Number of features: 300\n",
            "Number of classes: 18\n",
            "\n",
            "Data(x=[22662, 300], edge_index=[2, 65854], y=[22662], train_mask=[22662, 10], val_mask=[22662, 10], test_mask=[22662, 10])\n",
            "===========================================================================================================\n",
            "Number of nodes: 22662\n",
            "Number of edges: 65854\n",
            "Average node degree: 2.91\n",
            "Number of training nodes: 113310\n",
            "Training node label rate: 5.00\n",
            "Has isolated nodes: False\n",
            "Has self-loops: False\n",
            "Is undirected: True\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Done!\n"
          ]
        }
      ],
      "source": [
        "\n",
        "#dataset = Planetoid(root='data/Planetoid', name='cora', transform=NormalizeFeatures())\n",
        "#dataset = Amazon(root='data/amazon', name='photo', transform=NormalizeFeatures())\n",
        "#dataset = WebKB(root='data/', name='Texas',transform = NormalizeFeatures())\n",
        "#dataset = CitationFull(root='data/Citeseer', name='Citeseer', to_undirected = False)#\n",
        "#dataset = WikipediaNetwork(root='data/WikipediaNetwork', name='squirrel')\n",
        "dataset = HeterophilousGraphDataset(root='data/Roman-empire', name='Roman-empire')\n",
        "#dataset = Actor(root='data/Actor')\n",
        "\n",
        "print()\n",
        "print(f'Dataset: {dataset}:')\n",
        "print('======================')\n",
        "print(f'Number of graphs: {len(dataset)}')\n",
        "print(f'Number of features: {dataset.num_features}')\n",
        "print(f'Number of classes: {dataset.num_classes}')\n",
        "\n",
        "data = dataset[0]  # Get the first graph object.\n",
        "\n",
        "print()\n",
        "print(data)\n",
        "print('===========================================================================================================')\n",
        "\n",
        "# Gather some statistics about the graph.\n",
        "print(f'Number of nodes: {data.num_nodes}')\n",
        "print(f'Number of edges: {data.num_edges}')\n",
        "print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')\n",
        "\n",
        "\n",
        "print(f'Number of training nodes: {data.train_mask.sum()}')\n",
        "print(f'Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.2f}')\n",
        "print(f'Has isolated nodes: {data.has_isolated_nodes()}')\n",
        "print(f'Has self-loops: {data.has_self_loops()}')\n",
        "print(f'Is undirected: {data.is_undirected()}')\n",
        "\n",
        "data_und = to_undirected(data.edge_index)\n",
        "data_und,_= add_self_loops(data_und, num_nodes=data.num_nodes)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Import the libraries needed to calculate the curvature"
      ],
      "metadata": {
        "id": "2hbMk7uAx4sA"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 411
        },
        "id": "C3EIddJ2UUud",
        "outputId": "3f92bce7-d668-4feb-95b6-9ea417ffcf2c"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting GraphRicciCurvature\n",
            "  Downloading GraphRicciCurvature-0.5.3.1-py3-none-any.whl (23 kB)\n",
            "Requirement already satisfied: cython in /usr/local/lib/python3.10/dist-packages (from GraphRicciCurvature) (3.0.2)\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from GraphRicciCurvature) (1.23.5)\n",
            "Requirement already satisfied: scipy>=1.0 in /usr/local/lib/python3.10/dist-packages (from GraphRicciCurvature) (1.8.1)\n",
            "Requirement already satisfied: networkx>=2.0 in /usr/local/lib/python3.10/dist-packages (from GraphRicciCurvature) (3.1)\n",
            "Collecting pot>=0.8.0 (from GraphRicciCurvature)\n",
            "  Downloading POT-0.9.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (789 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m790.0/790.0 kB\u001b[0m \u001b[31m20.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from GraphRicciCurvature) (23.1)\n",
            "Collecting networkit>=6.1 (from GraphRicciCurvature)\n",
            "  Downloading networkit-10.1-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (9.8 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m9.8/9.8 MB\u001b[0m \u001b[31m7.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: python-louvain in /usr/local/lib/python3.10/dist-packages (from GraphRicciCurvature) (0.16)\n",
            "Installing collected packages: pot, networkit, GraphRicciCurvature\n",
            "Successfully installed GraphRicciCurvature-0.5.3.1 networkit-10.1 pot-0.9.1\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "\n",
              "\t\t\t<script type=\"text/javascript\">\n",
              "\t\t\t<!--\n",
              "\t\t\t\t\n",
              "\t\t\t{\n",
              "\t\t\t\tvar element = document.getElementById('NetworKit_script');\n",
              "\t\t\t\tif (element) {\n",
              "\t\t\t\t\telement.parentNode.removeChild(element);\n",
              "\t\t\t\t}\n",
              "\t\t\t\telement = document.createElement('script');\n",
              "\t\t\t\telement.type = 'text/javascript';\n",
              "\t\t\t\telement.innerHTML = 'function NetworKit_pageEmbed(id) { var i, j; var elements; elements = document.getElementById(id).getElementsByClassName(\"Plot\"); for (i=0; i<elements.length; i++) { elements[i].id = id + \"_Plot_\" + i; var data = elements[i].getAttribute(\"data-image\").split(\"|\"); elements[i].removeAttribute(\"data-image\"); var content = \"<div class=\\\\\"Image\\\\\" id=\\\\\"\" + elements[i].id + \"_Image\\\\\" />\"; elements[i].innerHTML = content; elements[i].setAttribute(\"data-image-index\", 0); elements[i].setAttribute(\"data-image-length\", data.length); for (j=0; j<data.length; j++) { elements[i].setAttribute(\"data-image-\" + j, data[j]); } NetworKit_plotUpdate(elements[i]); elements[i].onclick = function (e) { NetworKit_overlayShow((e.target) ? e.target : e.srcElement); } } elements = document.getElementById(id).getElementsByClassName(\"HeatCell\"); for (i=0; i<elements.length; i++) { var data = parseFloat(elements[i].getAttribute(\"data-heat\")); var color = \"#00FF00\"; if (data <= 1 && data > 0) { color = \"hsla(0, 100%, 75%, \" + (data) + \")\"; } else if (data <= 0 && data >= -1) { color = \"hsla(240, 100%, 75%, \" + (-data) + \")\"; } elements[i].style.backgroundColor = color; } elements = document.getElementById(id).getElementsByClassName(\"Details\"); for (i=0; i<elements.length; i++) { elements[i].setAttribute(\"data-title\", \"-\"); NetworKit_toggleDetails(elements[i]); elements[i].onclick = function (e) { NetworKit_toggleDetails((e.target) ? e.target : e.srcElement); } } elements = document.getElementById(id).getElementsByClassName(\"MathValue\"); for (i=elements.length-1; i>=0; i--) { value = elements[i].innerHTML.trim(); if (value === \"nan\") { elements[i].parentNode.innerHTML = \"\" } } elements = document.getElementById(id).getElementsByClassName(\"SubCategory\"); for (i=elements.length-1; i>=0; i--) { value = elements[i].innerHTML.trim(); if (value === \"\") { elements[i].parentNode.removeChild(elements[i]) } } elements = document.getElementById(id).getElementsByClassName(\"Category\"); for (i=elements.length-1; i>=0; i--) { value = elements[i].innerHTML.trim(); if (value === \"\") { elements[i].parentNode.removeChild(elements[i]) } } var isFirefox = false; try { isFirefox = typeof InstallTrigger !== \"undefined\"; } catch (e) {} if (!isFirefox) { alert(\"Currently the function\\'s output is only fully supported by Firefox.\"); } } function NetworKit_plotUpdate(source) { var index = source.getAttribute(\"data-image-index\"); var data = source.getAttribute(\"data-image-\" + index); var image = document.getElementById(source.id + \"_Image\"); image.style.backgroundImage = \"url(\" + data + \")\"; } function NetworKit_showElement(id, show) { var element = document.getElementById(id); element.style.display = (show) ? \"block\" : \"none\"; } function NetworKit_overlayShow(source) { NetworKit_overlayUpdate(source); NetworKit_showElement(\"NetworKit_Overlay\", true); } function NetworKit_overlayUpdate(source) { document.getElementById(\"NetworKit_Overlay_Title\").innerHTML = source.title; var index = source.getAttribute(\"data-image-index\"); var data = source.getAttribute(\"data-image-\" + index); var image = document.getElementById(\"NetworKit_Overlay_Image\"); image.setAttribute(\"data-id\", source.id); image.style.backgroundImage = \"url(\" + data + \")\"; var link = document.getElementById(\"NetworKit_Overlay_Toolbar_Bottom_Save\"); link.href = data; link.download = source.title + \".svg\"; } function NetworKit_overlayImageShift(delta) { var image = document.getElementById(\"NetworKit_Overlay_Image\"); var source = document.getElementById(image.getAttribute(\"data-id\")); var index = parseInt(source.getAttribute(\"data-image-index\")); var length = parseInt(source.getAttribute(\"data-image-length\")); var index = (index+delta) % length; if (index < 0) { index = length + index; } source.setAttribute(\"data-image-index\", index); NetworKit_overlayUpdate(source); } function NetworKit_toggleDetails(source) { var childs = source.children; var show = false; if (source.getAttribute(\"data-title\") == \"-\") { source.setAttribute(\"data-title\", \"+\"); show = false; } else { source.setAttribute(\"data-title\", \"-\"); show = true; } for (i=0; i<childs.length; i++) { if (show) { childs[i].style.display = \"block\"; } else { childs[i].style.display = \"none\"; } } }';\n",
              "\t\t\t\telement.setAttribute('id', 'NetworKit_script');\n",
              "\t\t\t\tdocument.head.appendChild(element);\n",
              "\t\t\t}\n",
              "\t\t\n",
              "\t\t\t\t\n",
              "\t\t\t{\n",
              "\t\t\t\tvar element = document.getElementById('NetworKit_style');\n",
              "\t\t\t\tif (element) {\n",
              "\t\t\t\t\telement.parentNode.removeChild(element);\n",
              "\t\t\t\t}\n",
              "\t\t\t\telement = document.createElement('style');\n",
              "\t\t\t\telement.type = 'text/css';\n",
              "\t\t\t\telement.innerHTML = '.NetworKit_Page { font-family: Arial, Helvetica, sans-serif; font-size: 14px; } .NetworKit_Page .Value:before { font-family: Arial, Helvetica, sans-serif; font-size: 1.05em; content: attr(data-title) \":\"; margin-left: -2.5em; padding-right: 0.5em; } .NetworKit_Page .Details .Value:before { display: block; } .NetworKit_Page .Value { font-family: monospace; white-space: pre; padding-left: 2.5em; white-space: -moz-pre-wrap !important; white-space: -pre-wrap; white-space: -o-pre-wrap; white-space: pre-wrap; word-wrap: break-word; tab-size: 4; -moz-tab-size: 4; } .NetworKit_Page .Category { clear: both; padding-left: 1em; margin-bottom: 1.5em; } .NetworKit_Page .Category:before { content: attr(data-title); font-size: 1.75em; display: block; margin-left: -0.8em; margin-bottom: 0.5em; } .NetworKit_Page .SubCategory { margin-bottom: 1.5em; padding-left: 1em; } .NetworKit_Page .SubCategory:before { font-size: 1.6em; display: block; margin-left: -0.8em; margin-bottom: 0.5em; } .NetworKit_Page .SubCategory[data-title]:before { content: attr(data-title); } .NetworKit_Page .Block { display: block; } .NetworKit_Page .Block:after { content: \".\"; visibility: hidden; display: block; height: 0; clear: both; } .NetworKit_Page .Block .Thumbnail_Overview, .NetworKit_Page .Block .Thumbnail_ScatterPlot { width: 260px; float: left; } .NetworKit_Page .Block .Thumbnail_Overview img, .NetworKit_Page .Block .Thumbnail_ScatterPlot img { width: 260px; } .NetworKit_Page .Block .Thumbnail_Overview:before, .NetworKit_Page .Block .Thumbnail_ScatterPlot:before { display: block; text-align: center; font-weight: bold; } .NetworKit_Page .Block .Thumbnail_Overview:before { content: attr(data-title); } .NetworKit_Page .HeatCell { font-family: \"Courier New\", Courier, monospace; cursor: pointer; } .NetworKit_Page .HeatCell, .NetworKit_Page .HeatCellName { display: inline; padding: 0.1em; margin-right: 2px; background-color: #FFFFFF } .NetworKit_Page .HeatCellName { margin-left: 0.25em; } .NetworKit_Page .HeatCell:before { content: attr(data-heat); display: inline-block; color: #000000; width: 4em; text-align: center; } .NetworKit_Page .Measure { clear: both; } .NetworKit_Page .Measure .Details { cursor: pointer; } .NetworKit_Page .Measure .Details:before { content: \"[\" attr(data-title) \"]\"; display: block; } .NetworKit_Page .Measure .Details .Value { border-left: 1px dotted black; margin-left: 0.4em; padding-left: 3.5em; pointer-events: none; } .NetworKit_Page .Measure .Details .Spacer:before { content: \".\"; opacity: 0.0; pointer-events: none; } .NetworKit_Page .Measure .Plot { width: 440px; height: 440px; cursor: pointer; float: left; margin-left: -0.9em; margin-right: 20px; } .NetworKit_Page .Measure .Plot .Image { background-repeat: no-repeat; background-position: center center; background-size: contain; height: 100%; pointer-events: none; } .NetworKit_Page .Measure .Stat { width: 500px; float: left; } .NetworKit_Page .Measure .Stat .Group { padding-left: 1.25em; margin-bottom: 0.75em; } .NetworKit_Page .Measure .Stat .Group .Title { font-size: 1.1em; display: block; margin-bottom: 0.3em; margin-left: -0.75em; border-right-style: dotted; border-right-width: 1px; border-bottom-style: dotted; border-bottom-width: 1px; background-color: #D0D0D0; padding-left: 0.2em; } .NetworKit_Page .Measure .Stat .Group .List { -webkit-column-count: 3; -moz-column-count: 3; column-count: 3; } .NetworKit_Page .Measure .Stat .Group .List .Entry { position: relative; line-height: 1.75em; } .NetworKit_Page .Measure .Stat .Group .List .Entry[data-tooltip]:before { position: absolute; left: 0; top: -40px; background-color: #808080; color: #ffffff; height: 30px; line-height: 30px; border-radius: 5px; padding: 0 15px; content: attr(data-tooltip); white-space: nowrap; display: none; } .NetworKit_Page .Measure .Stat .Group .List .Entry[data-tooltip]:after { position: absolute; left: 15px; top: -10px; border-top: 7px solid #808080; border-left: 7px solid transparent; border-right: 7px solid transparent; content: \"\"; display: none; } .NetworKit_Page .Measure .Stat .Group .List .Entry[data-tooltip]:hover:after, .NetworKit_Page .Measure .Stat .Group .List .Entry[data-tooltip]:hover:before { display: block; } .NetworKit_Page .Measure .Stat .Group .List .Entry .MathValue { font-family: \"Courier New\", Courier, monospace; } .NetworKit_Page .Measure:after { content: \".\"; visibility: hidden; display: block; height: 0; clear: both; } .NetworKit_Page .PartitionPie { clear: both; } .NetworKit_Page .PartitionPie img { width: 600px; } #NetworKit_Overlay { left: 0px; top: 0px; display: none; position: absolute; width: 100%; height: 100%; background-color: rgba(0,0,0,0.6); z-index: 1000; } #NetworKit_Overlay_Title { position: absolute; color: white; transform: rotate(-90deg); width: 32em; height: 32em; padding-right: 0.5em; padding-top: 0.5em; text-align: right; font-size: 40px; } #NetworKit_Overlay .button { background: white; cursor: pointer; } #NetworKit_Overlay .button:before { size: 13px; display: inline-block; text-align: center; margin-top: 0.5em; margin-bottom: 0.5em; width: 1.5em; height: 1.5em; } #NetworKit_Overlay .icon-close:before { content: \"X\"; } #NetworKit_Overlay .icon-previous:before { content: \"P\"; } #NetworKit_Overlay .icon-next:before { content: \"N\"; } #NetworKit_Overlay .icon-save:before { content: \"S\"; } #NetworKit_Overlay_Toolbar_Top, #NetworKit_Overlay_Toolbar_Bottom { position: absolute; width: 40px; right: 13px; text-align: right; z-index: 1100; } #NetworKit_Overlay_Toolbar_Top { top: 0.5em; } #NetworKit_Overlay_Toolbar_Bottom { Bottom: 0.5em; } #NetworKit_Overlay_ImageContainer { position: absolute; top: 5%; left: 5%; height: 90%; width: 90%; background-repeat: no-repeat; background-position: center center; background-size: contain; } #NetworKit_Overlay_Image { height: 100%; width: 100%; background-repeat: no-repeat; background-position: center center; background-size: contain; }';\n",
              "\t\t\t\telement.setAttribute('id', 'NetworKit_style');\n",
              "\t\t\t\tdocument.head.appendChild(element);\n",
              "\t\t\t}\n",
              "\t\t\n",
              "\t\t\t\t\n",
              "\t\t\t{\n",
              "\t\t\t\tvar element = document.getElementById('NetworKit_Overlay');\n",
              "\t\t\t\tif (element) {\n",
              "\t\t\t\t\telement.parentNode.removeChild(element);\n",
              "\t\t\t\t}\n",
              "\t\t\t\telement = document.createElement('div');\n",
              "\t\t\t\telement.innerHTML = '<div id=\"NetworKit_Overlay_Toolbar_Top\"><div class=\"button icon-close\" id=\"NetworKit_Overlay_Close\" /></div><div id=\"NetworKit_Overlay_Title\" /> <div id=\"NetworKit_Overlay_ImageContainer\"> <div id=\"NetworKit_Overlay_Image\" /> </div> <div id=\"NetworKit_Overlay_Toolbar_Bottom\"> <div class=\"button icon-previous\" onclick=\"NetworKit_overlayImageShift(-1)\" /> <div class=\"button icon-next\" onclick=\"NetworKit_overlayImageShift(1)\" /> <a id=\"NetworKit_Overlay_Toolbar_Bottom_Save\"><div class=\"button icon-save\" /></a> </div>';\n",
              "\t\t\t\telement.setAttribute('id', 'NetworKit_Overlay');\n",
              "\t\t\t\tdocument.body.appendChild(element);\n",
              "\t\t\t\tdocument.getElementById('NetworKit_Overlay_Close').onclick = function (e) {\n",
              "\t\t\t\t\tdocument.getElementById('NetworKit_Overlay').style.display = 'none';\n",
              "\t\t\t\t}\n",
              "\t\t\t}\n",
              "\t\t\n",
              "\t\t\t-->\n",
              "\t\t\t</script>\n",
              "\t\t"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/ot/backend.py:1368: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in array is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.\n",
            "  jax.device_put(jnp.array(1, dtype=jnp.float64), d)\n",
            "/usr/local/lib/python3.10/dist-packages/ot/backend.py:2998: UserWarning: To use TensorflowBackend, you need to activate the tensorflow numpy API. You can activate it by running: \n",
            "from tensorflow.python.ops.numpy_ops import np_config\n",
            "np_config.enable_numpy_behavior()\n",
            "  register_backend(TensorflowBackend())\n"
          ]
        }
      ],
      "source": [
        "!pip install GraphRicciCurvature\n",
        "from GraphRicciCurvature.OllivierRicci import OllivierRicci\n",
        "import networkx as nx\n",
        "# load GraphRicciCuravture package\n",
        "from GraphRicciCurvature.OllivierRicci import OllivierRicci\n",
        "from GraphRicciCurvature.FormanRicci import FormanRicci\n",
        "\n",
        "\n",
        "import networkx as nx\n",
        "import numpy as np\n",
        "import math\n",
        "%matplotlib inline\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# to print logs in jupyter notebook\n",
        "import logging\n",
        "logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.ERROR)\n",
        "\n",
        "# load GraphRicciCuravture package\n",
        "from GraphRicciCurvature.OllivierRicci import OllivierRicci\n",
        "from GraphRicciCurvature.FormanRicci import FormanRicci"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Calculate curvature according to Ollivier"
      ],
      "metadata": {
        "id": "2h-bYI-8-ngo"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "OVPYRvLtUswc",
        "outputId": "bc08efe5-d2b1-4924-f849-b07fef817fb0"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:GraphRicciCurvature:Self-loop edge detected. Removing 22662 self-loop edges.\n",
            "TRACE:GraphRicciCurvature:Number of nodes: 22662\n",
            "TRACE:GraphRicciCurvature:Number of edges: 32927\n",
            "TRACE:GraphRicciCurvature:Start to compute all pair shortest path.\n",
            "TRACE:GraphRicciCurvature:32.199761 secs for all pair by NetworKit.\n",
            "INFO:GraphRicciCurvature:3.637583 secs for Ricci curvature computation.\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "71.98827242851257\n"
          ]
        }
      ],
      "source": [
        "A = torch.squeeze(to_dense_adj(data_und)).numpy()\n",
        "G=nx.Graph(A)\n",
        "start_time = time.time()\n",
        "curvature = \"ricciCurvature\"\n",
        "orc = OllivierRicci(G, alpha=0.5, verbose=\"TRACE\")\n",
        "start_time = time.time()\n",
        "orc.compute_ricci_curvature()\n",
        "G_curv = orc.G.copy()\n",
        "print(time.time() - start_time)\n",
        "\n",
        "\n",
        "mat_ricci=np.zeros((data.num_nodes,data.num_nodes))\n",
        "mat_riccim=np.zeros((data.num_nodes,data.num_nodes))\n",
        "\n",
        "\n",
        "for a,i in enumerate(np.unique(G_curv)):\n",
        "    for j in list(G_curv.neighbors(a)):\n",
        "        if G_curv[i][j][curvature]>=0 :\n",
        "            mat_ricci[a][j]=G_curv[i][j][curvature]\n",
        "        if G_curv[i][j][curvature]<=0 :\n",
        "            mat_riccim[a][j]=G_curv[i][j][curvature]\n",
        "        if G_curv[i][j][curvature]==0 :\n",
        "            mat_riccim[a][j]=G_curv[i][j][curvature]+0.01\n",
        "            #mat_ricci[a][j]=G_curv[i][j][curvature]+0.01"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OFc8fiAiU2tA"
      },
      "outputs": [],
      "source": [
        "curvp = torch.Tensor(mat_ricci)\n",
        "curvm = torch.Tensor(mat_riccim)\n",
        "curvp = curvp.nonzero().t().contiguous()\n",
        "curvm = curvm.nonzero().t().contiguous()\n",
        "\n",
        "edge_index_curvp,_ = add_self_loops(curvp, num_nodes=data.num_nodes)\n",
        "edge_index_curvm,_ = add_self_loops(curvm, num_nodes=data.num_nodes)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Measurement of homophily on edges presented by Zhu et al. and the measurement Curvature-Constrained homophily"
      ],
      "metadata": {
        "id": "Jji-UdFJ-shs"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "fThbWL5PjLDl",
        "outputId": "58f6964c-fa06-493e-81b7-614b14ac8651"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "0.2909078598022461\n",
            "0.3991025388240814\n",
            "0.47972702980041504\n"
          ]
        }
      ],
      "source": [
        "print(homophily(data_und, data.y))\n",
        "print(homophily(edge_index_curvp, data.y))\n",
        "print(homophily(edge_index_curvm, data.y))\n"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Possibly building a two-hop based on curvature."
      ],
      "metadata": {
        "id": "Bb6PsH0C0mQG"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7smMb0dX3_OG"
      },
      "outputs": [],
      "source": [
        "from torch_geometric.utils import add_self_loops\n",
        "# Initialiser une liste vide pour stocker les nouveaux index des arêtes\n",
        "num_nodes = data.num_nodes\n",
        "new_edge_index = []\n",
        "\n",
        "for node in range(num_nodes):\n",
        "    neighbors = edge_index_curvm[1][edge_index_curvm[0] == node]\n",
        "\n",
        "    for neighbor in neighbors:\n",
        "        neighbors_of_neighbor = edge_index_curvm[1][edge_index_curvm[0] == neighbor]\n",
        "\n",
        "        neighbors_of_neighbor = neighbors_of_neighbor[neighbors_of_neighbor != node]\n",
        "\n",
        "        for neighbor_of_neighbor in neighbors_of_neighbor:\n",
        "            new_edge_index.append([node, neighbor_of_neighbor.item()])\n",
        "\n",
        "two_hop_tensor = torch.tensor(new_edge_index, dtype=torch.long).t()\n",
        "\n",
        "two_hop_tensor,_ = add_self_loops(two_hop_tensor, num_nodes=data.num_nodes)\n",
        "print(homophily(two_hop_tensor, data.y))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WxXSWxlNUO_V"
      },
      "outputs": [],
      "source": [
        "from torch_geometric.nn import GCNConv\n",
        "\n",
        "class GCN(torch.nn.Module):\n",
        "    def __init__(self, hidden_channels):\n",
        "        super().__init__()\n",
        "        self.conv1 = GCNConv(dataset.num_features, hidden_channels)\n",
        "        self.conv2 = GCNConv(hidden_channels, dataset.num_classes)\n",
        "\n",
        "    def forward(self, x, edge_index):\n",
        "        x = self.conv1(x.to(device), edge_index_curvm.to(device))\n",
        "        x = x.relu()\n",
        "        x = F.dropout(x.to(device), p=0.5, training=self.training)\n",
        "        x = self.conv2(x.to(device), edge_index_curvm.to(device))\n",
        "        return x.to(device)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hm2vbJ1R6PDq"
      },
      "outputs": [],
      "source": [
        "from torch_geometric.nn import GATConv,\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "class GAT(torch.nn.Module):\n",
        "    def __init__(self, hidden_channels,heads):\n",
        "        super().__init__()\n",
        "        self.conv1 = GATConv(dataset.num_features, hidden_channels,heads=8,dropout= 0.5)\n",
        "        self.conv2 = GATConv(hidden_channels*8, dataset.num_classes,heads=1,dropout= 0.5)\n",
        "\n",
        "    def forward(self, x, edge_index):\n",
        "        x = F.elu(self.conv1(x.to(device), edge_index_curvm.to(device)))\n",
        "        x = F.dropout(x.to(device), p=0.5, training=self.training)\n",
        "        x = self.conv2(x.to(device), edge_index_curvm.to(device))\n",
        "        return x.to(device)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HHPAR4DGdtsB"
      },
      "outputs": [],
      "source": [
        "criterion = torch.nn.CrossEntropyLoss()\n",
        "\n",
        "def train(t,v):\n",
        "      model.train()\n",
        "      optimizer.zero_grad()  # Clear gradients.\n",
        "      out = model(data.x.to(device), data.edge_index.to(device))  # Perform a single forward pass.\n",
        "      loss = criterion(out[data_train_mask], data.y.to(device)[data_train_mask])  # Compute the loss solely based on the training nodes.\n",
        "      val_loss = criterion(out[v].to(device), data.y.to(device)[v])  # Compute the loss solely based on the val nodes.\n",
        "      loss.backward() # Derive gradients.\n",
        "      #val_loss.backward()  # Derive gradients.\n",
        "      optimizer.step()  # Update parameters based on gradients.\n",
        "      return loss,val_loss\n",
        "def test(mask):\n",
        "      model.eval()\n",
        "      out = model(data.x.to(device), data.edge_index.to(device))\n",
        "      pred = out.argmax(dim=1)  # Use the class with highest probability.\n",
        "      correct = pred[mask] == data.y.to(device)[mask]  # Check against ground-truth labels.\n",
        "      acc = int(correct.sum()) / int(mask.sum())  # Derive ratio of correct predictions.\n",
        "      return acc"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "AlURyXBjdu-U",
        "outputId": "bcd97e6f-c934-425a-c574-48c94ce23835"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "runs 0\n",
            "Test Accuracy: 0.5877\n",
            "runs 1\n",
            "Test Accuracy: 0.5837\n",
            "runs 2\n",
            "Test Accuracy: 0.5747\n",
            "runs 3\n",
            "Test Accuracy: 0.5881\n",
            "runs 4\n",
            "Test Accuracy: 0.5879\n",
            "runs 5\n",
            "Test Accuracy: 0.5839\n",
            "runs 6\n",
            "Test Accuracy: 0.5778\n",
            "runs 7\n",
            "Test Accuracy: 0.5802\n",
            "runs 8\n",
            "Test Accuracy: 0.5767\n",
            "runs 9\n",
            "Test Accuracy: 0.5877\n",
            "-------------------------\n",
            "mod GCN(\n",
            "  (conv1): GCNConv(300, 48)\n",
            "  (conv2): GCNConv(48, 18)\n",
            ")\n",
            "hidden_channels 48\n",
            "weight_decay 5e-05\n",
            "lr 0.005\n",
            "moyenne 0.5828369733068609\n",
            "std 0.0049031790802764015\n",
            "pm nan\n",
            "temps 35.83308641910553\n",
            "-------------------------\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/numpy/core/fromnumeric.py:3432: RuntimeWarning: Mean of empty slice.\n",
            "  return _methods._mean(a, axis=axis, dtype=dtype,\n",
            "/usr/local/lib/python3.10/dist-packages/numpy/core/_methods.py:190: RuntimeWarning: invalid value encountered in double_scalars\n",
            "  ret = ret.dtype.type(ret / rcount)\n"
          ]
        }
      ],
      "source": [
        "import time\n",
        "import numpy as np\n",
        "ramdom = True\n",
        "moy = []\n",
        "ep = []\n",
        "T = []\n",
        "p_m =[]\n",
        "nb_runs = 10\n",
        "hidden_channels = 48\n",
        "lr = 0.005\n",
        "weight_decay = 5e-5\n",
        "moy = []\n",
        "ep = []\n",
        "for runs in range(nb_runs):\n",
        "    model = GCN(hidden_channels=hidden_channels).to(device)\n",
        "    #model = GAT(hidden_channels=hidden_channels).to(device)\n",
        "    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)\n",
        "\n",
        "    criterion = torch.nn.CrossEntropyLoss()\n",
        "    print(\"runs\", runs)\n",
        "    start_time = time.time()\n",
        "    shuffled_indices = torch.randperm(len(data.y))\n",
        "    if ramdom == True :\n",
        "\n",
        "\n",
        "        indices = [False for i in range(len(data.y))]\n",
        "        mask_train = [True if i in shuffled_indices[:20*(torch.max(data.y)+1)] else False for i in range(len(indices))]\n",
        "        mask_val = [True if i in shuffled_indices[20*(torch.max(data.y)+1):20*(torch.max(data.y)+1)+500] else False for i in range(len(indices))]\n",
        "        mask_test = [True if i in shuffled_indices[20*(torch.max(data.y)+1)+500:20*(torch.max(data.y)+1)+1500] else False for i in range(len(indices))]\n",
        "\n",
        "        data_train_mask = torch.tensor(mask_train)\n",
        "        data_val_mask = torch.tensor(mask_val)\n",
        "        data_test_mask = torch.tensor(mask_test)\n",
        "\n",
        "\n",
        "    indices = [False for i in range(len(data.y))]\n",
        "    mask_train = [True if i in shuffled_indices[:int(0.6*len(data.y))] else False for i in range(len(indices))]\n",
        "    mask_val = [True if i in shuffled_indices[int(0.6*len(data.y)):int(0.8*len(data.y))] else False for i in range(len(indices))]\n",
        "    mask_test = [True if i in shuffled_indices[int(0.8*len(data.y)):len(data.y)] else False for i in range(len(indices))]\n",
        "\n",
        "\n",
        "    data_train_mask = torch.tensor(mask_train)\n",
        "    data_val_mask = torch.tensor(mask_val)\n",
        "    data_test_mask = torch.tensor(mask_test)\n",
        "    best_val_acc = 0\n",
        "    i=0\n",
        "    for epoch in range(1, 2001):\n",
        "        loss,val_loss = train(data_train_mask,data_val_mask)\n",
        "        val_acc = test(data_val_mask)\n",
        "        train_acc = test(data_train_mask)\n",
        "        test_acc = test(data_test_mask)\n",
        "        if val_acc > best_val_acc :\n",
        "          best_val_acc = val_acc\n",
        "          Test_acc = test_acc\n",
        "          i=0\n",
        "        i=i+1\n",
        "        #print(f'Epoch: {epoch:03d}, Loss: {loss:.4f},Train_acc: {train_acc:.4f}, val_loss: {val_loss:.4f}, val_acc: {val_acc:.4f},test_acc: {test_acc:.4f}')\n",
        "        #print(i)\n",
        "        if i == 100 :\n",
        "          break\n",
        "\n",
        "    T.append(time.time() - start_time)\n",
        "    print(f'Test Accuracy: {Test_acc:.4f}')\n",
        "    moy.append(Test_acc)\n",
        "    ep.append(epoch)\n",
        "print(\"-------------------------\" )\n",
        "print(\"mod\" , model)\n",
        "print(\"hidden_channels\" , hidden_channels)\n",
        "print(\"weight_decay\" , weight_decay)\n",
        "print(\"lr\" , lr)\n",
        "print(\"moyenne\" , np.mean(moy))\n",
        "print(\"std\" , np.std(moy))\n",
        "print(\"temps\" , np.mean(T))\n",
        "\n",
        "print(\"-------------------------\" )\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sQd2_XiegCvi"
      },
      "outputs": [],
      "source": [
        "from torch_geometric.utils import to_dense_adj\n",
        "def spectral_gap(adj):\n",
        "    Adj = to_dense_adj(adj.squeeze()).squeeze()\n",
        "    G = nx.Graph(Adj.numpy())\n",
        "    eigenvalues = nx.normalized_laplacian_spectrum(G)\n",
        "\n",
        "    sorted_indices = np.argsort(eigenvalues)[::-1]\n",
        "    eigenvalues_sorted = eigenvalues[sorted_indices]\n",
        "\n",
        "    spectral_gap = eigenvalues_sorted[1]\n",
        "    sum_of_eigenvalues = np.sum(eigenvalues_sorted[1:])\n",
        "\n",
        "    normalized_spectral_gap = spectral_gap / sum_of_eigenvalues\n",
        "    return normalized_spectral_gap\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "R-1LcamYrI1J",
        "outputId": "228cc789-a0b1-4b7e-ffc8-622965de711d"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "improvement of the spectral gap of 86.9467363071268 %\n"
          ]
        }
      ],
      "source": [
        "after = spectral_gap(edge_index_curvm)\n",
        "before = spectral_gap(data_und)\n",
        "print(\"improvement of the spectral gap of\", ((after - before)/before)*100, \"%\" )"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "machine_shape": "hm",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}