{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "multiple-making",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.nn import init\n",
    "from torch.autograd import Variable\n",
    "\n",
    "import argparse\n",
    "import os\n",
    "\n",
    "import numpy as np\n",
    "import time\n",
    "import random\n",
    "from sklearn.metrics import f1_score\n",
    "from collections import defaultdict\n",
    "import networkx as nx\n",
    "\n",
    "from encoders import Encoder\n",
    "from aggregators import MeanAggregator\n",
    "from model import SupervisedGraphSage\n",
    "from dataset_utils import DataLoader\n",
    "from utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "existing-bristol",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "_StoreAction(option_strings=['--dataset'], dest='dataset', nargs=None, const=None, default='cora', type=None, choices=None, help='Dataset name.', metavar=None)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "parser = argparse.ArgumentParser()\n",
    "parser.add_argument('--no-cuda', action='store_true', default=False,\n",
    "                    help='Disables CUDA training.')\n",
    "parser.add_argument('--fastmode', action='store_true', default=False,\n",
    "                    help='Validate during training pass.')\n",
    "parser.add_argument('--seed', type=int, default=42, help='Random seed.')\n",
    "parser.add_argument('--epochs', type=int, default=200,\n",
    "                    help='Number of epochs to train.')\n",
    "parser.add_argument('--lr', type=float, default=0.015,\n",
    "                    help='Initial learning rate.')\n",
    "parser.add_argument('--weight_decay', type=float, default=9e-4,\n",
    "                    help='Weight decay (L2 loss on parameters).')\n",
    "parser.add_argument('--hidden', type=int, default=64,\n",
    "                    help='Number of hidden units.')\n",
    "parser.add_argument('--num_samples', type=int, default=25,\n",
    "                    help='Number of samples.')\n",
    "parser.add_argument('--dataset', default='cora', help='Dataset name.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "exceptional-variation",
   "metadata": {},
   "outputs": [],
   "source": [
    "args = parser.parse_args(\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "considerable-lobby",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(args.seed)\n",
    "torch.manual_seed(args.seed)\n",
    "random.seed(args.seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "cubic-puzzle",
   "metadata": {},
   "outputs": [],
   "source": [
    "dname = args.dataset\n",
    "dataset = DataLoader(dname)\n",
    "data = dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "appointed-culture",
   "metadata": {},
   "outputs": [],
   "source": [
    "A_norm, A, X, labels, idx_train, idx_val, idx_test = load_citation_data(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "looking-enlargement",
   "metadata": {},
   "outputs": [],
   "source": [
    "G = nx.from_numpy_matrix(A)\n",
    "feature_dictionary = {}\n",
    "\n",
    "for i in np.arange(len(labels)):\n",
    "    feature_dictionary[i] = labels[i]\n",
    "\n",
    "nx.set_node_attributes(G, feature_dictionary, \"attr_name\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "turned-taiwan",
   "metadata": {},
   "outputs": [],
   "source": [
    "sub_graphs = []\n",
    "\n",
    "for i in np.arange(len(A)):\n",
    "    s_indexes = []\n",
    "    for j in np.arange(len(A)):\n",
    "        s_indexes.append(i)\n",
    "        if(A[i][j]==1):\n",
    "            s_indexes.append(j)\n",
    "    sub_graphs.append(G.subgraph(s_indexes))\n",
    "\n",
    "subgraph_nodes_list = []\n",
    "\n",
    "for i in np.arange(len(sub_graphs)):\n",
    "    subgraph_nodes_list.append(list(sub_graphs[i].nodes))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "noticed-transsexual",
   "metadata": {},
   "outputs": [],
   "source": [
    "sub_graphs_adj = []\n",
    "for index in np.arange(len(sub_graphs)):\n",
    "    sub_graphs_adj.append(nx.adjacency_matrix(sub_graphs[index]).toarray())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "south-closing",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_adj = torch.zeros(A.shape[0], A.shape[0])\n",
    "\n",
    "for node in np.arange(len(subgraph_nodes_list)):\n",
    "    sub_adj = sub_graphs_adj[node]\n",
    "    for neighbors in np.arange(len(subgraph_nodes_list[node])):\n",
    "        index = subgraph_nodes_list[node][neighbors]\n",
    "        count = torch.tensor(0).float()\n",
    "        if(index==node):\n",
    "            continue\n",
    "        else:\n",
    "            c_neighbors = set(subgraph_nodes_list[node]).intersection(subgraph_nodes_list[index])\n",
    "            if index in c_neighbors:\n",
    "                nodes_list = subgraph_nodes_list[node]\n",
    "                sub_graph_index = nodes_list.index(index)\n",
    "                c_neighbors_list = list(c_neighbors)\n",
    "                for i, item1 in enumerate(nodes_list):\n",
    "                    if(item1 in c_neighbors):\n",
    "                        for item2 in c_neighbors_list:\n",
    "                            j = nodes_list.index(item2)\n",
    "                            count += sub_adj[i][j]\n",
    "\n",
    "            new_adj[node][index] = count/2\n",
    "            new_adj[node][index] = new_adj[node][index]/(len(c_neighbors)*(len(c_neighbors)-1))\n",
    "            new_adj[node][index] = new_adj[node][index] * (len(c_neighbors)**1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "integral-greene",
   "metadata": {},
   "outputs": [],
   "source": [
    "weight = torch.FloatTensor(new_adj)\n",
    "weight = weight / weight.sum(1, keepdim=True)\n",
    "\n",
    "weight = weight + torch.FloatTensor(A)\n",
    "\n",
    "coeff = weight.sum(1, keepdim=True)\n",
    "coeff = torch.diag((coeff.T)[0])\n",
    "\n",
    "weight = weight + coeff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "drawn-ideal",
   "metadata": {},
   "outputs": [],
   "source": [
    "weight = weight.detach().numpy()\n",
    "adj = np.nan_to_num(weight, nan=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "associate-reminder",
   "metadata": {},
   "outputs": [],
   "source": [
    "feat_data = np.array(X)\n",
    "\n",
    "adj_lists = defaultdict(set)\n",
    "for i in np.arange(len(sub_graphs)):\n",
    "    adj_lists[i]=set(subgraph_nodes_list[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "likely-maldives",
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = labels.reshape(A.shape[0], 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "musical-matter",
   "metadata": {},
   "outputs": [],
   "source": [
    "full_nodes = torch.LongTensor(np.arange(A.shape[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "expensive-vintage",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "F:\\Projects\\Research\\Task_4\\a_new_perspective_on_how_graph-Supplementary Material\\GraphSNN_Final\\Node_Classification\\4.GraphSN_SAGE\\encoders.py:35: UserWarning: nn.init.xavier_uniform is now deprecated in favor of nn.init.xavier_uniform_.\n",
      "  init.xavier_uniform(self.weight)\n",
      "F:\\Projects\\Research\\Task_4\\a_new_perspective_on_how_graph-Supplementary Material\\GraphSNN_Final\\Node_Classification\\4.GraphSN_SAGE\\model.py:13: UserWarning: nn.init.xavier_uniform is now deprecated in favor of nn.init.xavier_uniform_.\n",
      "  init.xavier_uniform(self.weight)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 77.8\n",
      "Test Accuracy: 80.7\n"
     ]
    }
   ],
   "source": [
    "num_nodes = A.shape[0]\n",
    "num_features = feat_data.shape[1]\n",
    "features = nn.Embedding(num_nodes, num_features)\n",
    "features.weight = nn.Parameter(torch.FloatTensor(feat_data), requires_grad=False)\n",
    "feat_data = torch.FloatTensor(feat_data)\n",
    "\n",
    "agg1 = MeanAggregator(features, cuda=False)\n",
    "enc1 = Encoder(features, num_features, args.hidden, adj_lists, adj, feat_data, agg1, cuda=False)\n",
    "agg2 = MeanAggregator(lambda nodes : enc1(nodes, full_nodes).t(), cuda=False)\n",
    "enc2 = Encoder(lambda nodes : enc1(nodes, full_nodes).t(), enc1.embed_dim, args.hidden, adj_lists, adj, feat_data, agg2,\n",
    "               base_model=enc1, cuda=False)\n",
    "enc1.num_samples = args.num_samples\n",
    "enc2.num_samples = args.num_samples\n",
    "\n",
    "num_classes = np.unique(labels).shape[0]\n",
    "graphsage = SupervisedGraphSage(num_classes, enc2)\n",
    "\n",
    "test_size = torch.count_nonzero(idx_test).item()\n",
    "val_size = torch.count_nonzero(idx_val).item()\n",
    "train_size = torch.count_nonzero(idx_train).item()\n",
    "\n",
    "test = np.array(range(train_size+val_size, train_size+val_size+test_size))\n",
    "val = np.array(range(train_size, train_size+val_size))\n",
    "train = np.array(range(train_size))\n",
    "\n",
    "optimizer = torch.optim.Adam(graphsage.parameters(), lr=args.lr, weight_decay=args.weight_decay) \n",
    "\n",
    "times = []\n",
    "for batch in range(args.epochs):\n",
    "    batch_nodes = train[:train_size]\n",
    "    random.shuffle(train)\n",
    "    start_time = time.time()\n",
    "    optimizer.zero_grad()\n",
    "    loss = graphsage.loss(batch_nodes, full_nodes, Variable(torch.LongTensor(labels[np.array(batch_nodes)])))\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    end_time = time.time()\n",
    "    times.append(end_time-start_time)\n",
    "\n",
    "val_output = graphsage.forward(val, full_nodes) \n",
    "print(\"Validation Accuracy:\", 100*f1_score(labels[val], val_output.data.numpy().argmax(axis=1), average=\"micro\"))\n",
    "\n",
    "test_output = graphsage.forward(test, full_nodes) \n",
    "print(\"Test Accuracy:\", 100*f1_score(labels[test], test_output.data.numpy().argmax(axis=1), average=\"micro\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "periodic-italic",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
