{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3bb10fc8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch_geometric.data import Data\n",
    "from torch_geometric.loader import DataLoader\n",
    "from torch_geometric.datasets import TUDataset\n",
    "from torch_geometric.datasets import GNNBenchmarkDataset\n",
    "import numpy as np\n",
    "import torch.optim \n",
    "import networkx as nx\n",
    "from torch_geometric.utils.convert import from_networkx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d1e3dca",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"monoidal operation $\\circ$ for featured graphs as explain in \\textbf{SNN in Practice} in paper\"\"\"\n",
    "dtype=torch.cuda.FloatTensor\n",
    "def mon_op(A,B):\n",
    "    return A+B+torch.mm(A,B)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94b8793c",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = TUDataset(root='/tmp/NCI1', name='NCI1',cleaned=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2be33e6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "l_dataset=len(dataset)\n",
    "dataset=dataset.shuffle()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f574c37",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"comuting image for all nodes\"\"\"\n",
    "def image(X,n,m):\n",
    "    cover=torch.zeros(n,n,n)\n",
    "    for i in range(n):\n",
    "        \n",
    "        dec=torch.clone(X)\n",
    "        cover[i].t()[i]=X.t()[i]\n",
    "        dec[i]=0\n",
    "        dec.t()[i]=0\n",
    "        M=torch.zeros(n,n)\n",
    "        N=torch.ones(n,n)\n",
    "        for k in range(n):\n",
    "            if cover[i][k].sum()!=0:\n",
    "                M.t()[k]=1\n",
    "                N.t()[k]=0\n",
    "        c=0\n",
    "            #M.sum()!=0\n",
    "        while c<m:\n",
    "            cover[i]=mon_op((M*dec)-(((M*dec)*((M*dec).t()))),cover[i])\n",
    "            dec=dec-(M*dec)\n",
    "            M=torch.zeros(n,n)\n",
    "            N=torch.ones(n,n)\n",
    "            for k_ in range(n):\n",
    "                if cover[i][k_].sum()!=0 and dec.t()[k_].sum()!=0:\n",
    "                    M.t()[k_]=1\n",
    "                    N.t()[k_]=0\n",
    "            c+=1\n",
    "    return cover"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cab88803",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"computing the output of SNN version alpha\"\"\"\n",
    "data_list = []\n",
    "for graph in range(l_dataset):\n",
    "    gr=dataset[graph]\n",
    "    num_nodes=gr.num_nodes\n",
    "    num_edges=gr.num_edges\n",
    "    Ad_mat=torch.zeros(num_nodes,num_nodes)\n",
    "    \n",
    "    for edge in range(num_edges):\n",
    "        Ad_mat[gr.edge_index[0][edge]][gr.edge_index[1][edge]]=0.5\n",
    "    \"\"\"Computing images and coimages\"\"\"\n",
    "    CoImage=image(Ad_mat,num_nodes,0)\n",
    "    Image=image(Ad_mat,num_nodes,1)\n",
    "    \"\"\"computing the output of SNN version alpha\"\"\"\n",
    "    output_of_SNN=torch.zeros(num_nodes,num_nodes)\n",
    "    for i in range(num_nodes):\n",
    "        for j in range(num_nodes):\n",
    "            if (CoImage[i].t()[i].sum()*Image[j].t()[j].sum())!=0:\n",
    "                output_of_SNN[i][j]=mon_op(CoImage[i].t(),Image[j])[i][j]/(CoImage[i].t()[i].sum()*Image[j].t()[j].sum())\n",
    "    \"\"\"creating a graph from the output\"\"\"\n",
    "    array = output_of_SNN.numpy()\n",
    "    G=from_networkx(nx.from_numpy_matrix(array))\n",
    "    G[\"x\"]=gr.x\n",
    "    G[\"y\"]=torch.tensor([gr.y])\n",
    "    G.edge_index.to(torch.float)\n",
    "    data_list.append(G)\n",
    "\"\"\"This graph can be used in every GNN as input\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41e65e2f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
