{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8bd9034",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch_geometric.data import Data\n",
    "import networkx as nx\n",
    "from torch_geometric.utils import to_undirected"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ed555ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Code of this function is from the repository of the official code\n",
    "of the papers Weisfeiler and Lehman Go Cellular: CW Networks (NeurIPS 2021)\n",
    "and Weisfeiler and Lehman Go Topological: Message Passing Simplicial Networks (ICML 2021)\"\"\"\n",
    "def load_sr_dataset(path):\n",
    "    \"\"\"Load the Strongly Regular Graph Dataset from the supplied path.\"\"\"\n",
    "    nx_graphs = nx.read_graph6(path)\n",
    "    graphs = list()\n",
    "    for nx_graph in nx_graphs:\n",
    "        n = nx_graph.number_of_nodes()\n",
    "        edge_index = to_undirected(torch.tensor(list(nx_graph.edges()), dtype=torch.long).transpose(1,0))\n",
    "        graphs.append((edge_index, n))\n",
    "    return graphs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79c075bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Monoidal operation $\\circ$\"\"\"\n",
    "def mon_op(A,B):\n",
    "    return A+B+torch.mm(A,B)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b1e2438",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Computing Image for all nodes\"\"\"\n",
    "def image(X,n,m):\n",
    "    cover=torch.zeros(n,n,n)\n",
    "    for i in range(n):\n",
    "        \n",
    "        dec=torch.clone(X)\n",
    "        cover[i].t()[i]=X.t()[i]\n",
    "        dec[i]=0\n",
    "        dec.t()[i]=0\n",
    "        M=torch.zeros(n,n)\n",
    "        N=torch.ones(n,n)\n",
    "        for k in range(n):\n",
    "            if cover[i][k].sum()!=0:\n",
    "                M.t()[k]=1\n",
    "                N.t()[k]=0\n",
    "        c=0\n",
    "            #M.sum()!=0  c<m \n",
    "        while M.sum()!=0:\n",
    "            cover[i]=mon_op((M*dec)-(((M*dec)*((M*dec).t()))),cover[i])\n",
    "            dec=dec-(M*dec)\n",
    "            M=torch.zeros(n,n)\n",
    "            N=torch.ones(n,n)\n",
    "            for k_ in range(n):\n",
    "                if cover[i][k_].sum()!=0 and dec.t()[k_].sum()!=0:\n",
    "                    M.t()[k_]=1\n",
    "                    N.t()[k_]=0\n",
    "            c+=1\n",
    "    return cover"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c1c5e6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Loading the collection of strongly regular graphs\n",
    "\n",
    "https://users.cecs.anu.edu.au/~bdm/data/graphs.html\n",
    "\n",
    "sr251256.g6 (15 graphs)\n",
    "sr261034.g6 (10 graphs)\n",
    "sr281264.g6 (4 graphs)\n",
    "sr291467.g6 (41 graphs)\n",
    "sr351668.g6 (3854 graphs)\n",
    "sr351899.g6 (227 graphs)\n",
    "sr361446.g6 (180 graphs)\n",
    "sr361566.g6 (32548 graphs)\n",
    "sr371889some.g6 (6760 graphs)\n",
    "sr401224.g6 (28 graphs)\n",
    "sr65321516some.g6 (32 graphs)\n",
    "\"\"\"\n",
    "dataset=load_sr_dataset(\"sr291467.g6\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "044a31ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_of_graphs=len(dataset)\n",
    "num_nodes=dataset[0][1]\n",
    "num_edges=len(dataset[0][0][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8284e2ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "Distinguished_graphs=[]\n",
    "for graph in range(num_of_graphs):\n",
    "    gr=dataset[graph]\n",
    "    \n",
    "    Ad_mat=torch.zeros(num_nodes,num_nodes)\n",
    "    \n",
    "    for edge in range(num_edges):\n",
    "        Ad_mat[gr[0][0][edge]][gr[0][1][edge]]=1\n",
    "    \n",
    "    Image=image(Ad_mat,num_nodes,5)\n",
    "    su=torch.sum(Image,0)\n",
    "\n",
    "    output_snn_beta=mon_op(mon_op(su.t(),su),su.t())\n",
    "    \n",
    "    t=torch.mean(output_snn_beta)\n",
    "    s=torch.var(output_snn_beta)\n",
    "    s1=torch.var(torch.sum(torch.eye(num_nodes)*output_snn_beta,0))\n",
    "    t1=torch.mean(torch.sum(torch.eye(num_nodes)*output_snn_beta,0))\n",
    "    \n",
    "    if (t,s,t1,s1) not in Distinguished_graphs:\n",
    "        Distinguished_graphs.append((t,s,t1,s1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e710f97c",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Number of graphs that model can not distinguish\"\"\"\n",
    "print(num_of_graphs-len(Distinguished_graphs))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba8a1610",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
