{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "class config:\n",
    "    data = 'Cora' #choices=[\"Cora\", \"Citeseer\", \"Pubmed\", \"Coauthor-CS\"]\n",
    "    model = 'GAT' #choices=[\"GCN\", \"GAT\", \"APPNP\", \"DAGNN\", \"TreeLSTM\", \"GTCN\", \"GTAN\", \"GTCN2\"]\n",
    "    n_in = 0\n",
    "    n_hid = 64\n",
    "    n_out = 0\n",
    "    num_heads = 1\n",
    "    num_out_heads = 1\n",
    "    device = 'cpu'\n",
    "    dropout = 0.8\n",
    "    dropout2 = 0.8\n",
    "    learning_rate = 0.01\n",
    "    weight_decay = 5e-4\n",
    "    patience = 200\n",
    "    num_iter = 1000\n",
    "    num_test = 30\n",
    "    hop = 2\n",
    "    alpha = 0.1 # used by APPNP only\n",
    "    random_label_split = False\n",
    "    num_train = 20 # for random label split only\n",
    "    num_val = 30 # for random label split only\n",
    "    data_load = True # load the saved label split to rerun the test (for reproduce purpose)\n",
    "    test_id = 1 # number of the test, only used to record the ith number of the random label split (for reproduce purpose)\n",
    "    filter_pct = 0.1 # remove the top and bottom filer_pct points before obtaining statistics of test accuracy\n",
    "    log = False # whether to show the training log or not\n",
    "    eval_metric = 'acc' # evaluation metrics, choices=[\"acc\", \"f1-macro\", \"f1-micro\"]\n",
    "    root_dir = '../..' # dir of the source code\n",
    "\n",
    "args = config()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append(args.root_dir)\n",
    "from train import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "start testing on Cora dataset\n",
      "  NumNodes: 2708\n",
      "  NumEdges: 10556\n",
      "  NumFeats: 1433\n",
      "  NumClasses: 7\n",
      "  NumTrainingSamples: 140\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "time duration = 7.050, best val acc = 0.822, test acc = 0.826\n",
      "time duration = 4.955, best val acc = 0.828, test acc = 0.831\n",
      "time duration = 4.608, best val acc = 0.814, test acc = 0.817\n",
      "time duration = 8.104, best val acc = 0.830, test acc = 0.845\n",
      "time duration = 7.585, best val acc = 0.832, test acc = 0.833\n",
      "time duration = 5.510, best val acc = 0.828, test acc = 0.837\n",
      "time duration = 5.429, best val acc = 0.822, test acc = 0.849\n",
      "time duration = 6.674, best val acc = 0.826, test acc = 0.832\n",
      "time duration = 5.478, best val acc = 0.814, test acc = 0.826\n",
      "time duration = 9.544, best val acc = 0.822, test acc = 0.819\n",
      "time duration = 9.342, best val acc = 0.832, test acc = 0.852\n",
      "time duration = 6.192, best val acc = 0.824, test acc = 0.836\n",
      "time duration = 7.592, best val acc = 0.830, test acc = 0.829\n",
      "time duration = 7.499, best val acc = 0.818, test acc = 0.826\n",
      "time duration = 12.829, best val acc = 0.832, test acc = 0.823\n",
      "time duration = 9.052, best val acc = 0.840, test acc = 0.835\n",
      "time duration = 4.785, best val acc = 0.822, test acc = 0.823\n",
      "time duration = 8.364, best val acc = 0.824, test acc = 0.813\n",
      "time duration = 6.289, best val acc = 0.842, test acc = 0.837\n",
      "time duration = 5.764, best val acc = 0.830, test acc = 0.824\n",
      "time duration = 7.570, best val acc = 0.834, test acc = 0.831\n",
      "time duration = 5.193, best val acc = 0.810, test acc = 0.833\n",
      "time duration = 10.478, best val acc = 0.830, test acc = 0.829\n",
      "time duration = 7.220, best val acc = 0.836, test acc = 0.843\n",
      "time duration = 8.362, best val acc = 0.816, test acc = 0.833\n",
      "time duration = 8.683, best val acc = 0.822, test acc = 0.829\n",
      "time duration = 7.914, best val acc = 0.812, test acc = 0.831\n",
      "time duration = 8.767, best val acc = 0.822, test acc = 0.818\n",
      "time duration = 5.720, best val acc = 0.828, test acc = 0.836\n",
      "time duration = 6.198, best val acc = 0.812, test acc = 0.826\n",
      "test acc (mean, std):  0.8307333707809448 0.008899190787240014\n",
      "test acc (mean, std) after filter:  0.8303333719571432 0.005467071364008335\n",
      "\n",
      "start testing on Citeseer dataset\n",
      "  NumNodes: 3327\n",
      "  NumEdges: 9228\n",
      "  NumFeats: 3703\n",
      "  NumClasses: 6\n",
      "  NumTrainingSamples: 120\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "time duration = 7.137, best val acc = 0.732, test acc = 0.719\n",
      "time duration = 7.593, best val acc = 0.728, test acc = 0.702\n",
      "time duration = 7.665, best val acc = 0.726, test acc = 0.729\n",
      "time duration = 5.905, best val acc = 0.716, test acc = 0.704\n",
      "time duration = 6.517, best val acc = 0.730, test acc = 0.730\n",
      "time duration = 11.130, best val acc = 0.736, test acc = 0.703\n",
      "time duration = 6.984, best val acc = 0.744, test acc = 0.722\n",
      "time duration = 6.799, best val acc = 0.738, test acc = 0.730\n",
      "time duration = 7.793, best val acc = 0.718, test acc = 0.694\n",
      "time duration = 9.733, best val acc = 0.734, test acc = 0.718\n",
      "time duration = 6.051, best val acc = 0.722, test acc = 0.680\n",
      "time duration = 6.581, best val acc = 0.722, test acc = 0.708\n",
      "time duration = 8.499, best val acc = 0.734, test acc = 0.711\n",
      "time duration = 7.858, best val acc = 0.742, test acc = 0.725\n",
      "time duration = 10.194, best val acc = 0.726, test acc = 0.705\n",
      "time duration = 7.526, best val acc = 0.730, test acc = 0.716\n",
      "time duration = 6.956, best val acc = 0.736, test acc = 0.699\n",
      "time duration = 5.839, best val acc = 0.720, test acc = 0.700\n",
      "time duration = 9.535, best val acc = 0.732, test acc = 0.703\n",
      "time duration = 6.047, best val acc = 0.740, test acc = 0.714\n",
      "time duration = 5.598, best val acc = 0.712, test acc = 0.696\n",
      "time duration = 6.945, best val acc = 0.724, test acc = 0.727\n",
      "time duration = 11.355, best val acc = 0.724, test acc = 0.719\n",
      "time duration = 7.747, best val acc = 0.740, test acc = 0.712\n",
      "time duration = 8.489, best val acc = 0.726, test acc = 0.720\n",
      "time duration = 6.136, best val acc = 0.736, test acc = 0.713\n",
      "time duration = 5.911, best val acc = 0.728, test acc = 0.712\n",
      "time duration = 6.989, best val acc = 0.726, test acc = 0.702\n",
      "time duration = 10.403, best val acc = 0.730, test acc = 0.704\n",
      "time duration = 7.927, best val acc = 0.738, test acc = 0.691\n",
      "test acc (mean, std):  0.7102666993935903 0.01221183672054972\n",
      "test acc (mean, std) after filter:  0.7105833689371744 0.008635764860846423\n",
      "\n",
      "start testing on Pubmed dataset\n",
      "  NumNodes: 19717\n",
      "  NumEdges: 88651\n",
      "  NumFeats: 500\n",
      "  NumClasses: 3\n",
      "  NumTrainingSamples: 60\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "time duration = 19.178, best val acc = 0.806, test acc = 0.785\n",
      "time duration = 33.891, best val acc = 0.812, test acc = 0.766\n",
      "time duration = 26.893, best val acc = 0.806, test acc = 0.775\n",
      "time duration = 43.723, best val acc = 0.806, test acc = 0.777\n",
      "time duration = 21.231, best val acc = 0.808, test acc = 0.777\n",
      "time duration = 18.326, best val acc = 0.806, test acc = 0.783\n",
      "time duration = 24.509, best val acc = 0.806, test acc = 0.765\n",
      "time duration = 49.995, best val acc = 0.804, test acc = 0.776\n",
      "time duration = 27.834, best val acc = 0.806, test acc = 0.779\n",
      "time duration = 25.448, best val acc = 0.806, test acc = 0.773\n",
      "time duration = 23.027, best val acc = 0.812, test acc = 0.771\n",
      "time duration = 33.848, best val acc = 0.814, test acc = 0.768\n",
      "time duration = 34.909, best val acc = 0.814, test acc = 0.777\n",
      "time duration = 26.360, best val acc = 0.804, test acc = 0.777\n",
      "time duration = 32.948, best val acc = 0.808, test acc = 0.782\n",
      "time duration = 48.881, best val acc = 0.812, test acc = 0.776\n",
      "time duration = 34.702, best val acc = 0.806, test acc = 0.759\n",
      "time duration = 30.935, best val acc = 0.812, test acc = 0.781\n",
      "time duration = 28.115, best val acc = 0.812, test acc = 0.781\n",
      "time duration = 32.980, best val acc = 0.808, test acc = 0.769\n",
      "time duration = 23.570, best val acc = 0.804, test acc = 0.777\n",
      "time duration = 33.631, best val acc = 0.808, test acc = 0.774\n",
      "time duration = 31.089, best val acc = 0.808, test acc = 0.777\n",
      "time duration = 27.935, best val acc = 0.814, test acc = 0.773\n",
      "time duration = 23.951, best val acc = 0.812, test acc = 0.772\n",
      "time duration = 35.298, best val acc = 0.808, test acc = 0.783\n",
      "time duration = 30.078, best val acc = 0.816, test acc = 0.775\n",
      "time duration = 22.033, best val acc = 0.806, test acc = 0.775\n",
      "time duration = 30.418, best val acc = 0.804, test acc = 0.777\n",
      "time duration = 27.337, best val acc = 0.810, test acc = 0.787\n",
      "test acc (mean, std):  0.775566699107488 0.0060148874236890925\n",
      "test acc (mean, std) after filter:  0.7759166955947876 0.0037071626848310226\n",
      "\n",
      "start testing on Coauthor-CS dataset with random split: 0\n",
      "time duration = 30.160, best val acc = 0.869, test acc = 0.900\n",
      "time duration = 30.343, best val acc = 0.873, test acc = 0.899\n",
      "time duration = 62.958, best val acc = 0.873, test acc = 0.911\n",
      "time duration = 31.097, best val acc = 0.873, test acc = 0.900\n",
      "time duration = 36.170, best val acc = 0.880, test acc = 0.903\n",
      "time duration = 29.992, best val acc = 0.867, test acc = 0.904\n",
      "time duration = 36.447, best val acc = 0.880, test acc = 0.903\n",
      "time duration = 32.044, best val acc = 0.873, test acc = 0.907\n",
      "time duration = 30.211, best val acc = 0.873, test acc = 0.902\n",
      "time duration = 29.821, best val acc = 0.871, test acc = 0.900\n",
      "time duration = 31.323, best val acc = 0.858, test acc = 0.896\n",
      "time duration = 48.959, best val acc = 0.876, test acc = 0.910\n",
      "time duration = 39.536, best val acc = 0.876, test acc = 0.909\n",
      "time duration = 43.235, best val acc = 0.878, test acc = 0.908\n",
      "time duration = 32.009, best val acc = 0.869, test acc = 0.901\n",
      "time duration = 31.872, best val acc = 0.871, test acc = 0.906\n",
      "time duration = 34.221, best val acc = 0.873, test acc = 0.911\n",
      "time duration = 30.337, best val acc = 0.858, test acc = 0.901\n",
      "time duration = 32.134, best val acc = 0.871, test acc = 0.904\n",
      "time duration = 29.989, best val acc = 0.878, test acc = 0.905\n",
      "time duration = 31.184, best val acc = 0.867, test acc = 0.905\n",
      "time duration = 30.948, best val acc = 0.873, test acc = 0.900\n",
      "time duration = 30.591, best val acc = 0.864, test acc = 0.901\n",
      "time duration = 48.506, best val acc = 0.878, test acc = 0.902\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time duration = 31.295, best val acc = 0.878, test acc = 0.902\n",
      "time duration = 36.198, best val acc = 0.878, test acc = 0.910\n",
      "time duration = 32.331, best val acc = 0.880, test acc = 0.907\n",
      "time duration = 31.030, best val acc = 0.864, test acc = 0.903\n",
      "time duration = 33.589, best val acc = 0.878, test acc = 0.909\n",
      "time duration = 30.696, best val acc = 0.873, test acc = 0.901\n",
      "test acc (mean, std):  0.904026613632838 0.0038731554309239316\n",
      "test acc (mean, std) after filter:  0.9039413010080656 0.0030140921554298976\n",
      "\n",
      "start testing on Coauthor-CS dataset with random split: 1\n",
      "time duration = 48.912, best val acc = 0.911, test acc = 0.906\n",
      "time duration = 34.066, best val acc = 0.907, test acc = 0.906\n",
      "time duration = 30.552, best val acc = 0.902, test acc = 0.895\n",
      "time duration = 56.769, best val acc = 0.913, test acc = 0.897\n",
      "time duration = 34.721, best val acc = 0.916, test acc = 0.912\n",
      "time duration = 30.257, best val acc = 0.904, test acc = 0.900\n",
      "time duration = 31.418, best val acc = 0.900, test acc = 0.893\n",
      "time duration = 46.452, best val acc = 0.916, test acc = 0.906\n",
      "time duration = 31.130, best val acc = 0.904, test acc = 0.895\n",
      "time duration = 33.130, best val acc = 0.904, test acc = 0.906\n",
      "time duration = 29.954, best val acc = 0.898, test acc = 0.896\n",
      "time duration = 31.003, best val acc = 0.900, test acc = 0.891\n",
      "time duration = 54.681, best val acc = 0.916, test acc = 0.911\n",
      "time duration = 32.878, best val acc = 0.904, test acc = 0.902\n",
      "time duration = 36.900, best val acc = 0.907, test acc = 0.911\n",
      "time duration = 52.866, best val acc = 0.916, test acc = 0.906\n",
      "time duration = 34.817, best val acc = 0.909, test acc = 0.900\n",
      "time duration = 30.756, best val acc = 0.907, test acc = 0.881\n",
      "time duration = 32.666, best val acc = 0.909, test acc = 0.908\n",
      "time duration = 54.334, best val acc = 0.909, test acc = 0.905\n",
      "time duration = 32.030, best val acc = 0.907, test acc = 0.905\n",
      "time duration = 30.224, best val acc = 0.902, test acc = 0.903\n",
      "time duration = 30.740, best val acc = 0.907, test acc = 0.893\n",
      "time duration = 31.055, best val acc = 0.907, test acc = 0.901\n",
      "time duration = 41.081, best val acc = 0.909, test acc = 0.910\n",
      "time duration = 51.191, best val acc = 0.913, test acc = 0.908\n",
      "time duration = 31.649, best val acc = 0.909, test acc = 0.912\n",
      "time duration = 55.659, best val acc = 0.909, test acc = 0.899\n",
      "time duration = 31.277, best val acc = 0.907, test acc = 0.908\n",
      "time duration = 31.281, best val acc = 0.904, test acc = 0.899\n",
      "test acc (mean, std):  0.9021858155727387 0.007123920029858689\n",
      "test acc (mean, std) after filter:  0.9026853491862615 0.004986345763845884\n",
      "\n",
      "start testing on Coauthor-CS dataset with random split: 2\n",
      "time duration = 31.023, best val acc = 0.867, test acc = 0.895\n",
      "time duration = 30.984, best val acc = 0.873, test acc = 0.890\n",
      "time duration = 52.860, best val acc = 0.893, test acc = 0.895\n",
      "time duration = 46.210, best val acc = 0.893, test acc = 0.895\n",
      "time duration = 30.408, best val acc = 0.862, test acc = 0.873\n",
      "time duration = 33.768, best val acc = 0.898, test acc = 0.901\n",
      "time duration = 32.187, best val acc = 0.887, test acc = 0.897\n",
      "time duration = 46.042, best val acc = 0.898, test acc = 0.904\n",
      "time duration = 47.689, best val acc = 0.902, test acc = 0.908\n",
      "time duration = 36.513, best val acc = 0.893, test acc = 0.896\n",
      "time duration = 53.607, best val acc = 0.900, test acc = 0.895\n",
      "time duration = 31.191, best val acc = 0.882, test acc = 0.894\n",
      "time duration = 30.880, best val acc = 0.873, test acc = 0.894\n",
      "time duration = 31.642, best val acc = 0.882, test acc = 0.899\n",
      "time duration = 36.395, best val acc = 0.900, test acc = 0.894\n",
      "time duration = 31.420, best val acc = 0.880, test acc = 0.897\n",
      "time duration = 31.017, best val acc = 0.871, test acc = 0.896\n",
      "time duration = 30.863, best val acc = 0.869, test acc = 0.887\n",
      "time duration = 31.178, best val acc = 0.889, test acc = 0.907\n",
      "time duration = 34.325, best val acc = 0.900, test acc = 0.887\n",
      "time duration = 45.598, best val acc = 0.902, test acc = 0.891\n",
      "time duration = 31.482, best val acc = 0.867, test acc = 0.895\n",
      "time duration = 37.108, best val acc = 0.893, test acc = 0.908\n",
      "time duration = 30.112, best val acc = 0.864, test acc = 0.890\n",
      "time duration = 29.730, best val acc = 0.867, test acc = 0.887\n",
      "time duration = 36.556, best val acc = 0.893, test acc = 0.899\n",
      "time duration = 30.226, best val acc = 0.858, test acc = 0.891\n",
      "time duration = 31.551, best val acc = 0.889, test acc = 0.894\n",
      "time duration = 30.536, best val acc = 0.873, test acc = 0.892\n",
      "time duration = 30.317, best val acc = 0.869, test acc = 0.893\n",
      "test acc (mean, std):  0.8948169569174449 0.006805450641381312\n",
      "test acc (mean, std) after filter:  0.8947681412100792 0.0036730358332925134\n"
     ]
    }
   ],
   "source": [
    "accs = []\n",
    "datasets = ['Cora', 'Citeseer', 'Pubmed', 'Coauthor-CS']\n",
    "for dataset in datasets:\n",
    "    args.data = dataset\n",
    "    if dataset == 'Cora':\n",
    "        args.dropout = 0.8\n",
    "        args.dropout2 = 0.8\n",
    "        print('\\nstart testing on ' + dataset + ' dataset')\n",
    "        accs.append(main(args))\n",
    "    elif dataset == 'Citeseer':\n",
    "        args.dropout = 0.8\n",
    "        args.dropout2 = 0.8\n",
    "        print('\\nstart testing on ' + dataset + ' dataset')\n",
    "        accs.append(main(args))\n",
    "    elif dataset == 'Pubmed':\n",
    "        args.dropout = 0.8\n",
    "        args.dropout2 = 0.2\n",
    "        print('\\nstart testing on ' + dataset + ' dataset')\n",
    "        accs.append(main(args))\n",
    "    elif dataset == 'Coauthor-CS':\n",
    "        args.dropout = 0.8\n",
    "        args.dropout2 = 0.2\n",
    "        for i in range(3):\n",
    "            print('\\nstart testing on ' + dataset + ' dataset with random split: ' + str(i))\n",
    "            args.test_id = i\n",
    "            accs.append(main(args))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "accs = np.array(accs)\n",
    "np.savetxt('acc_GAT_hop2.txt',accs)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
