{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "class config:\n",
    "    data = 'Cora' #choices=[\"Cora\", \"Citeseer\", \"Pubmed\", \"Coauthor-CS\"]\n",
    "    model = 'GTCN' #choices=[\"GCN\", \"GAT\", \"APPNP\", \"DAGNN\", \"TreeLSTM\", \"GTCN\", \"GTAN\", \"GTCN2\"]\n",
    "    n_in = 0\n",
    "    n_hid = 64\n",
    "    n_out = 0\n",
    "    num_heads = 1\n",
    "    num_out_heads = 1\n",
    "    device = 'cpu'\n",
    "    dropout = 0.6\n",
    "    dropout2 = 0.6\n",
    "    learning_rate = 0.01\n",
    "    weight_decay = 5e-4\n",
    "    patience = 200\n",
    "    num_iter = 1000\n",
    "    num_test = 30\n",
    "    hop = 2\n",
    "    alpha = 0.1 # used by APPNP only\n",
    "    random_label_split = False\n",
    "    num_train = 20 # for random label split only\n",
    "    num_val = 30 # for random label split only\n",
    "    data_load = True # load the saved label split to rerun the test (for reproduce purpose)\n",
    "    test_id = 1 # number of the test, only used to record the ith number of the random label split (for reproduce purpose)\n",
    "    filter_pct = 0.1 # remove the top and bottom filer_pct points before obtaining statistics of test accuracy\n",
    "    log = False # whether to show the training log or not\n",
    "    eval_metric = 'acc' # evaluation metrics, choices=[\"acc\", \"f1-macro\", \"f1-micro\"]\n",
    "    root_dir = '../..' # dir of the source code\n",
    "\n",
    "args = config()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append(args.root_dir)\n",
    "from train import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "start testing on Cora dataset\n",
      "  NumNodes: 2708\n",
      "  NumEdges: 10556\n",
      "  NumFeats: 1433\n",
      "  NumClasses: 7\n",
      "  NumTrainingSamples: 140\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "time duration = 3.494, best val acc = 0.820, test acc = 0.830\n",
      "time duration = 3.283, best val acc = 0.818, test acc = 0.832\n",
      "time duration = 5.549, best val acc = 0.826, test acc = 0.837\n",
      "time duration = 4.793, best val acc = 0.826, test acc = 0.829\n",
      "time duration = 2.509, best val acc = 0.820, test acc = 0.840\n",
      "time duration = 2.353, best val acc = 0.812, test acc = 0.836\n",
      "time duration = 2.346, best val acc = 0.814, test acc = 0.836\n",
      "time duration = 1.447, best val acc = 0.812, test acc = 0.828\n",
      "time duration = 2.585, best val acc = 0.820, test acc = 0.836\n",
      "time duration = 2.895, best val acc = 0.816, test acc = 0.831\n",
      "time duration = 5.201, best val acc = 0.822, test acc = 0.833\n",
      "time duration = 2.129, best val acc = 0.816, test acc = 0.831\n",
      "time duration = 2.854, best val acc = 0.812, test acc = 0.831\n",
      "time duration = 2.503, best val acc = 0.814, test acc = 0.829\n",
      "time duration = 1.789, best val acc = 0.810, test acc = 0.826\n",
      "time duration = 1.880, best val acc = 0.816, test acc = 0.829\n",
      "time duration = 2.749, best val acc = 0.822, test acc = 0.841\n",
      "time duration = 1.729, best val acc = 0.820, test acc = 0.832\n",
      "time duration = 2.822, best val acc = 0.808, test acc = 0.839\n",
      "time duration = 3.058, best val acc = 0.826, test acc = 0.835\n",
      "time duration = 1.731, best val acc = 0.814, test acc = 0.825\n",
      "time duration = 1.787, best val acc = 0.814, test acc = 0.829\n",
      "time duration = 2.073, best val acc = 0.816, test acc = 0.833\n",
      "time duration = 2.075, best val acc = 0.818, test acc = 0.833\n",
      "time duration = 3.339, best val acc = 0.828, test acc = 0.845\n",
      "time duration = 1.616, best val acc = 0.818, test acc = 0.822\n",
      "time duration = 2.357, best val acc = 0.816, test acc = 0.836\n",
      "time duration = 2.339, best val acc = 0.822, test acc = 0.836\n",
      "time duration = 3.246, best val acc = 0.822, test acc = 0.831\n",
      "time duration = 2.407, best val acc = 0.820, test acc = 0.829\n",
      "test acc (mean, std):  0.8326667050520579 0.004901241062525223\n",
      "test acc (mean, std) after filter:  0.8325417066613833 0.003081916700366099\n",
      "\n",
      "start testing on Citeseer dataset\n",
      "  NumNodes: 3327\n",
      "  NumEdges: 9228\n",
      "  NumFeats: 3703\n",
      "  NumClasses: 6\n",
      "  NumTrainingSamples: 120\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "time duration = 2.842, best val acc = 0.732, test acc = 0.728\n",
      "time duration = 2.012, best val acc = 0.728, test acc = 0.712\n",
      "time duration = 2.150, best val acc = 0.728, test acc = 0.722\n",
      "time duration = 1.982, best val acc = 0.736, test acc = 0.711\n",
      "time duration = 1.959, best val acc = 0.726, test acc = 0.733\n",
      "time duration = 1.450, best val acc = 0.734, test acc = 0.698\n",
      "time duration = 3.189, best val acc = 0.734, test acc = 0.718\n",
      "time duration = 2.276, best val acc = 0.734, test acc = 0.714\n",
      "time duration = 2.176, best val acc = 0.742, test acc = 0.719\n",
      "time duration = 1.588, best val acc = 0.748, test acc = 0.720\n",
      "time duration = 1.944, best val acc = 0.730, test acc = 0.727\n",
      "time duration = 2.317, best val acc = 0.736, test acc = 0.727\n",
      "time duration = 2.075, best val acc = 0.728, test acc = 0.730\n",
      "time duration = 2.313, best val acc = 0.728, test acc = 0.730\n",
      "time duration = 1.792, best val acc = 0.720, test acc = 0.716\n",
      "time duration = 2.294, best val acc = 0.734, test acc = 0.728\n",
      "time duration = 1.705, best val acc = 0.740, test acc = 0.716\n",
      "time duration = 2.528, best val acc = 0.730, test acc = 0.710\n",
      "time duration = 1.621, best val acc = 0.746, test acc = 0.712\n",
      "time duration = 1.780, best val acc = 0.748, test acc = 0.726\n",
      "time duration = 2.350, best val acc = 0.728, test acc = 0.708\n",
      "time duration = 2.630, best val acc = 0.724, test acc = 0.721\n",
      "time duration = 1.696, best val acc = 0.734, test acc = 0.731\n",
      "time duration = 1.933, best val acc = 0.740, test acc = 0.731\n",
      "time duration = 1.832, best val acc = 0.730, test acc = 0.728\n",
      "time duration = 2.195, best val acc = 0.742, test acc = 0.723\n",
      "time duration = 2.154, best val acc = 0.740, test acc = 0.741\n",
      "time duration = 2.516, best val acc = 0.736, test acc = 0.727\n",
      "time duration = 2.142, best val acc = 0.728, test acc = 0.725\n",
      "time duration = 1.984, best val acc = 0.728, test acc = 0.710\n",
      "test acc (mean, std):  0.7214000324408213 0.009160063621520656\n",
      "test acc (mean, std) after filter:  0.7217083672682444 0.006604793636676623\n",
      "\n",
      "start testing on Pubmed dataset\n",
      "  NumNodes: 19717\n",
      "  NumEdges: 88651\n",
      "  NumFeats: 500\n",
      "  NumClasses: 3\n",
      "  NumTrainingSamples: 60\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "time duration = 3.470, best val acc = 0.816, test acc = 0.780\n",
      "time duration = 4.237, best val acc = 0.818, test acc = 0.792\n",
      "time duration = 2.591, best val acc = 0.810, test acc = 0.783\n",
      "time duration = 4.535, best val acc = 0.808, test acc = 0.784\n",
      "time duration = 4.546, best val acc = 0.814, test acc = 0.790\n",
      "time duration = 5.966, best val acc = 0.814, test acc = 0.788\n",
      "time duration = 6.113, best val acc = 0.820, test acc = 0.780\n",
      "time duration = 2.682, best val acc = 0.816, test acc = 0.766\n",
      "time duration = 3.074, best val acc = 0.814, test acc = 0.786\n",
      "time duration = 5.032, best val acc = 0.808, test acc = 0.786\n",
      "time duration = 2.587, best val acc = 0.814, test acc = 0.791\n",
      "time duration = 2.693, best val acc = 0.808, test acc = 0.783\n",
      "time duration = 3.173, best val acc = 0.812, test acc = 0.777\n",
      "time duration = 3.590, best val acc = 0.816, test acc = 0.784\n",
      "time duration = 3.776, best val acc = 0.812, test acc = 0.780\n",
      "time duration = 3.593, best val acc = 0.810, test acc = 0.794\n",
      "time duration = 5.545, best val acc = 0.816, test acc = 0.792\n",
      "time duration = 4.386, best val acc = 0.816, test acc = 0.793\n",
      "time duration = 3.798, best val acc = 0.810, test acc = 0.783\n",
      "time duration = 4.139, best val acc = 0.812, test acc = 0.792\n",
      "time duration = 3.878, best val acc = 0.812, test acc = 0.782\n",
      "time duration = 3.499, best val acc = 0.810, test acc = 0.785\n",
      "time duration = 3.213, best val acc = 0.808, test acc = 0.776\n",
      "time duration = 2.445, best val acc = 0.810, test acc = 0.769\n",
      "time duration = 2.342, best val acc = 0.812, test acc = 0.796\n",
      "time duration = 2.495, best val acc = 0.810, test acc = 0.769\n",
      "time duration = 3.145, best val acc = 0.810, test acc = 0.775\n",
      "time duration = 3.033, best val acc = 0.810, test acc = 0.788\n",
      "time duration = 2.958, best val acc = 0.812, test acc = 0.792\n",
      "time duration = 4.850, best val acc = 0.816, test acc = 0.784\n",
      "test acc (mean, std):  0.784000039100647 0.007593859081783433\n",
      "test acc (mean, std) after filter:  0.7847083707650503 0.005078383222382145\n",
      "\n",
      "start testing on Coauthor-CS dataset with random split: 0\n",
      "time duration = 4.966, best val acc = 0.918, test acc = 0.923\n",
      "time duration = 7.413, best val acc = 0.920, test acc = 0.924\n",
      "time duration = 4.975, best val acc = 0.918, test acc = 0.926\n",
      "time duration = 5.600, best val acc = 0.916, test acc = 0.926\n",
      "time duration = 8.998, best val acc = 0.924, test acc = 0.927\n",
      "time duration = 5.361, best val acc = 0.918, test acc = 0.930\n",
      "time duration = 10.021, best val acc = 0.924, test acc = 0.928\n",
      "time duration = 7.445, best val acc = 0.927, test acc = 0.921\n",
      "time duration = 7.071, best val acc = 0.916, test acc = 0.918\n",
      "time duration = 9.046, best val acc = 0.922, test acc = 0.928\n",
      "time duration = 9.361, best val acc = 0.924, test acc = 0.927\n",
      "time duration = 4.343, best val acc = 0.916, test acc = 0.928\n",
      "time duration = 9.343, best val acc = 0.922, test acc = 0.928\n",
      "time duration = 7.671, best val acc = 0.918, test acc = 0.928\n",
      "time duration = 6.137, best val acc = 0.922, test acc = 0.928\n",
      "time duration = 4.862, best val acc = 0.922, test acc = 0.925\n",
      "time duration = 8.994, best val acc = 0.918, test acc = 0.927\n",
      "time duration = 8.674, best val acc = 0.922, test acc = 0.927\n",
      "time duration = 9.873, best val acc = 0.916, test acc = 0.930\n",
      "time duration = 5.953, best val acc = 0.920, test acc = 0.929\n",
      "time duration = 4.422, best val acc = 0.920, test acc = 0.929\n",
      "time duration = 4.337, best val acc = 0.916, test acc = 0.927\n",
      "time duration = 4.476, best val acc = 0.916, test acc = 0.927\n",
      "time duration = 5.598, best val acc = 0.920, test acc = 0.926\n",
      "time duration = 4.658, best val acc = 0.918, test acc = 0.931\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time duration = 5.971, best val acc = 0.924, test acc = 0.924\n",
      "time duration = 6.342, best val acc = 0.924, test acc = 0.927\n",
      "time duration = 4.741, best val acc = 0.918, test acc = 0.929\n",
      "time duration = 7.726, best val acc = 0.918, test acc = 0.927\n",
      "time duration = 7.316, best val acc = 0.918, test acc = 0.924\n",
      "test acc (mean, std):  0.9266298870245616 0.002571971260689455\n",
      "test acc (mean, std) after filter:  0.926910936832428 0.0012879681772163518\n",
      "\n",
      "start testing on Coauthor-CS dataset with random split: 1\n",
      "time duration = 6.109, best val acc = 0.931, test acc = 0.925\n",
      "time duration = 6.658, best val acc = 0.931, test acc = 0.921\n",
      "time duration = 8.117, best val acc = 0.929, test acc = 0.926\n",
      "time duration = 5.251, best val acc = 0.929, test acc = 0.925\n",
      "time duration = 4.864, best val acc = 0.929, test acc = 0.925\n",
      "time duration = 5.575, best val acc = 0.931, test acc = 0.923\n",
      "time duration = 8.298, best val acc = 0.931, test acc = 0.918\n",
      "time duration = 7.803, best val acc = 0.931, test acc = 0.916\n",
      "time duration = 5.253, best val acc = 0.933, test acc = 0.920\n",
      "time duration = 6.761, best val acc = 0.931, test acc = 0.920\n",
      "time duration = 5.018, best val acc = 0.929, test acc = 0.921\n",
      "time duration = 7.453, best val acc = 0.931, test acc = 0.923\n",
      "time duration = 4.269, best val acc = 0.927, test acc = 0.926\n",
      "time duration = 5.383, best val acc = 0.927, test acc = 0.924\n",
      "time duration = 8.210, best val acc = 0.933, test acc = 0.921\n",
      "time duration = 4.815, best val acc = 0.929, test acc = 0.925\n",
      "time duration = 4.043, best val acc = 0.927, test acc = 0.925\n",
      "time duration = 6.695, best val acc = 0.931, test acc = 0.921\n",
      "time duration = 4.726, best val acc = 0.927, test acc = 0.921\n",
      "time duration = 4.797, best val acc = 0.936, test acc = 0.922\n",
      "time duration = 5.480, best val acc = 0.927, test acc = 0.922\n",
      "time duration = 8.348, best val acc = 0.933, test acc = 0.920\n",
      "time duration = 4.454, best val acc = 0.931, test acc = 0.923\n",
      "time duration = 6.529, best val acc = 0.933, test acc = 0.922\n",
      "time duration = 4.935, best val acc = 0.931, test acc = 0.922\n",
      "time duration = 8.789, best val acc = 0.931, test acc = 0.920\n",
      "time duration = 8.881, best val acc = 0.931, test acc = 0.916\n",
      "time duration = 8.460, best val acc = 0.933, test acc = 0.928\n",
      "time duration = 4.768, best val acc = 0.929, test acc = 0.926\n",
      "time duration = 7.976, best val acc = 0.931, test acc = 0.923\n",
      "test acc (mean, std):  0.9222639163335165 0.002870713493944935\n",
      "test acc (mean, std) after filter:  0.9223895122607549 0.0018825307951772213\n",
      "\n",
      "start testing on Coauthor-CS dataset with random split: 2\n",
      "time duration = 4.384, best val acc = 0.920, test acc = 0.919\n",
      "time duration = 5.444, best val acc = 0.920, test acc = 0.921\n",
      "time duration = 6.277, best val acc = 0.920, test acc = 0.924\n",
      "time duration = 7.110, best val acc = 0.924, test acc = 0.923\n",
      "time duration = 4.561, best val acc = 0.929, test acc = 0.923\n",
      "time duration = 6.904, best val acc = 0.922, test acc = 0.920\n",
      "time duration = 5.121, best val acc = 0.920, test acc = 0.919\n",
      "time duration = 6.566, best val acc = 0.922, test acc = 0.918\n",
      "time duration = 7.540, best val acc = 0.922, test acc = 0.923\n",
      "time duration = 4.973, best val acc = 0.924, test acc = 0.922\n",
      "time duration = 5.337, best val acc = 0.927, test acc = 0.918\n",
      "time duration = 4.095, best val acc = 0.920, test acc = 0.922\n",
      "time duration = 5.417, best val acc = 0.924, test acc = 0.921\n",
      "time duration = 6.967, best val acc = 0.922, test acc = 0.918\n",
      "time duration = 5.197, best val acc = 0.927, test acc = 0.919\n",
      "time duration = 4.696, best val acc = 0.920, test acc = 0.915\n",
      "time duration = 5.667, best val acc = 0.924, test acc = 0.918\n",
      "time duration = 4.338, best val acc = 0.920, test acc = 0.920\n",
      "time duration = 7.763, best val acc = 0.924, test acc = 0.918\n",
      "time duration = 8.491, best val acc = 0.927, test acc = 0.922\n",
      "time duration = 6.021, best val acc = 0.924, test acc = 0.923\n",
      "time duration = 6.738, best val acc = 0.924, test acc = 0.918\n",
      "time duration = 5.196, best val acc = 0.918, test acc = 0.917\n",
      "time duration = 4.426, best val acc = 0.922, test acc = 0.920\n",
      "time duration = 6.589, best val acc = 0.924, test acc = 0.924\n",
      "time duration = 4.607, best val acc = 0.918, test acc = 0.920\n",
      "time duration = 4.690, best val acc = 0.924, test acc = 0.924\n",
      "time duration = 8.304, best val acc = 0.922, test acc = 0.924\n",
      "time duration = 5.254, best val acc = 0.922, test acc = 0.923\n",
      "time duration = 4.164, best val acc = 0.927, test acc = 0.921\n",
      "test acc (mean, std):  0.920576689640681 0.002411747590684216\n",
      "test acc (mean, std) after filter:  0.9206074972947439 0.0019508534524734953\n"
     ]
    }
   ],
   "source": [
    "accs = []\n",
    "datasets = ['Cora', 'Citeseer', 'Pubmed', 'Coauthor-CS']\n",
    "for dataset in datasets:\n",
    "    args.data = dataset\n",
    "    if dataset == 'Cora':\n",
    "        args.dropout = 0.6\n",
    "        args.dropout2 = 0.6\n",
    "        args.weight_decay = 5e-4\n",
    "        args.learning_rate = 0.01\n",
    "        print('\\nstart testing on ' + dataset + ' dataset')\n",
    "        accs.append(main(args))\n",
    "    elif dataset == 'Citeseer':\n",
    "        args.dropout = 0.8\n",
    "        args.dropout2 = 0.6\n",
    "        args.weight_decay = 5e-4\n",
    "        args.learning_rate = 0.01\n",
    "        print('\\nstart testing on ' + dataset + ' dataset')\n",
    "        accs.append(main(args))\n",
    "    elif dataset == 'Pubmed':\n",
    "        args.dropout = 0.8\n",
    "        args.dropout2 = 0.5\n",
    "        args.weight_decay = 5e-4\n",
    "        args.learning_rate = 0.02\n",
    "        print('\\nstart testing on ' + dataset + ' dataset')\n",
    "        accs.append(main(args))\n",
    "    elif dataset == 'Coauthor-CS':\n",
    "        args.dropout = 0.6\n",
    "        args.dropout2 = 0.2\n",
    "        args.weight_decay = 5e-3\n",
    "        args.learning_rate = 0.01\n",
    "        for i in range(3):\n",
    "            print('\\nstart testing on ' + dataset + ' dataset with random split: ' + str(i))\n",
    "            args.test_id = i\n",
    "            accs.append(main(args))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "accs = np.array(accs)\n",
    "np.savetxt('acc_GTCN_hop2.txt',accs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
