{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import division\n",
    "from __future__ import print_function\n",
    "\n",
    "import time\n",
    "import argparse\n",
    "import numpy as np\n",
    "import os\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import networkx as nx\n",
    "from scipy import sparse\n",
    "from scipy.linalg import fractional_matrix_power\n",
    "\n",
    "from utils import *\n",
    "from models import Graphsn_GIN\n",
    "from dataset_utils import DataLoader\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "_StoreAction(option_strings=['--dataset'], dest='dataset', nargs=None, const=None, default='cora', type=None, choices=None, help='Dataset name.', metavar=None)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "parser = argparse.ArgumentParser()\n",
    "parser.add_argument('--no-cuda', action='store_true', default=False,\n",
    "                    help='Disables CUDA training.')\n",
    "parser.add_argument('--fastmode', action='store_true', default=False,\n",
    "                    help='Validate during training pass.')\n",
    "parser.add_argument('--seed', type=int, default=42, help='Random seed.')\n",
    "parser.add_argument('--epochs', type=int, default=200,\n",
    "                    help='Number of epochs to train.')\n",
    "parser.add_argument('--lr', type=float, default=0.002,\n",
    "                    help='Initial learning rate.')\n",
    "parser.add_argument('--weight_decay', type=float, default=9e-3,\n",
    "                    help='Weight decay (L2 loss on parameters).')\n",
    "parser.add_argument('--hidden', type=int, default=64,\n",
    "                    help='Number of hidden units.')\n",
    "parser.add_argument('--dropout', type=float, default=0.88,\n",
    "                    help='Dropout rate (1 - keep probability).')\n",
    "parser.add_argument('--dataset', default='cora', help='Dataset name.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = parser.parse_args(\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x1fc94ab5900>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.random.seed(args.seed)\n",
    "torch.manual_seed(args.seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index\n",
      "Processing...\n",
      "Done!\n"
     ]
    }
   ],
   "source": [
    "dname = args.dataset\n",
    "dataset = DataLoader(dname)\n",
    "data = dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "A_norm, A, X, labels, idx_train, idx_val, idx_test = load_citation_data(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "G = nx.from_numpy_matrix(A)\n",
    "feature_dictionary = {}\n",
    "\n",
    "for i in np.arange(len(labels)):\n",
    "    feature_dictionary[i] = labels[i]\n",
    "\n",
    "nx.set_node_attributes(G, feature_dictionary, \"attr_name\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "sub_graphs = []\n",
    "\n",
    "for i in np.arange(len(A)):\n",
    "    s_indexes = []\n",
    "    for j in np.arange(len(A)):\n",
    "        s_indexes.append(i)\n",
    "        if(A[i][j]==1):\n",
    "            s_indexes.append(j)\n",
    "    sub_graphs.append(G.subgraph(s_indexes))\n",
    "\n",
    "subgraph_nodes_list = []\n",
    "\n",
    "for i in np.arange(len(sub_graphs)):\n",
    "    subgraph_nodes_list.append(list(sub_graphs[i].nodes))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "sub_graphs_adj = []\n",
    "for index in np.arange(len(sub_graphs)):\n",
    "    sub_graphs_adj.append(nx.adjacency_matrix(sub_graphs[index]).toarray())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_adj = torch.zeros(A.shape[0], A.shape[0])\n",
    "\n",
    "for node in np.arange(len(subgraph_nodes_list)):\n",
    "    sub_adj = sub_graphs_adj[node]\n",
    "    for neighbors in np.arange(len(subgraph_nodes_list[node])):\n",
    "        index = subgraph_nodes_list[node][neighbors]\n",
    "        count = torch.tensor(0).float()\n",
    "        if(index==node):\n",
    "            continue\n",
    "        else:\n",
    "            c_neighbors = set(subgraph_nodes_list[node]).intersection(subgraph_nodes_list[index])\n",
    "            if index in c_neighbors:\n",
    "                nodes_list = subgraph_nodes_list[node]\n",
    "                sub_graph_index = nodes_list.index(index)\n",
    "                c_neighbors_list = list(c_neighbors)\n",
    "                for i, item1 in enumerate(nodes_list):\n",
    "                    if(item1 in c_neighbors):\n",
    "                        for item2 in c_neighbors_list:\n",
    "                            j = nodes_list.index(item2)\n",
    "                            count += sub_adj[i][j]\n",
    "\n",
    "            new_adj[node][index] = count/2\n",
    "            new_adj[node][index] = new_adj[node][index]/(len(c_neighbors)*(len(c_neighbors)-1))\n",
    "            new_adj[node][index] = new_adj[node][index] * (len(c_neighbors)**1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "features = torch.FloatTensor(X)\n",
    "labels = torch.LongTensor(labels)\n",
    "\n",
    "weight = torch.FloatTensor(new_adj)\n",
    "weight = weight / weight.sum(1, keepdim=True)\n",
    "\n",
    "weight = weight + torch.FloatTensor(A)\n",
    "\n",
    "coeff = weight.sum(1, keepdim=True)\n",
    "coeff = torch.diag((coeff.T)[0])\n",
    "\n",
    "weight = weight + coeff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "weight = weight.detach().numpy()\n",
    "weight = np.nan_to_num(weight, nan=0)\n",
    "adj = torch.FloatTensor(weight)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model and optimizer\n",
    "model = Graphsn_GIN(nfeat=features.shape[1],\n",
    "                    nhid=args.hidden,\n",
    "                    nclass=labels.max().item() + 1,\n",
    "                    dropout=args.dropout)\n",
    "\n",
    "optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(epoch):\n",
    "    t = time.time()\n",
    "    model.train()\n",
    "    optimizer.zero_grad()\n",
    "    output = model(features, adj)\n",
    "    loss_train = F.nll_loss(output[idx_train], labels[idx_train])\n",
    "    acc_train = accuracy(output[idx_train], labels[idx_train])\n",
    "    loss_train.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    if not args.fastmode:\n",
    "        # Evaluate validation set performance separately, deactivates dropout during validation run.\n",
    "        model.eval()\n",
    "        output = model(features, adj)\n",
    "\n",
    "    loss_val = F.nll_loss(output[idx_val], labels[idx_val])\n",
    "    acc_val = accuracy(output[idx_val], labels[idx_val])\n",
    "    print('Epoch: {:04d}'.format(epoch+1),\n",
    "          'loss_train: {:.4f}'.format(loss_train.item()),\n",
    "          'acc_train: {:.4f}'.format(acc_train.item()),\n",
    "          'loss_val: {:.4f}'.format(loss_val.item()),\n",
    "          'acc_val: {:.4f}'.format(acc_val.item()),\n",
    "          'time: {:.4f}s'.format(time.time() - t))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test():\n",
    "    model.eval()\n",
    "    output = model(features, adj)\n",
    "    loss_test = F.nll_loss(output[idx_test], labels[idx_test])\n",
    "    acc_test = accuracy(output[idx_test], labels[idx_test])\n",
    "    print(\"Test set results:\",\n",
    "          \"loss= {:.4f}\".format(loss_test.item()),\n",
    "          \"accuracy= {:.4f}\".format(acc_test.item()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0001 loss_train: 2.0504 acc_train: 0.1786 loss_val: 1.8746 acc_val: 0.3440 time: 1.0253s\n",
      "Epoch: 0002 loss_train: 1.9849 acc_train: 0.2071 loss_val: 1.8429 acc_val: 0.4140 time: 0.6827s\n",
      "Epoch: 0003 loss_train: 1.8407 acc_train: 0.2500 loss_val: 1.8152 acc_val: 0.4720 time: 0.7530s\n",
      "Epoch: 0004 loss_train: 1.7604 acc_train: 0.2929 loss_val: 1.7882 acc_val: 0.5640 time: 0.7181s\n",
      "Epoch: 0005 loss_train: 1.7793 acc_train: 0.3643 loss_val: 1.7618 acc_val: 0.6020 time: 0.6423s\n",
      "Epoch: 0006 loss_train: 1.6856 acc_train: 0.3857 loss_val: 1.7333 acc_val: 0.6380 time: 0.7679s\n",
      "Epoch: 0007 loss_train: 1.6423 acc_train: 0.4714 loss_val: 1.7020 acc_val: 0.6780 time: 0.6732s\n",
      "Epoch: 0008 loss_train: 1.5223 acc_train: 0.5143 loss_val: 1.6692 acc_val: 0.7160 time: 0.7310s\n",
      "Epoch: 0009 loss_train: 1.6090 acc_train: 0.4000 loss_val: 1.6366 acc_val: 0.7300 time: 0.6293s\n",
      "Epoch: 0010 loss_train: 1.5840 acc_train: 0.4571 loss_val: 1.6031 acc_val: 0.7360 time: 0.6154s\n",
      "Epoch: 0011 loss_train: 1.3787 acc_train: 0.5643 loss_val: 1.5679 acc_val: 0.7480 time: 0.6812s\n",
      "Epoch: 0012 loss_train: 1.3795 acc_train: 0.4786 loss_val: 1.5311 acc_val: 0.7500 time: 0.6328s\n",
      "Epoch: 0013 loss_train: 1.4354 acc_train: 0.4714 loss_val: 1.4935 acc_val: 0.7540 time: 0.6524s\n",
      "Epoch: 0014 loss_train: 1.3258 acc_train: 0.5500 loss_val: 1.4557 acc_val: 0.7500 time: 0.5725s\n"
     ]
    }
   ],
   "source": [
    "# Train model\n",
    "t_total = time.time()\n",
    "for epoch in range(args.epochs):\n",
    "    train(epoch)\n",
    "print(\"Optimization Finished!\")\n",
    "print(\"Total time elapsed: {:.4f}s\".format(time.time() - t_total))\n",
    "\n",
    "# Testing\n",
    "test()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
