{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20f445fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch_geometric.data import Data\n",
    "import networkx as nx\n",
    "from torch_geometric.utils import to_undirected\n",
    "from torch_geometric.datasets import GNNBenchmarkDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1c0e901",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Monoidal operation $\\circ$\"\"\"\n",
    "def mon_op(A,B):\n",
    "    return A+B+torch.mm(A,B)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b5478fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Computing Image for all nodes\"\"\"\n",
    "def image(X,n,m):\n",
    "    cover=torch.zeros(n,n,n)\n",
    "    for i in range(n):\n",
    "        \n",
    "        dec=torch.clone(X)\n",
    "        cover[i].t()[i]=X.t()[i]\n",
    "        dec[i]=0\n",
    "        dec.t()[i]=0\n",
    "        M=torch.zeros(n,n)\n",
    "        N=torch.ones(n,n)\n",
    "        for k in range(n):\n",
    "            if cover[i][k].sum()!=0:\n",
    "                M.t()[k]=1\n",
    "                N.t()[k]=0\n",
    "        c=0\n",
    "            #M.sum()!=0  c<m \n",
    "        while M.sum()!=0:\n",
    "            cover[i]=mon_op((M*dec)-(((M*dec)*((M*dec).t()))),cover[i])\n",
    "            dec=dec-(M*dec)\n",
    "            M=torch.zeros(n,n)\n",
    "            N=torch.ones(n,n)\n",
    "            for k_ in range(n):\n",
    "                if cover[i][k_].sum()!=0 and dec.t()[k_].sum()!=0:\n",
    "                    M.t()[k_]=1\n",
    "                    N.t()[k_]=0\n",
    "            c+=1\n",
    "    return cover"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0b0a8b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = GNNBenchmarkDataset(root='/tmp/CSL',name='CSL')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c82f439",
   "metadata": {},
   "outputs": [],
   "source": [
    "l_dataset=len(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdeddce9",
   "metadata": {},
   "outputs": [],
   "source": [
    "Distinguished_graphs=[]\n",
    "for graph in range(l_dataset):\n",
    "    gr=dataset[graph]\n",
    "    num_nodes=gr.num_nodes\n",
    "    num_edges=gr.num_edges\n",
    "    Ad_mat=torch.zeros(num_nodes,num_nodes)\n",
    "    node_int=torch.ones(10,num_nodes)\n",
    "    for edge in range(num_edges):\n",
    "        Ad_mat[gr.edge_index[0][edge]][gr.edge_index[1][edge]]=1\n",
    "    \n",
    "    Image=image(Ad_mat,num_nodes,5)\n",
    "    \n",
    "    \n",
    "    s=torch.sum(torch.sum(Image))\n",
    "    \n",
    "    \n",
    "    if s not in Distinguished_graphs:\n",
    "        Distinguished_graphs.append(s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "153910a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Number of graphs that model can not distinguish\"\"\"\n",
    "num_of_different_isomorph_classes=10\n",
    "print(num_of_different_isomorph_classes-len(Distinguished_graphs))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d0eb312",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
