{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3780b51",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch_geometric.data import Data\n",
    "from torch_geometric.loader import DataLoader\n",
    "from torch_geometric.datasets import TUDataset\n",
    "from torch_geometric.datasets import GNNBenchmarkDataset\n",
    "import torch.optim "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8027d0bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"definition of monoidal operation\"\"\"\n",
    "dtype=torch.cuda.FloatTensor\n",
    "def mon_op(A,B):\n",
    "    return A+B+torch.mm(A,B)\n",
    "\"\"\"monoidal operation $\\circ$ for featured graphs as explain in \\textbf{SNN in Practice} in paper\"\"\"\n",
    "def mmon_op(A,B,n,m):\n",
    "    C=torch.zeros(n,n,m)\n",
    "    for i in range(n):\n",
    "        for j in range(n):\n",
    "            C[i][j]=torch.sum(A[i,:,:]*B[:,j,:],0)\n",
    "    return A+B+C"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9bb73dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"comuting image for all nodes\"\"\"\n",
    "def mimage(X,Y,d,n,m):\n",
    "    cover=torch.zeros(n,n,n)\n",
    "    mcover=torch.zeros(n,n,n,d)\n",
    "    for i in range(n):\n",
    "        \n",
    "        dec=torch.clone(X)\n",
    "        mdec=torch.clone(Y)\n",
    "        cover[i].t()[i]=X.t()[i]\n",
    "        torch.transpose(mcover[i],0,1)[i]=torch.transpose(Y,0,1)[i]\n",
    "        dec[i]=0\n",
    "        dec.t()[i]=0\n",
    "        mdec[i]=0\n",
    "        torch.transpose(mdec,0,1)[i]=0\n",
    "        M=torch.zeros(n,n)\n",
    "        N=torch.ones(n,n)\n",
    "        mM=torch.zeros(n,n,d)\n",
    "        mN=torch.ones(n,n,d)\n",
    "        for k in range(n):\n",
    "            if cover[i][k].sum()!=0:\n",
    "                M.t()[k]=1\n",
    "                N.t()[k]=0\n",
    "                torch.transpose(mM,0,1)[k]=1\n",
    "                torch.transpose(mN,0,1)[k]=0\n",
    "        if m==-1:\n",
    "            while M.sum()!=0:\n",
    "                cover[i]=mon_op((M*dec)-(((M*dec)*((M*dec).t()))),cover[i])\n",
    "                mcover[i]=mmon_op((mM*mdec)-(((mM*mdec)*(torch.transpose(mM*mdec,0,1)))),mcover[i],n,d)\n",
    "                dec=dec-(M*dec)\n",
    "                mdec=mdec-(mM*mdec)\n",
    "                M=torch.zeros(n,n)\n",
    "                N=torch.ones(n,n)\n",
    "                mM=torch.zeros(n,n,d)\n",
    "                mN=torch.ones(n,n,d)\n",
    "                for k_ in range(n):\n",
    "                    if cover[i][k_].sum()!=0 and dec.t()[k_].sum()!=0:\n",
    "                        M.t()[k_]=1\n",
    "                        N.t()[k_]=0\n",
    "                        torch.transpose(mM,0,1)[k_]=1\n",
    "                        torch.transpose(mN,0,1)[k_]=0\n",
    "        else:\n",
    "            c=0\n",
    "                \n",
    "            while c<m:\n",
    "                cover[i]=comp((M*dec)-(((M*dec)*((M*dec).t()))),cover[i])\n",
    "                mcover[i]=mcomp((mM*mdec)-(((mM*mdec)*(torch.transpose(mM*mdec,0,1)))),mcover[i],n,d)\n",
    "                dec=dec-(M*dec)\n",
    "                mdec=mdec-(mM*mdec)\n",
    "                M=torch.zeros(n,n)\n",
    "                N=torch.ones(n,n)\n",
    "                mM=torch.zeros(n,n,d)\n",
    "                mN=torch.ones(n,n,d)\n",
    "                for k_ in range(n):\n",
    "                    if cover[i][k_].sum()!=0 and dec.t()[k_].sum()!=0:\n",
    "                        M.t()[k_]=1\n",
    "                        N.t()[k_]=0\n",
    "                        torch.transpose(mM,0,1)[k_]=1\n",
    "                        torch.transpose(mN,0,1)[k_]=0\n",
    "                c+=1\n",
    "    return (cover,mcover)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c75e24c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = TUDataset(root='/tmp/MUTAG', name='MUTAG',cleaned=False,use_node_attr= True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "162c801a",
   "metadata": {},
   "outputs": [],
   "source": [
    "l_dataset=len(dataset)\n",
    "dataset=dataset.shuffle()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4c49f4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"computing the output of SNN version alpha\"\"\"\n",
    "data_list = []\n",
    "for graph in range(l_dataset):\n",
    "    gr=dataset[graph]\n",
    "    num_nodes=gr.num_nodes\n",
    "    num_edges=gr.num_edges\n",
    "    num_n_feat=gr.num_node_features\n",
    "    num_e_feat=gr.num_edge_features\n",
    "    Ad_mat=torch.zeros(num_nodes,num_nodes,num_e_feat)\n",
    "    ad_mat=torch.zeros(num_nodes,num_nodes)\n",
    "    \n",
    "    \n",
    "    for edge in range(num_edges):\n",
    "        Ad_mat[gr.edge_index[0][edge]][gr.edge_index[1][edge]]=gr.edge_attr[edge]\n",
    "        ad_mat[gr.edge_index[0][edge]][gr.edge_index[1][edge]]=1\n",
    "    \"\"\"computing image\"\"\"\n",
    "    Image=mimage(ad_mat,Ad_mat,num_e_feat,num_nodes,0)[1]\n",
    "    \"\"\"computing output of SNN\"\"\"\n",
    "    output_of_SNN=torch.zeros(num_nodes,num_nodes,num_e_feat)\n",
    "    \n",
    "    for i in range(num_nodes):\n",
    "        for j in range(num_nodes):\n",
    "            for k in range(num_e_feat):\n",
    "                if (Image[i,:,:,k].t()[i].sum()*Image[j,:,:,k].t()[j].sum())!=0:\n",
    "                    output_of_SNN[i][j][k]=mon_op(Image[i,:,:,k].t(),Image[j,:,:,k])[i][j]/(Image[i,:,:,k].t()[i].sum()*Image[j,:,:,k].t()[j].sum())\n",
    "    \"\"\"Extracting the new graph from output\"\"\"\n",
    "    L=[]\n",
    "    for k in range(num_nodes):\n",
    "        for l in range(num_nodes):\n",
    "            if output_of_SNN[k][l].sum()!=0:\n",
    "                L.append((k,l,output_of_SNN[k][l]))\n",
    "    le=len(L)\n",
    "    edge_index=torch.zeros(2,le)\n",
    "    edge_attr=torch.zeros(le,num_e_feat)\n",
    "    for f in range(le):\n",
    "        edge_index[0][f]=L[f][0]\n",
    "        edge_index[1][f]=L[f][1]\n",
    "        edge_attr[f]=L[f][2]\n",
    "    data=Data(x=gr.x,edge_index=edge_index.to(torch.int64),edge_attr=edge_attr,y=gr.y)\n",
    "    data_list.append(data)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
