{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('./codes/forgraph/')\n",
    "from config import args\n",
    "# import tensorflow as tf\n",
    "import torch\n",
    "import numpy as np\n",
    "from models import GCN\n",
    "from metrics import *\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "args.dataset = 'Mutagenicity'\n",
    "import pickle as pkl\n",
    "with open('./dataset/'+args.dataset +'.pkl','rb') as fin:\n",
    "    adjs,features,labels = pkl.load(fin)\n",
    "    \n",
    "order = np.random.permutation(adjs.shape[0])\n",
    "shuffle_adjs = adjs[order]\n",
    "shuffle_features = features[order]\n",
    "shuffle_labels = labels[order]\n",
    "\n",
    "train_split = int(adjs.shape[0] * 0.8)\n",
    "val_split = int(adjs.shape[0] * 0.9)\n",
    "\n",
    "train_adjs = shuffle_adjs[:train_split]\n",
    "train_features = shuffle_features[:train_split]\n",
    "train_labels = shuffle_labels[:train_split]\n",
    "train_ids = order[:train_split]\n",
    "\n",
    "val_adjs = shuffle_adjs[train_split:val_split]\n",
    "val_features = shuffle_features[train_split:val_split]\n",
    "val_labels = shuffle_labels[train_split:val_split]\n",
    "val_ids = order[train_split:val_split]\n",
    "\n",
    "test_adjs = shuffle_adjs[val_split:]\n",
    "test_features = shuffle_features[val_split:]\n",
    "test_labels = shuffle_labels[val_split:]\n",
    "test_ids = order[val_split:]\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_acc=0.87259 val_acc= 0.87558 test_acc= 0.87097\n"
     ]
    }
   ],
   "source": [
    "model = GCN(input_dim=train_features.shape[-1], output_dim=train_labels.shape[1])\n",
    "model.load_weights(args.save_path+args.dataset)\n",
    "\n",
    "optimizer = tf.keras.optimizers.Adam(learning_rate=args.lr)\n",
    "\n",
    "train_adjs_tensor = tf.convert_to_tensor(train_adjs,dtype=tf.float32)\n",
    "train_features_tensor = tf.convert_to_tensor(train_features,dtype=tf.float32)\n",
    "train_labels_tensor = tf.convert_to_tensor(train_labels,dtype=tf.float32)\n",
    "\n",
    "val_adjs_tensor = tf.convert_to_tensor(val_adjs,dtype=tf.float32)\n",
    "val_features_tensor = tf.convert_to_tensor(val_features,dtype=tf.float32)\n",
    "val_labels_tensor = tf.convert_to_tensor(val_labels,dtype=tf.float32)\n",
    "\n",
    "test_adjs_tensor = tf.convert_to_tensor(test_adjs,dtype=tf.float32)\n",
    "test_features_tensor = tf.convert_to_tensor(test_features,dtype=tf.float32)\n",
    "test_labels_tensor = tf.convert_to_tensor(test_labels,dtype=tf.float32)\n",
    "\n",
    "best_val_acc = 0\n",
    "best_val_loss = 10000\n",
    "clip_value_min = -2.0\n",
    "clip_value_max = 2.0\n",
    "\n",
    "output = model.call((train_features_tensor, train_adjs_tensor), training=False)\n",
    "train_acc = accuracy(output, train_labels_tensor)\n",
    "val_output = model.call((val_features_tensor, val_adjs_tensor), training=False)\n",
    "val_acc  = accuracy(val_output, val_labels_tensor)\n",
    "val_loss = softmax_cross_entropy(val_output, val_labels_tensor)\n",
    "\n",
    "test_output = model.call((test_features_tensor, test_adjs_tensor), training=False)\n",
    "test_acc  = accuracy(test_output, test_labels_tensor)\n",
    "test_loss = softmax_cross_entropy(test_output, test_labels_tensor)\n",
    "\n",
    "print(\"train_acc={:.5f}\".format(train_acc), \"val_acc=\", \"{:.5f}\".format(val_acc),\n",
    "      \"test_acc=\", \"{:.5f}\".format(test_acc))\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
