{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "class config:\n",
    "    data = 'Cora' #choices=[\"Cora\", \"Citeseer\", \"Pubmed\", \"Coauthor-CS\"]\n",
    "    model = 'GTCN' #choices=[\"GCN\", \"GAT\", \"APPNP\", \"DAGNN\", \"TreeLSTM\", \"GTCN\", \"GTAN\", \"GTCN2\"]\n",
    "    n_in = 0\n",
    "    n_hid = 64\n",
    "    n_out = 0\n",
    "    num_heads = 1\n",
    "    num_out_heads = 1\n",
    "    device = 'cpu'\n",
    "    dropout = 0.6\n",
    "    dropout2 = 0.6\n",
    "    learning_rate = 0.01\n",
    "    weight_decay = 5e-4\n",
    "    patience = 200\n",
    "    num_iter = 1000\n",
    "    num_test = 30\n",
    "    hop = 5\n",
    "    alpha = 0.1 # used by APPNP only\n",
    "    random_label_split = False\n",
    "    num_train = 20 # for random label split only\n",
    "    num_val = 30 # for random label split only\n",
    "    data_load = True # load the saved label split to rerun the test (for reproduce purpose)\n",
    "    test_id = 1 # number of the test, only used to record the ith number of the random label split (for reproduce purpose)\n",
    "    filter_pct = 0.1 # remove the top and bottom filer_pct points before obtaining statistics of test accuracy\n",
    "    log = False # whether to show the training log or not\n",
    "    eval_metric = 'acc' # evaluation metrics, choices=[\"acc\", \"f1-macro\", \"f1-micro\"]\n",
    "    root_dir = '../..' # dir of the source code\n",
    "\n",
    "args = config()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append(args.root_dir)\n",
    "from train import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "start testing on Cora dataset\n",
      "  NumNodes: 2708\n",
      "  NumEdges: 10556\n",
      "  NumFeats: 1433\n",
      "  NumClasses: 7\n",
      "  NumTrainingSamples: 140\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "time duration = 3.682, best val acc = 0.818, test acc = 0.833\n",
      "time duration = 5.302, best val acc = 0.840, test acc = 0.850\n",
      "time duration = 6.501, best val acc = 0.830, test acc = 0.844\n",
      "time duration = 3.929, best val acc = 0.820, test acc = 0.847\n",
      "time duration = 5.763, best val acc = 0.828, test acc = 0.851\n",
      "time duration = 6.241, best val acc = 0.828, test acc = 0.843\n",
      "time duration = 4.720, best val acc = 0.818, test acc = 0.846\n",
      "time duration = 3.313, best val acc = 0.828, test acc = 0.836\n",
      "time duration = 8.785, best val acc = 0.842, test acc = 0.853\n",
      "time duration = 4.014, best val acc = 0.824, test acc = 0.851\n",
      "time duration = 5.703, best val acc = 0.832, test acc = 0.841\n",
      "time duration = 7.314, best val acc = 0.842, test acc = 0.849\n",
      "time duration = 10.669, best val acc = 0.830, test acc = 0.846\n",
      "time duration = 8.256, best val acc = 0.838, test acc = 0.847\n",
      "time duration = 4.047, best val acc = 0.826, test acc = 0.851\n",
      "time duration = 2.757, best val acc = 0.804, test acc = 0.820\n",
      "time duration = 3.684, best val acc = 0.834, test acc = 0.850\n",
      "time duration = 6.348, best val acc = 0.830, test acc = 0.844\n",
      "time duration = 3.609, best val acc = 0.836, test acc = 0.847\n",
      "time duration = 3.003, best val acc = 0.814, test acc = 0.828\n",
      "time duration = 3.388, best val acc = 0.822, test acc = 0.852\n",
      "time duration = 6.276, best val acc = 0.816, test acc = 0.848\n",
      "time duration = 2.733, best val acc = 0.826, test acc = 0.828\n",
      "time duration = 3.030, best val acc = 0.816, test acc = 0.831\n",
      "time duration = 4.280, best val acc = 0.824, test acc = 0.847\n",
      "time duration = 7.911, best val acc = 0.832, test acc = 0.838\n",
      "time duration = 9.650, best val acc = 0.834, test acc = 0.847\n",
      "time duration = 5.922, best val acc = 0.830, test acc = 0.841\n",
      "time duration = 3.340, best val acc = 0.826, test acc = 0.833\n",
      "time duration = 3.237, best val acc = 0.828, test acc = 0.836\n",
      "test acc (mean, std):  0.8426000416278839 0.00832905886151587\n",
      "test acc (mean, std) after filter:  0.8435833727320036 0.005964590998984945\n",
      "\n",
      "start testing on Citeseer dataset\n",
      "  NumNodes: 3327\n",
      "  NumEdges: 9228\n",
      "  NumFeats: 3703\n",
      "  NumClasses: 6\n",
      "  NumTrainingSamples: 120\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "time duration = 3.093, best val acc = 0.734, test acc = 0.728\n",
      "time duration = 4.730, best val acc = 0.726, test acc = 0.709\n",
      "time duration = 3.320, best val acc = 0.742, test acc = 0.736\n",
      "time duration = 4.601, best val acc = 0.726, test acc = 0.722\n",
      "time duration = 4.652, best val acc = 0.732, test acc = 0.737\n",
      "time duration = 2.829, best val acc = 0.732, test acc = 0.721\n",
      "time duration = 3.202, best val acc = 0.744, test acc = 0.733\n",
      "time duration = 5.032, best val acc = 0.736, test acc = 0.730\n",
      "time duration = 3.313, best val acc = 0.740, test acc = 0.727\n",
      "time duration = 4.160, best val acc = 0.728, test acc = 0.732\n",
      "time duration = 3.011, best val acc = 0.742, test acc = 0.731\n",
      "time duration = 3.785, best val acc = 0.734, test acc = 0.731\n",
      "time duration = 4.906, best val acc = 0.744, test acc = 0.741\n",
      "time duration = 4.654, best val acc = 0.728, test acc = 0.719\n",
      "time duration = 3.909, best val acc = 0.740, test acc = 0.722\n",
      "time duration = 2.924, best val acc = 0.732, test acc = 0.725\n",
      "time duration = 4.282, best val acc = 0.738, test acc = 0.728\n",
      "time duration = 2.888, best val acc = 0.730, test acc = 0.722\n",
      "time duration = 5.400, best val acc = 0.732, test acc = 0.730\n",
      "time duration = 3.429, best val acc = 0.720, test acc = 0.728\n",
      "time duration = 3.318, best val acc = 0.732, test acc = 0.739\n",
      "time duration = 3.452, best val acc = 0.728, test acc = 0.728\n",
      "time duration = 4.226, best val acc = 0.738, test acc = 0.719\n",
      "time duration = 3.805, best val acc = 0.730, test acc = 0.717\n",
      "time duration = 5.005, best val acc = 0.728, test acc = 0.725\n",
      "time duration = 3.077, best val acc = 0.732, test acc = 0.736\n",
      "time duration = 3.094, best val acc = 0.734, test acc = 0.732\n",
      "time duration = 2.777, best val acc = 0.716, test acc = 0.707\n",
      "time duration = 4.448, best val acc = 0.734, test acc = 0.728\n",
      "time duration = 3.790, best val acc = 0.734, test acc = 0.729\n",
      "test acc (mean, std):  0.7270667056242625 0.007784317505775022\n",
      "test acc (mean, std) after filter:  0.7275833735863367 0.004768967680872285\n",
      "\n",
      "start testing on Pubmed dataset\n",
      "  NumNodes: 19717\n",
      "  NumEdges: 88651\n",
      "  NumFeats: 500\n",
      "  NumClasses: 3\n",
      "  NumTrainingSamples: 60\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "time duration = 6.828, best val acc = 0.812, test acc = 0.790\n",
      "time duration = 8.953, best val acc = 0.812, test acc = 0.784\n",
      "time duration = 6.409, best val acc = 0.806, test acc = 0.779\n",
      "time duration = 4.918, best val acc = 0.800, test acc = 0.773\n",
      "time duration = 7.785, best val acc = 0.814, test acc = 0.785\n",
      "time duration = 11.731, best val acc = 0.818, test acc = 0.786\n",
      "time duration = 9.796, best val acc = 0.814, test acc = 0.792\n",
      "time duration = 4.957, best val acc = 0.812, test acc = 0.791\n",
      "time duration = 11.922, best val acc = 0.818, test acc = 0.782\n",
      "time duration = 8.793, best val acc = 0.814, test acc = 0.792\n",
      "time duration = 8.489, best val acc = 0.818, test acc = 0.791\n",
      "time duration = 10.291, best val acc = 0.818, test acc = 0.781\n",
      "time duration = 10.248, best val acc = 0.812, test acc = 0.783\n",
      "time duration = 11.559, best val acc = 0.816, test acc = 0.787\n",
      "time duration = 9.012, best val acc = 0.814, test acc = 0.789\n",
      "time duration = 10.008, best val acc = 0.818, test acc = 0.792\n",
      "time duration = 7.380, best val acc = 0.810, test acc = 0.780\n",
      "time duration = 10.542, best val acc = 0.818, test acc = 0.799\n",
      "time duration = 6.283, best val acc = 0.808, test acc = 0.774\n",
      "time duration = 10.565, best val acc = 0.818, test acc = 0.794\n",
      "time duration = 4.224, best val acc = 0.812, test acc = 0.772\n",
      "time duration = 7.028, best val acc = 0.812, test acc = 0.783\n",
      "time duration = 9.111, best val acc = 0.818, test acc = 0.794\n",
      "time duration = 7.007, best val acc = 0.808, test acc = 0.789\n",
      "time duration = 11.711, best val acc = 0.820, test acc = 0.788\n",
      "time duration = 5.761, best val acc = 0.814, test acc = 0.788\n",
      "time duration = 4.862, best val acc = 0.814, test acc = 0.792\n",
      "time duration = 5.957, best val acc = 0.814, test acc = 0.796\n",
      "time duration = 4.649, best val acc = 0.808, test acc = 0.771\n",
      "time duration = 4.865, best val acc = 0.808, test acc = 0.780\n",
      "test acc (mean, std):  0.7859000384807586 0.0071942108610064375\n",
      "test acc (mean, std) after filter:  0.7863333721955618 0.005112619333732436\n",
      "\n",
      "start testing on Coauthor-CS dataset with random split: 0\n",
      "time duration = 8.165, best val acc = 0.924, test acc = 0.922\n",
      "time duration = 5.944, best val acc = 0.924, test acc = 0.922\n",
      "time duration = 7.270, best val acc = 0.920, test acc = 0.926\n",
      "time duration = 9.490, best val acc = 0.918, test acc = 0.925\n",
      "time duration = 7.560, best val acc = 0.918, test acc = 0.922\n",
      "time duration = 7.084, best val acc = 0.924, test acc = 0.926\n",
      "time duration = 10.250, best val acc = 0.929, test acc = 0.921\n",
      "time duration = 9.202, best val acc = 0.918, test acc = 0.925\n",
      "time duration = 6.722, best val acc = 0.918, test acc = 0.922\n",
      "time duration = 6.167, best val acc = 0.916, test acc = 0.925\n",
      "time duration = 5.954, best val acc = 0.920, test acc = 0.922\n",
      "time duration = 6.858, best val acc = 0.920, test acc = 0.925\n",
      "time duration = 9.748, best val acc = 0.924, test acc = 0.922\n",
      "time duration = 7.224, best val acc = 0.920, test acc = 0.927\n",
      "time duration = 8.233, best val acc = 0.924, test acc = 0.928\n",
      "time duration = 7.296, best val acc = 0.922, test acc = 0.927\n",
      "time duration = 9.463, best val acc = 0.920, test acc = 0.921\n",
      "time duration = 6.922, best val acc = 0.922, test acc = 0.927\n",
      "time duration = 7.690, best val acc = 0.922, test acc = 0.926\n",
      "time duration = 6.554, best val acc = 0.916, test acc = 0.927\n",
      "time duration = 7.477, best val acc = 0.918, test acc = 0.923\n",
      "time duration = 7.694, best val acc = 0.920, test acc = 0.926\n",
      "time duration = 6.694, best val acc = 0.922, test acc = 0.922\n",
      "time duration = 8.238, best val acc = 0.922, test acc = 0.926\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time duration = 6.876, best val acc = 0.918, test acc = 0.919\n",
      "time duration = 12.011, best val acc = 0.918, test acc = 0.922\n",
      "time duration = 7.425, best val acc = 0.920, test acc = 0.927\n",
      "time duration = 6.426, best val acc = 0.913, test acc = 0.920\n",
      "time duration = 7.258, best val acc = 0.920, test acc = 0.926\n",
      "time duration = 10.916, best val acc = 0.924, test acc = 0.927\n",
      "test acc (mean, std):  0.924212779601415 0.002482767783476714\n",
      "test acc (mean, std) after filter:  0.9243350575367609 0.0020017591955747358\n",
      "\n",
      "start testing on Coauthor-CS dataset with random split: 1\n",
      "time duration = 16.724, best val acc = 0.931, test acc = 0.916\n",
      "time duration = 8.452, best val acc = 0.936, test acc = 0.925\n",
      "time duration = 10.145, best val acc = 0.931, test acc = 0.924\n",
      "time duration = 8.697, best val acc = 0.927, test acc = 0.917\n",
      "time duration = 7.564, best val acc = 0.931, test acc = 0.920\n",
      "time duration = 7.369, best val acc = 0.929, test acc = 0.919\n",
      "time duration = 8.536, best val acc = 0.931, test acc = 0.921\n",
      "time duration = 10.192, best val acc = 0.929, test acc = 0.925\n",
      "time duration = 6.290, best val acc = 0.929, test acc = 0.923\n",
      "time duration = 9.494, best val acc = 0.933, test acc = 0.921\n",
      "time duration = 8.976, best val acc = 0.929, test acc = 0.924\n",
      "time duration = 15.516, best val acc = 0.938, test acc = 0.920\n",
      "time duration = 9.760, best val acc = 0.931, test acc = 0.920\n",
      "time duration = 7.046, best val acc = 0.936, test acc = 0.918\n",
      "time duration = 7.848, best val acc = 0.933, test acc = 0.917\n",
      "time duration = 10.588, best val acc = 0.933, test acc = 0.917\n",
      "time duration = 7.276, best val acc = 0.929, test acc = 0.925\n",
      "time duration = 9.967, best val acc = 0.929, test acc = 0.917\n",
      "time duration = 7.207, best val acc = 0.933, test acc = 0.921\n",
      "time duration = 7.193, best val acc = 0.929, test acc = 0.921\n",
      "time duration = 7.376, best val acc = 0.933, test acc = 0.923\n",
      "time duration = 7.157, best val acc = 0.931, test acc = 0.923\n",
      "time duration = 7.866, best val acc = 0.933, test acc = 0.916\n",
      "time duration = 8.405, best val acc = 0.933, test acc = 0.921\n",
      "time duration = 8.482, best val acc = 0.931, test acc = 0.916\n",
      "time duration = 8.580, best val acc = 0.936, test acc = 0.915\n",
      "time duration = 9.080, best val acc = 0.931, test acc = 0.918\n",
      "time duration = 9.651, best val acc = 0.927, test acc = 0.921\n",
      "time duration = 6.575, best val acc = 0.927, test acc = 0.921\n",
      "time duration = 6.663, best val acc = 0.929, test acc = 0.919\n",
      "test acc (mean, std):  0.9201804717381795 0.0028985102270156363\n",
      "test acc (mean, std) after filter:  0.9201714669664701 0.002299307766730085\n",
      "\n",
      "start testing on Coauthor-CS dataset with random split: 2\n",
      "time duration = 10.066, best val acc = 0.922, test acc = 0.915\n",
      "time duration = 9.639, best val acc = 0.920, test acc = 0.919\n",
      "time duration = 6.759, best val acc = 0.922, test acc = 0.919\n",
      "time duration = 8.983, best val acc = 0.920, test acc = 0.920\n",
      "time duration = 8.150, best val acc = 0.924, test acc = 0.920\n",
      "time duration = 10.837, best val acc = 0.922, test acc = 0.922\n",
      "time duration = 8.053, best val acc = 0.916, test acc = 0.921\n",
      "time duration = 8.295, best val acc = 0.918, test acc = 0.918\n",
      "time duration = 7.284, best val acc = 0.918, test acc = 0.918\n",
      "time duration = 7.205, best val acc = 0.922, test acc = 0.920\n",
      "time duration = 7.981, best val acc = 0.920, test acc = 0.922\n",
      "time duration = 6.794, best val acc = 0.918, test acc = 0.924\n",
      "time duration = 7.251, best val acc = 0.922, test acc = 0.919\n",
      "time duration = 7.891, best val acc = 0.920, test acc = 0.919\n",
      "time duration = 11.038, best val acc = 0.922, test acc = 0.920\n",
      "time duration = 8.891, best val acc = 0.920, test acc = 0.923\n",
      "time duration = 10.244, best val acc = 0.922, test acc = 0.920\n",
      "time duration = 8.115, best val acc = 0.922, test acc = 0.915\n",
      "time duration = 6.669, best val acc = 0.920, test acc = 0.925\n",
      "time duration = 9.939, best val acc = 0.920, test acc = 0.923\n",
      "time duration = 7.562, best val acc = 0.918, test acc = 0.918\n",
      "time duration = 6.229, best val acc = 0.920, test acc = 0.918\n",
      "time duration = 8.509, best val acc = 0.918, test acc = 0.914\n",
      "time duration = 11.683, best val acc = 0.922, test acc = 0.920\n",
      "time duration = 6.841, best val acc = 0.922, test acc = 0.921\n",
      "time duration = 7.997, best val acc = 0.920, test acc = 0.919\n",
      "time duration = 6.391, best val acc = 0.922, test acc = 0.916\n",
      "time duration = 6.791, best val acc = 0.920, test acc = 0.919\n",
      "time duration = 10.567, best val acc = 0.920, test acc = 0.926\n",
      "time duration = 6.557, best val acc = 0.920, test acc = 0.918\n",
      "test acc (mean, std):  0.919706525405248 0.0027270078097157823\n",
      "test acc (mean, std) after filter:  0.9196832999587059 0.0016405525122699907\n"
     ]
    }
   ],
   "source": [
    "accs = []\n",
    "datasets = ['Cora', 'Citeseer', 'Pubmed', 'Coauthor-CS']\n",
    "for dataset in datasets:\n",
    "    args.data = dataset\n",
    "    if dataset == 'Cora':\n",
    "        args.dropout = 0.6\n",
    "        args.dropout2 = 0.6\n",
    "        args.weight_decay = 5e-4\n",
    "        args.learning_rate = 0.01\n",
    "        print('\\nstart testing on ' + dataset + ' dataset')\n",
    "        accs.append(main(args))\n",
    "    elif dataset == 'Citeseer':\n",
    "        args.dropout = 0.8\n",
    "        args.dropout2 = 0.6\n",
    "        args.weight_decay = 5e-4\n",
    "        args.learning_rate = 0.01\n",
    "        print('\\nstart testing on ' + dataset + ' dataset')\n",
    "        accs.append(main(args))\n",
    "    elif dataset == 'Pubmed':\n",
    "        args.dropout = 0.8\n",
    "        args.dropout2 = 0.5\n",
    "        args.weight_decay = 5e-4\n",
    "        args.learning_rate = 0.02\n",
    "        print('\\nstart testing on ' + dataset + ' dataset')\n",
    "        accs.append(main(args))\n",
    "    elif dataset == 'Coauthor-CS':\n",
    "        args.dropout = 0.6\n",
    "        args.dropout2 = 0.2\n",
    "        args.weight_decay = 5e-3\n",
    "        args.learning_rate = 0.01\n",
    "        for i in range(3):\n",
    "            print('\\nstart testing on ' + dataset + ' dataset with random split: ' + str(i))\n",
    "            args.test_id = i\n",
    "            accs.append(main(args))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "accs = np.array(accs)\n",
    "np.savetxt('acc_GTCN_hop5.txt',accs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
