{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('./codes/')\n",
    "import time\n",
    "from config import args\n",
    "\n",
    "from utils import *\n",
    "from models import GCN2 as GCN\n",
    "from metrics import *\n",
    "\n",
    "import torch\n",
    "import torch.optim\n",
    "\n",
    "args.dataset = 'syn3'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('./dataset/' + args.dataset + '.pkl', 'rb') as fin:\n",
    "    adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask, edge_label_matrix  = pkl.load(fin)\n",
    "\n",
    "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "# Some preprocessing\n",
    "if args.normfea:\n",
    "    features = preprocess_features(features)\n",
    "support = preprocess_adj(adj,args.normadj)\n",
    "model = GCN(input_dim=features.shape[1], output_dim=y_train.shape[1], device=device)\n",
    "model.to(device)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=args.elr)\n",
    "\n",
    "features_tensor = torch.Tensor(features).type(torch.float32)\n",
    "\n",
    "i = torch.LongTensor([*support[0]])\n",
    "v = torch.FloatTensor([*support[1]])\n",
    "# LET OP: i moet getransposed worden om sparse tensor te maken met pytorch\n",
    "support_tensor = torch.sparse.FloatTensor(i.t(), v, torch.Size([*support[2]]))\n",
    "support_tensor = support_tensor.type(torch.float32)\n",
    "\n",
    "y_train_tensor = torch.Tensor(y_train).type(torch.float32)\n",
    "train_mask_tensor = torch.Tensor(train_mask)\n",
    "\n",
    "y_test_tensor = torch.Tensor(y_test).type(torch.float32)\n",
    "test_mask_tensor = torch.Tensor(test_mask)\n",
    "\n",
    "y_val_tensor = torch.Tensor(y_val).type(torch.float32)\n",
    "val_mask_tensor = torch.Tensor(val_mask)\n",
    "\n",
    "best_test_acc = 0\n",
    "best_val_acc = 0\n",
    "best_val_loss = 10000\n",
    "clip_value_min = -2.0\n",
    "clip_value_max = 2.0\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "print(\"*************\")\n",
    "print(\"Training GNN!\")\n",
    "print(\"*************\")\n",
    "\n",
    "tik = time.time()\n",
    "\n",
    "f = open(\"LISA_TEST_LOGS/GNN_WEIGHTS/GNN_LOG_\" + args.dataset + \".txt\", \"w\")\n",
    "\n",
    "curr_step = 0\n",
    "\n",
    "for epoch in range(args.epochs):\n",
    "#     print(features_tensor.shape, support_tensor.shape)\n",
    "    output = model((features_tensor,support_tensor),training=True)\n",
    "    cross_loss = masked_softmax_cross_entropy(output, y_train_tensor,train_mask_tensor)\n",
    "\n",
    "    lossL2 = torch.sum(torch.Tensor([torch.sum(v**2) / 2 for v in model.parameters()]))\n",
    "    loss = cross_loss + args.weight_decay*lossL2\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    torch.nn.utils.clip_grad_value_(model.parameters(), clip_value_max)\n",
    "    optimizer.step()\n",
    "\n",
    "    train_acc = masked_accuracy(output, y_train_tensor,train_mask_tensor)\n",
    "    val_acc  = masked_accuracy(output, y_val_tensor,val_mask_tensor)\n",
    "    val_loss = masked_softmax_cross_entropy(output, y_val_tensor, val_mask_tensor)\n",
    "    test_acc  = masked_accuracy(output, y_test_tensor,test_mask_tensor)\n",
    "\n",
    "    if val_acc > best_val_acc:\n",
    "        curr_step = 0\n",
    "        best_test_acc = test_acc\n",
    "        best_val_acc = val_acc\n",
    "        best_val_loss= val_loss\n",
    "        if args.save_model:\n",
    "            best_state_dict = model.state_dict()\n",
    "\n",
    "    else:\n",
    "        curr_step +=1\n",
    "    if curr_step > args.early_stop:\n",
    "        print(\"Early stopping...\")\n",
    "        break\n",
    "\n",
    "    if (epoch + 1) % 100 == 0:\n",
    "        print(\"Epoch:\", '%04d' % (epoch + 1), \"train_loss=\", \"{:.5f}\".format(cross_loss), \"train_acc=\",\n",
    "              \"{:.5f}\".format(train_acc), \"val_acc=\", \"{:.5f}\".format(val_acc), \"test_acc=\", \"{:.5f}\".format(test_acc),\n",
    "              \"best_test_acc=\", \"{:.5f}\".format(best_test_acc))\n",
    "\n",
    "    \n",
    "torch.save(best_state_dict, f'model_weights/GCN_{args.dataset}_BEST.pt')\n",
    "\n",
    "if not args.valid:\n",
    "    torch.save(model.state_dict(), f'model_weights/GCN_{args.dataset}_LAST.pt')\n",
    "\n",
    "tok = time.time()\n",
    "\n",
    "f.write(\"Epoch,%04d\" % (epoch + 1) + \"\\n\")\n",
    "f.write(\"train_loss,{:.5f}\".format(cross_loss) + \"\\n\")\n",
    "f.write(\"train_acc,{:.5f}\".format(train_acc) + \"\\n\")\n",
    "f.write(\"val_acc,{:.5f}\".format(val_acc) + \"\\n\")\n",
    "f.write(\"test_acc,{:.5f}\".format(test_acc) + \"\\n\")\n",
    "f.write(\"best_test_acc,{:.5f}\".format(best_test_acc) + \"\\n\")\n",
    "f.write(\"Time,{}\".format(tok - tik) + \"\\n\")\n",
    "    \n",
    "f.close()\n",
    "\n",
    "print(\"******************\")        \n",
    "print(\"Done training GNN!\")\n",
    "print(\"******************\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
