{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('./codes/forgraph')\n",
    "from config import args\n",
    "# import tensorflow as tf\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import time\n",
    "from models import GCN2 as GCN\n",
    "from metrics import *\n",
    "import pickle as pkl\n",
    "from matplotlib import pyplot as plt\n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "\n",
    "# args.bn = True\n",
    "# args.concat = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# with open('./dataset/BA-2motif.pkl','rb') as fin:\n",
    "#     (adjs, feas, labels) = pkl.load(fin)\n",
    "\n",
    "adjs = np.load('./dataset/mutag_chunks/sub_adjs_1.npy')\n",
    "feas = np.load('./dataset/mutag_chunks/sub_feas_1.npy')\n",
    "labels = np.load('./dataset/mutag_chunks/sub_labels_1.npy')\n",
    "\n",
    "def vis(adj):\n",
    "    G = nx.from_numpy_matrix(adj)\n",
    "    pos = nx.kamada_kawai_layout(G)\n",
    "    nx.draw_networkx_nodes(G, pos)\n",
    "    nx.draw_networkx_edges(G, pos)\n",
    "\n",
    "    plt.axis('off')\n",
    "    plt.show()\n",
    "    plt.clf()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# vis(adjs[0])\n",
    "# vis(adjs[500])\n",
    "\n",
    "order = np.arange(adjs.shape[0])\n",
    "shuffle_adjs = adjs[order]\n",
    "shuffle_feas = feas[order]\n",
    "shuffle_labels = labels[order]\n",
    "\n",
    "train_split = int(adjs.shape[0] * 0.8)\n",
    "val_split = int(adjs.shape[0] * 0.9)\n",
    "\n",
    "train_adjs = shuffle_adjs[:train_split]\n",
    "train_feas = shuffle_feas[:train_split]\n",
    "train_labels = shuffle_labels[:train_split]\n",
    "\n",
    "val_adjs = shuffle_adjs[train_split:val_split]\n",
    "val_feas = shuffle_feas[train_split:val_split]\n",
    "val_labels = shuffle_labels[train_split:val_split]\n",
    "\n",
    "test_adjs = shuffle_adjs[val_split:]\n",
    "test_feas = shuffle_feas[val_split:]\n",
    "test_labels = shuffle_labels[val_split:]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"*************\")\n",
    "print(\"Training GNN!\")\n",
    "print(\"*************\")\n",
    "\n",
    "tik = time.time()\n",
    "\n",
    "f = open(\"LISA_TEST_LOGS/GNN_WEIGHTS/GNN_LOG_\" + args.dataset + \".txt\", \"w\")\n",
    "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "model = GCN(input_dim=train_feas.shape[1:], output_dim=train_labels.shape[1], device=device)\n",
    "model.to(device)\n",
    "\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)\n",
    "\n",
    "train_adjs_tensor = torch.tensor(train_adjs).to(torch.float32)\n",
    "train_features_tensor = torch.tensor(train_feas).to(torch.float32)\n",
    "train_labels_tensor = torch.tensor(train_labels).to(torch.float32)\n",
    "\n",
    "val_adjs_tensor = torch.tensor(val_adjs).to(torch.float32)\n",
    "val_features_tensor = torch.tensor(val_feas).to(torch.float32)\n",
    "val_labels_tensor = torch.tensor(val_labels).to(torch.float32)\n",
    "\n",
    "test_adjs_tensor = torch.tensor(test_adjs).to(torch.float32)\n",
    "test_features_tensor = torch.tensor(test_feas).to(torch.float32)\n",
    "test_labels_tensor = torch.tensor(test_labels).to(torch.float32)\n",
    "\n",
    "best_test_acc = 0\n",
    "best_val_acc = -1\n",
    "clip_value_min = -2.0\n",
    "clip_value_max = 2.0\n",
    "\n",
    "curr_step = 0\n",
    "for epoch in range(args.epochs):\n",
    "    if args.batch:\n",
    "        begin = 0\n",
    "        batch_size= 64\n",
    "        end = batch_size\n",
    "        trainsize = train_adjs.shape[0]\n",
    "        outputs = []\n",
    "        while begin<trainsize:\n",
    "            batch_train_adjs_tensor = torch.tensor(train_adjs[begin:end]).to(torch.float32)\n",
    "            batch_train_features_tensor = torch.tensor(train_feas[begin:end]).to(torch.float32)\n",
    "            batch_train_labels_tensor = torch.tensor(train_labels[begin:end]).to(torch.float32)\n",
    "            \n",
    "            output = model.forward((batch_train_features_tensor,batch_train_adjs_tensor),training=True)\n",
    "            cross_loss = softmax_cross_entropy(output, batch_train_labels_tensor)\n",
    "            lossL2 = torch.sum(torch.Tensor([torch.sum(v**2) / 2 for v in model.parameters()]))\n",
    "            loss = cross_loss + args.weight_decay * lossL2\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            nn.utils.clip_grad_value_(model.parameters(), clip_value_max)\n",
    "            optimizer.step()\n",
    "            begin = end\n",
    "            end = min(end+batch_size,trainsize)\n",
    "            outputs.append(output)\n",
    "        output = torch.cat(outputs,dim=0)\n",
    "    else:\n",
    "        output = model.forward((train_feas_tensor,train_adjs_tensor),training=True)\n",
    "        cross_loss = softmax_cross_entropy(output, train_labels_tensor)\n",
    "        lossL2 = torch.sum(torch.Tensor([torch.sum(v**2) / 2 for v in model.parameters()]))\n",
    "        loss = cross_loss + args.weight_decay*lossL2\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        nn.utils.clip_grad_value_(model.parameters(), clip_value_max)\n",
    "        optimizer.step()\n",
    "\n",
    "    train_acc = accuracy(output, train_labels_tensor)\n",
    "    val_output = model.forward((val_features_tensor, val_adjs_tensor), training=False)\n",
    "    val_acc  = accuracy(val_output, val_labels_tensor)\n",
    "    val_loss = softmax_cross_entropy(val_output, val_labels_tensor)\n",
    "\n",
    "    test_output = model.forward((test_features_tensor, test_adjs_tensor), training=False)\n",
    "    test_acc  = accuracy(test_output, test_labels_tensor)\n",
    "    test_loss = softmax_cross_entropy(test_output, test_labels_tensor)\n",
    "    \n",
    "    # Save better model\n",
    "    if val_acc > best_val_acc:\n",
    "        curr_step = 0\n",
    "        best_test_acc = test_acc\n",
    "        best_val_acc = val_acc\n",
    "        if args.save_model:\n",
    "            best_state_dict = model.state_dict()\n",
    "    else:\n",
    "        curr_step +=1\n",
    "        \n",
    "    if curr_step > args.early_stop:\n",
    "        print(\"Early stopping...\")\n",
    "        break\n",
    "\n",
    "    if (epoch + 1) % 100 == 0:\n",
    "        print(\"Epoch:\", '%04d' % (epoch + 1), \"train_loss=\", \"{:.5f}\".format(cross_loss), \"train_acc=\",\n",
    "              \"{:.5f}\".format(train_acc), \"val_acc=\", \"{:.5f}\".format(val_acc),\n",
    "              \"test_acc=\", \"{:.5f}\".format(test_acc),\n",
    "              \"test_best_acc=\", \"{:.5f}\".format(best_test_acc))\n",
    "\n",
    "\n",
    "torch.save(best_state_dict, f'model_weights/GCN_BA2motif_BEST.pt')\n",
    "torch.save(model.state_dict(), f'model_weights/GCN_BA2motif_LAST.pt')\n",
    "\n",
    "tok = time.time()\n",
    "\n",
    "f.write(\"Epoch,%04d\" % (epoch + 1) + \"\\n\")\n",
    "f.write(\"train_loss,{:.5f}\".format(cross_loss) + \"\\n\")\n",
    "f.write(\"train_acc,{:.5f}\".format(train_acc) + \"\\n\")\n",
    "f.write(\"val_acc,{:.5f}\".format(val_acc) + \"\\n\")\n",
    "f.write(\"test_acc,{:.5f}\".format(test_acc) + \"\\n\")\n",
    "f.write(\"best_test_acc,{:.5f}\".format(best_test_acc) + \"\\n\")\n",
    "f.write(\"Time,{}\".format(tok - tik) + \"\\n\")\n",
    "    \n",
    "f.close()\n",
    "\n",
    "print(\"******************\")        \n",
    "print(\"Done training GNN!\")\n",
    "print(\"******************\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
