{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c8bd9034",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch_geometric.data import Data\n",
    "import networkx as nx\n",
    "from torch_geometric.utils import to_undirected"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3ed555ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Code of this function is from the repository of the official code\n",
    "of the papers Weisfeiler and Lehman Go Cellular: CW Networks (NeurIPS 2021)\n",
    "and Weisfeiler and Lehman Go Topological: Message Passing Simplicial Networks (ICML 2021)\"\"\"\n",
    "def load_sr_dataset(path):\n",
    "    \"\"\"Load the Strongly Regular Graph Dataset from the supplied path.\"\"\"\n",
    "    nx_graphs = nx.read_graph6(path)\n",
    "    graphs = list()\n",
    "    for nx_graph in nx_graphs:\n",
    "        n = nx_graph.number_of_nodes()\n",
    "        edge_index = to_undirected(torch.tensor(list(nx_graph.edges()), dtype=torch.long).transpose(1,0))\n",
    "        graphs.append((edge_index, n))\n",
    "    return graphs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "79c075bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Monoidal operation $\\circ$\"\"\"\n",
    "def mon_op(A,B):\n",
    "    return A+B+torch.mm(A,B)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "96027c3c-e8ca-4c4b-961c-7784717a3a1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"comuting image for all nodes\"\"\"\n",
    "def image(X,n,m):\n",
    "    cover=torch.zeros(n,n,n).to(torch.float64)\n",
    "    for i in range(n):\n",
    "        \n",
    "        dec=(X!=0).float()\n",
    "        wdec=torch.clone(X)\n",
    "        cover[i].t()[i]=X.t()[i]\n",
    "        dec[i]=0\n",
    "        dec.t()[i]=0\n",
    "        #wdec[i]=0\n",
    "        #wdec.t()[i]=0\n",
    "        M=torch.zeros(n,n).to(torch.float64)\n",
    "       # N=torch.ones(n,n).to(torch.float64)\n",
    "        for k in range(n):\n",
    "            if cover[i][k].sum()!=0:\n",
    "                M.t()[k]=1\n",
    "                #N.t()[k]=0\n",
    "        c=0\n",
    "            #M.sum()!=0\n",
    "        while c<m:\n",
    "            om=(M*dec)-(((M*dec)*((M*dec).t())))\n",
    "            cover[i]=mon_op(om*wdec,cover[i])\n",
    "            dec=dec-(M*dec+(M*dec).t()-(M*dec)*(M*dec).t())\n",
    "            #wdec=wdec-(M*wdec)\n",
    "            M=torch.zeros(n,n).to(torch.float64)\n",
    "           # N=torch.ones(n,n).to(torch.float64)\n",
    "            for k_ in range(n):\n",
    "                if cover[i][k_].sum()!=0 and dec.t()[k_].sum()!=0:\n",
    "                    M.t()[k_]=1\n",
    "                   # N.t()[k_]=0\n",
    "            c+=1\n",
    "    return torch.log1p(cover)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "1c1c5e6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Loading the collection of strongly regular graphs\n",
    "\n",
    "https://users.cecs.anu.edu.au/~bdm/data/graphs.html\n",
    "\n",
    "sr251256.g6 (15 graphs)\n",
    "sr261034.g6 (10 graphs)\n",
    "sr281264.g6 (4 graphs)\n",
    "sr291467.g6 (41 graphs)\n",
    "sr351668.g6 (3854 graphs)\n",
    "sr351899.g6 (227 graphs)\n",
    "sr361446.g6 (180 graphs)\n",
    "sr361566.g6 (32548 graphs)\n",
    "sr371889some.g6 (6760 graphs)\n",
    "sr401224.g6 (28 graphs)\n",
    "sr65321516some.g6 (32 graphs)\n",
    "\"\"\"\n",
    "dataset=load_sr_dataset(\"sr371889some.g6\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "044a31ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_of_graphs=len(dataset)\n",
    "num_nodes=dataset[0][1]\n",
    "num_edges=len(dataset[0][0][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "8284e2ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "Distinguished_graphs=[]\n",
    "for graph in range(num_of_graphs):\n",
    "    gr=dataset[graph]\n",
    "    \n",
    "    Ad_mat=torch.zeros(num_nodes,num_nodes)\n",
    "    \n",
    "    for edge in range(num_edges):\n",
    "        Ad_mat[gr[0][0][edge]][gr[0][1][edge]]=1\n",
    "    \n",
    "    Image=image(Ad_mat,num_nodes,5)\n",
    "    su=torch.sum(Image,0)\n",
    "\n",
    "    output_snn_beta=0.01*mon_op(mon_op(0.01*su.t(),0.01*su),su.t())\n",
    "    #output_snn_beta=mon_op(output_snn_beta,output_snn_beta)\n",
    "    \n",
    "    t=torch.mean(output_snn_beta)\n",
    "    s=torch.var(output_snn_beta)\n",
    "    s1=torch.var(torch.sum(torch.eye(num_nodes)*output_snn_beta,0))\n",
    "    t1=torch.mean(torch.sum(torch.eye(num_nodes)*output_snn_beta,0))\n",
    "    det_su=torch.linalg.det(output_snn_beta)\n",
    "    mm=torch.min(output_snn_beta)\n",
    "    if (t,s,t1,s1) not in Distinguished_graphs:\n",
    "        Distinguished_graphs.append((mm,det_su,t,s,t1,s1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "e710f97c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n"
     ]
    }
   ],
   "source": [
    "\"\"\"Number of graphs that model can not distinguish\"\"\"\n",
    "print(num_of_graphs-len(Distinguished_graphs))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90670265-346c-4b3b-ae5b-f9dfdf7c26e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "Distinguished_graphs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "935d6058-d7f7-48ad-8842-dbfe6cead445",
   "metadata": {},
   "outputs": [],
   "source": [
    "from math import isfinite, isinf\n",
    "\n",
    "def all_finite(tuples_list):\n",
    "    # True if no NaN/±inf anywhere\n",
    "    return all(isfinite(x) for tup in tuples_list for x in tup)\n",
    "\n",
    "def none_infinite(tuples_list):\n",
    "    # True if no ±inf (but allows NaN). Use this if you only care about infinity.\n",
    "    return not any(isinf(x) for tup in tuples_list for x in tup)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "2a4defba-d440-4869-9d0f-1a9652d9798c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_finite(Distinguished_graphs)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
