{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eaac2364",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch_geometric.data import Data\n",
    "from torch_geometric.loader import DataLoader\n",
    "from torch_geometric.datasets import TUDataset\n",
    "from torch_geometric.datasets import GNNBenchmarkDataset\n",
    "import torch.optim "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f26a965",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"definition of monoidal operation\"\"\"\n",
    "dtype=torch.cuda.FloatTensor\n",
    "def mon_op(A,B):\n",
    "    return A+B+torch.mm(A,B)\n",
    "\"\"\"monoidal operation $\\circ$ for featured graphs as explain in \\textbf{SNN in Practice} in paper\"\"\"\n",
    "def mmon_op(A,B,n,m):\n",
    "    C=torch.zeros(n,n,m)\n",
    "    for i in range(n):\n",
    "        for j in range(n):\n",
    "            C[i][j]=torch.sum(A[i,:,:]*B[:,j,:],0)\n",
    "    return A+B+C"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f870a2a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"comuting image for all nodes\"\"\"\n",
    "def mimage(X,Y,d,n,m):\n",
    "    cover=torch.zeros(n,n,n)\n",
    "    mcover=torch.zeros(n,n,n,d)\n",
    "    for i in range(n):\n",
    "        \n",
    "        dec=torch.clone(X)\n",
    "        mdec=torch.clone(Y)\n",
    "        cover[i].t()[i]=X.t()[i]\n",
    "        torch.transpose(mcover[i],0,1)[i]=torch.transpose(Y,0,1)[i]\n",
    "        dec[i]=0\n",
    "        dec.t()[i]=0\n",
    "        mdec[i]=0\n",
    "        torch.transpose(mdec,0,1)[i]=0\n",
    "        M=torch.zeros(n,n)\n",
    "        N=torch.ones(n,n)\n",
    "        mM=torch.zeros(n,n,d)\n",
    "        mN=torch.ones(n,n,d)\n",
    "        for k in range(n):\n",
    "            if cover[i][k].sum()!=0:\n",
    "                M.t()[k]=1\n",
    "                N.t()[k]=0\n",
    "                torch.transpose(mM,0,1)[k]=1\n",
    "                torch.transpose(mN,0,1)[k]=0\n",
    "        if m==-1:\n",
    "            while M.sum()!=0:\n",
    "                cover[i]=mon_op((M*dec)-(((M*dec)*((M*dec).t()))),cover[i])\n",
    "                mcover[i]=mmon_op((mM*mdec)-(((mM*mdec)*(torch.transpose(mM*mdec,0,1)))),mcover[i],n,d)\n",
    "                dec=dec-(M*dec)\n",
    "                mdec=mdec-(mM*mdec)\n",
    "                M=torch.zeros(n,n)\n",
    "                N=torch.ones(n,n)\n",
    "                mM=torch.zeros(n,n,d)\n",
    "                mN=torch.ones(n,n,d)\n",
    "                for k_ in range(n):\n",
    "                    if cover[i][k_].sum()!=0 and dec.t()[k_].sum()!=0:\n",
    "                        M.t()[k_]=1\n",
    "                        N.t()[k_]=0\n",
    "                        torch.transpose(mM,0,1)[k_]=1\n",
    "                        torch.transpose(mN,0,1)[k_]=0\n",
    "        else:\n",
    "            c=0\n",
    "                \n",
    "            while c<m:\n",
    "                cover[i]=comp((M*dec)-(((M*dec)*((M*dec).t()))),cover[i])\n",
    "                mcover[i]=mcomp((mM*mdec)-(((mM*mdec)*(torch.transpose(mM*mdec,0,1)))),mcover[i],n,d)\n",
    "                dec=dec-(M*dec)\n",
    "                mdec=mdec-(mM*mdec)\n",
    "                M=torch.zeros(n,n)\n",
    "                N=torch.ones(n,n)\n",
    "                mM=torch.zeros(n,n,d)\n",
    "                mN=torch.ones(n,n,d)\n",
    "                for k_ in range(n):\n",
    "                    if cover[i][k_].sum()!=0 and dec.t()[k_].sum()!=0:\n",
    "                        M.t()[k_]=1\n",
    "                        N.t()[k_]=0\n",
    "                        torch.transpose(mM,0,1)[k_]=1\n",
    "                        torch.transpose(mN,0,1)[k_]=0\n",
    "                c+=1\n",
    "    return (cover,mcover)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53d8cd68",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = TUDataset(root='/tmp/PTC_MR', name='PTC_MR',cleaned=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a39ad33",
   "metadata": {},
   "outputs": [],
   "source": [
    "l_dataset=len(dataset)\n",
    "dataset=dataset.shuffle()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36fe8727",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"computing the output of SNN version alpha\"\"\"\n",
    "data_list = []\n",
    "for graph in range(l_dataset):\n",
    "    gr=dataset[graph]\n",
    "    num_nodes=gr.num_nodes\n",
    "    num_edges=gr.num_edges\n",
    "    num_n_feat=gr.num_node_features\n",
    "    num_e_feat=gr.num_edge_features\n",
    "    Ad_mat=torch.zeros(num_nodes,num_nodes,num_e_feat)\n",
    "    ad_mat=torch.zeros(num_nodes,num_nodes)\n",
    "    \n",
    "    \n",
    "    for edge in range(num_edges):\n",
    "        Ad_mat[gr.edge_index[0][edge]][gr.edge_index[1][edge]]=gr.edge_attr[edge]\n",
    "        ad_mat[gr.edge_index[0][edge]][gr.edge_index[1][edge]]=1\n",
    "    \"\"\"computing image\"\"\"\n",
    "    Image=mimage(ad_mat,Ad_mat,num_e_feat,num_nodes,0)[1]\n",
    "    \"\"\"computing output of SNN\"\"\"\n",
    "    output_of_SNN=torch.zeros(num_nodes,num_nodes,num_e_feat)\n",
    "    \n",
    "    for i in range(num_nodes):\n",
    "        for j in range(num_nodes):\n",
    "            for k in range(num_e_feat):\n",
    "                if (Image[i,:,:,k].t()[i].sum()*Image[j,:,:,k].t()[j].sum())!=0:\n",
    "                    output_of_SNN[i][j][k]=mon_op(Image[i,:,:,k].t(),Image[j,:,:,k])[i][j]/(Image[i,:,:,k].t()[i].sum()*Image[j,:,:,k].t()[j].sum())\n",
    "    \"\"\"Extracting the new graph from output\"\"\"\n",
    "    L=[]\n",
    "    for k in range(num_nodes):\n",
    "        for l in range(num_nodes):\n",
    "            if output_of_SNN[k][l].sum()!=0:\n",
    "                L.append((k,l,output_of_SNN[k][l]))\n",
    "    le=len(L)\n",
    "    edge_index=torch.zeros(2,le)\n",
    "    edge_attr=torch.zeros(le,num_e_feat)\n",
    "    for f in range(le):\n",
    "        edge_index[0][f]=L[f][0]\n",
    "        edge_index[1][f]=L[f][1]\n",
    "        edge_attr[f]=L[f][2]\n",
    "    data=Data(x=gr.x,edge_index=edge_index.to(torch.int64),edge_attr=edge_attr,y=gr.y)\n",
    "    data_list.append(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f02f8630",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(12345)\n",
    "#dataset = dataset.shuffle()\n",
    "\n",
    "train_dataset = data_list[:170]+data_list[204:]\n",
    "test_dataset = data_list[170:204]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d868452",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)\n",
    "test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d62597bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.nn import Linear\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.nn import GCNConv,GraphConv, NNConv,GATConv,GINEConv\n",
    "from torch_geometric.nn import global_max_pool,global_add_pool,global_mean_pool\n",
    "\n",
    "\n",
    "class GCN(torch.nn.Module):\n",
    "    def __init__(self, hidden_channels):\n",
    "        super(GCN, self).__init__()\n",
    "        torch.manual_seed(12345)\n",
    "        self.conv1 = GINEConv(torch.nn.Linear(18, hidden_channels),edge_dim=4)\n",
    "        self.conv2 = GINEConv(torch.nn.Linear(hidden_channels, hidden_channels),edge_dim=4)\n",
    "        self.conv3 = GINEConv(torch.nn.Linear(hidden_channels, hidden_channels),edge_dim=4)\n",
    "        self.conv4 = GINEConv(torch.nn.Linear(hidden_channels, hidden_channels),edge_dim=4)\n",
    "        self.conv5 = GINEConv(torch.nn.Linear(hidden_channels, hidden_channels),edge_dim=4)\n",
    "        self.conv6 = GINEConv(torch.nn.Linear(hidden_channels, hidden_channels),edge_dim=4)\n",
    "        self.lin = Linear(hidden_channels,2)\n",
    "\n",
    "    def forward(self, x, edge_index,edge_attr, batch):\n",
    "        # 1. Obtain node embeddings \n",
    "        x = self.conv1(x, edge_index,edge_attr)\n",
    "        x = x.relu()\n",
    "        #x = self.conv2(x, edge_index,edge_attr)\n",
    "        #x = x.relu()\n",
    "        #x = self.conv3(x, edge_index,edge_attr)\n",
    "        #x = x.relu()\n",
    "        #x = self.conv4(x, edge_index,edge_attr)\n",
    "        #x = x.relu()\n",
    "        x = self.conv5(x, edge_index,edge_attr)\n",
    "        x = x.tanh()\n",
    "        x = self.conv6(x, edge_index,edge_attr)\n",
    "        # 2. Readout layer\n",
    "        x = global_add_pool(x, batch)  # [batch_size, hidden_channels]\n",
    "\n",
    "        # 3. Apply a final classifier\n",
    "        x = F.dropout(x, p=0.5, training=self.training)\n",
    "        x = self.lin(x)\n",
    "        \n",
    "        return x\n",
    "\n",
    "model = GCN(hidden_channels=8)\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "986cbc8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "model = GCN(hidden_channels=16)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=8e-4, weight_decay=0)\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "def train():\n",
    "    model.train()\n",
    "\n",
    "    for data in train_loader:  # Iterate in batches over the training dataset.\n",
    "        out = model(data.x, data.edge_index,data.edge_attr, data.batch)  # Perform a single forward pass.\n",
    "        loss = criterion(out, data.y)  # Compute the loss.\n",
    "        loss.backward()  # Derive gradients.\n",
    "        optimizer.step()  # Update parameters based on gradients.\n",
    "        optimizer.zero_grad()  # Clear gradients.\n",
    "    #print(loss)\n",
    "def test(loader):\n",
    "    model.eval()\n",
    "\n",
    "    correct = 0\n",
    "    for data in loader:  # Iterate in batches over the training/test dataset.\n",
    "        out = model(data.x, data.edge_index,data.edge_attr, data.batch)  \n",
    "        pred = out.argmax(dim=1)  # Use the class with highest probability.\n",
    "        correct += int((pred == data.y).sum())  # Check against ground-truth labels.\n",
    "    return correct / len(loader.dataset)  # Derive ratio of correct predictions.\n",
    "\n",
    "\n",
    "for epoch in range(1, 2000):\n",
    "    train()\n",
    "    train_acc = test(train_loader)\n",
    "    test_acc = test(test_loader)\n",
    "    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "281db795",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
