{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('./codes/')\n",
    "from config import args\n",
    "\n",
    "from utils import *\n",
    "from models import GCN#2 as GCN\n",
    "from metrics import *\n",
    "\n",
    "import torch\n",
    "import torch.optim\n",
    "\n",
    "args.dataset = 'syn1'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('./dataset/' + args.dataset + '.pkl', 'rb') as fin:\n",
    "    adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask, edge_label_matrix  = pkl.load(fin)\n",
    "\n",
    "# Some preprocessing\n",
    "if args.normfea:\n",
    "    features = preprocess_features(features)\n",
    "support = preprocess_adj(adj,args.normadj)\n",
    "model = GCN(input_dim=features.shape[1], output_dim=y_train.shape[1])\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=args.elr)\n",
    "\n",
    "features_tensor = torch.Tensor(features).type(torch.float32)\n",
    "\n",
    "i = torch.LongTensor([*support[0]])\n",
    "v = torch.FloatTensor([*support[1]])\n",
    "# LET OP: i moet getransposed worden om sparse tensor te maken met pytorch\n",
    "support_tensor = torch.sparse.FloatTensor(i.t(), v, torch.Size([*support[2]]))\n",
    "support_tensor = support_tensor.type(torch.float32)\n",
    "\n",
    "y_train_tensor = torch.Tensor(y_train).type(torch.float32)\n",
    "train_mask_tensor = torch.Tensor(train_mask)\n",
    "\n",
    "y_test_tensor = torch.Tensor(y_test).type(torch.float32)\n",
    "test_mask_tensor = torch.Tensor(test_mask)\n",
    "\n",
    "y_val_tensor = torch.Tensor(y_val).type(torch.float32)\n",
    "val_mask_tensor = torch.Tensor(val_mask)\n",
    "\n",
    "best_test_acc = 0\n",
    "best_val_acc = 0\n",
    "best_val_loss = 10000\n",
    "clip_value_min = -2.0\n",
    "clip_value_max = 2.0\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([10, 20])\n",
      "torch.Size([20])\n",
      "torch.Size([20])\n",
      "torch.Size([20])\n",
      "torch.Size([20, 20])\n",
      "torch.Size([20])\n",
      "torch.Size([20])\n",
      "torch.Size([20])\n",
      "torch.Size([20, 20])\n",
      "torch.Size([20])\n",
      "torch.Size([4, 60])\n",
      "torch.Size([4])\n",
      "torch.Size([700, 10]) torch.Size([700, 700])\n",
      "(tensor([[1., 1., 1.,  ..., 1., 1., 1.],\n",
      "        [1., 1., 1.,  ..., 1., 1., 1.],\n",
      "        [1., 1., 1.,  ..., 1., 1., 1.],\n",
      "        ...,\n",
      "        [1., 1., 1.,  ..., 1., 1., 1.],\n",
      "        [1., 1., 1.,  ..., 1., 1., 1.],\n",
      "        [1., 1., 1.,  ..., 1., 1., 1.]]), tensor(indices=tensor([[  0,   0,   0,  ..., 698, 699, 699],\n",
      "                       [  5,   7,   9,  ..., 697, 695, 696]]),\n",
      "       values=tensor([1., 1., 1.,  ..., 1., 1., 1.]),\n",
      "       size=(700, 700), nnz=4110, layout=torch.sparse_coo))\n",
      "(tensor([[0.0000, 0.3764, 0.2215,  ..., 0.2627, 0.0000, 0.0000],\n",
      "        [0.0000, 0.3764, 0.2215,  ..., 0.2627, 0.0000, 0.0000],\n",
      "        [0.0000, 0.3764, 0.2215,  ..., 0.2627, 0.0000, 0.0000],\n",
      "        ...,\n",
      "        [0.0000, 0.3764, 0.2215,  ..., 0.2627, 0.0000, 0.0000],\n",
      "        [0.0000, 0.3764, 0.2215,  ..., 0.2627, 0.0000, 0.0000],\n",
      "        [0.0000, 0.3764, 0.2215,  ..., 0.2627, 0.0000, 0.0000]],\n",
      "       grad_fn=<ReluBackward0>), tensor(indices=tensor([[  0,   0,   0,  ..., 698, 699, 699],\n",
      "                       [  5,   7,   9,  ..., 697, 695, 696]]),\n",
      "       values=tensor([1., 1., 1.,  ..., 1., 1., 1.]),\n",
      "       size=(700, 700), nnz=4110, layout=torch.sparse_coo))\n"
     ]
    },
    {
     "ename": "TypeError",
     "evalue": "forward() takes 2 positional arguments but 3 were given",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-3-d5d646e36c96>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mepochs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      7\u001b[0m     \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeatures_tensor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msupport_tensor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m     \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeatures_tensor\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0msupport_tensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtraining\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      9\u001b[0m     \u001b[0mcross_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmasked_softmax_cross_entropy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_train_tensor\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtrain_mask_tensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/fact/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    720\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    721\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 722\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    723\u001b[0m         for hook in itertools.chain(\n\u001b[1;32m    724\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Desktop/master/FACT/FACTAI21/PGExplainer_torch/codes/models.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, inputs, training)\u001b[0m\n\u001b[1;32m     46\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtraining\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     47\u001b[0m         \u001b[0;31m# Run network\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 48\u001b[0;31m         \u001b[0memb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0membedding\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtraining\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     49\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpred_layer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0memb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     50\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Desktop/master/FACT/FACTAI21/PGExplainer_torch/codes/models.py\u001b[0m in \u001b[0;36membedding\u001b[0;34m(self, inputs, training)\u001b[0m\n\u001b[1;32m     55\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mlayer\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayers_\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     56\u001b[0m             \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msupport\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 57\u001b[0;31m             \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlayer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msupport\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtraining\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     58\u001b[0m             \u001b[0mx_all\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     59\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconcat\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mTypeError\u001b[0m: forward() takes 2 positional arguments but 3 were given"
     ]
    }
   ],
   "source": [
    "curr_step = 0\n",
    "for param in model.parameters():\n",
    "    print(param.shape)\n",
    "# print(model.parameters())\n",
    "# raise SystemExit(0)\n",
    "for epoch in range(args.epochs):\n",
    "    print(features_tensor.shape, support_tensor.shape)\n",
    "    output = model((features_tensor,support_tensor),training=True)\n",
    "    cross_loss = masked_softmax_cross_entropy(output, y_train_tensor,train_mask_tensor)\n",
    "\n",
    "    # L2 loss not sure though (TODO)\n",
    "    lossL2 = torch.sum(torch.Tensor([torch.sum(v**2) / 2 for v in model.parameters()]))\n",
    "    loss = cross_loss + args.weight_decay*lossL2\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    torch.nn.utils.clip_grad_value_(model.parameters(), clip_value_max)\n",
    "    optimizer.step()\n",
    "\n",
    "    train_acc = masked_accuracy(output, y_train_tensor,train_mask_tensor)\n",
    "    val_acc  = masked_accuracy(output, y_val_tensor,val_mask_tensor)\n",
    "    val_loss = masked_softmax_cross_entropy(output, y_val_tensor, val_mask_tensor)\n",
    "    test_acc  = masked_accuracy(output, y_test_tensor,test_mask_tensor)\n",
    "\n",
    "    if val_acc > best_val_acc:\n",
    "        curr_step = 0\n",
    "        best_test_acc = test_acc\n",
    "        best_val_acc = val_acc\n",
    "        best_val_loss= val_loss\n",
    "#         if args.save_model:\n",
    "#             torch.save(model.state_dict(), f'model_weights/GCN_{args.dataset}_{epoch}.pt')\n",
    "\n",
    "    else:\n",
    "        curr_step +=1\n",
    "    if curr_step > args.early_stop:\n",
    "        print(\"Early stopping...\")\n",
    "        break\n",
    "\n",
    "    print(\"Epoch:\", '%04d' % (epoch + 1), \"train_loss=\", \"{:.5f}\".format(cross_loss), \"train_acc=\",\n",
    "          \"{:.5f}\".format(train_acc), \"val_acc=\", \"{:.5f}\".format(val_acc), \"test_acc=\", \"{:.5f}\".format(test_acc),\n",
    "          \"best_test_acc=\", \"{:.5f}\".format(best_test_acc))\n",
    "\n",
    "# if not args.valid:\n",
    "#     torch.save(model.state_dict(), f'model_weights/GCN_{args.dataset}_{epoch}.pt')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
