{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "class config:\n",
    "    data = 'Cora' #choices=[\"Cora\", \"Citeseer\", \"Pubmed\", \"Coauthor-CS\"]\n",
    "    model = 'GTCN' #choices=[\"GCN\", \"GAT\", \"APPNP\", \"DAGNN\", \"TreeLSTM\", \"GTCN\", \"GTAN\", \"GTCN2\"]\n",
    "    n_in = 0\n",
    "    n_hid = 64\n",
    "    n_out = 0\n",
    "    num_heads = 1\n",
    "    num_out_heads = 1\n",
    "    device = 'cpu'\n",
    "    dropout = 0.6\n",
    "    dropout2 = 0.6\n",
    "    learning_rate = 0.01\n",
    "    weight_decay = 5e-4\n",
    "    patience = 200\n",
    "    num_iter = 1000\n",
    "    num_test = 30\n",
    "    hop = 10\n",
    "    alpha = 0.1 # used by APPNP only\n",
    "    random_label_split = False\n",
    "    num_train = 20 # for random label split only\n",
    "    num_val = 30 # for random label split only\n",
    "    data_load = True # load the saved label split to rerun the test (for reproduce purpose)\n",
    "    test_id = 1 # number of the test, only used to record the ith number of the random label split (for reproduce purpose)\n",
    "    filter_pct = 0.1 # remove the top and bottom filer_pct points before obtaining statistics of test accuracy\n",
    "    log = False # whether to show the training log or not\n",
    "    eval_metric = 'acc' # evaluation metrics, choices=[\"acc\", \"f1-macro\", \"f1-micro\"]\n",
    "    root_dir = '../..' # dir of the source code\n",
    "\n",
    "args = config()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append(args.root_dir)\n",
    "from train import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "start testing on Cora dataset\n",
      "  NumNodes: 2708\n",
      "  NumEdges: 10556\n",
      "  NumFeats: 1433\n",
      "  NumClasses: 7\n",
      "  NumTrainingSamples: 140\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "time duration = 9.789, best val acc = 0.836, test acc = 0.837\n",
      "time duration = 7.205, best val acc = 0.834, test acc = 0.847\n",
      "time duration = 7.525, best val acc = 0.828, test acc = 0.830\n",
      "time duration = 5.810, best val acc = 0.828, test acc = 0.838\n",
      "time duration = 6.920, best val acc = 0.814, test acc = 0.835\n",
      "time duration = 5.131, best val acc = 0.820, test acc = 0.823\n",
      "time duration = 7.152, best val acc = 0.822, test acc = 0.849\n",
      "time duration = 13.345, best val acc = 0.838, test acc = 0.838\n",
      "time duration = 8.851, best val acc = 0.838, test acc = 0.852\n",
      "time duration = 14.271, best val acc = 0.840, test acc = 0.851\n",
      "time duration = 5.907, best val acc = 0.826, test acc = 0.847\n",
      "time duration = 11.995, best val acc = 0.818, test acc = 0.850\n",
      "time duration = 15.269, best val acc = 0.834, test acc = 0.852\n",
      "time duration = 7.225, best val acc = 0.832, test acc = 0.846\n",
      "time duration = 7.160, best val acc = 0.836, test acc = 0.842\n",
      "time duration = 7.943, best val acc = 0.840, test acc = 0.859\n",
      "time duration = 6.571, best val acc = 0.832, test acc = 0.852\n",
      "time duration = 5.811, best val acc = 0.838, test acc = 0.831\n",
      "time duration = 9.366, best val acc = 0.838, test acc = 0.852\n",
      "time duration = 5.414, best val acc = 0.828, test acc = 0.837\n",
      "time duration = 8.337, best val acc = 0.838, test acc = 0.845\n",
      "time duration = 10.874, best val acc = 0.842, test acc = 0.847\n",
      "time duration = 14.154, best val acc = 0.836, test acc = 0.837\n",
      "time duration = 9.879, best val acc = 0.842, test acc = 0.848\n",
      "time duration = 6.029, best val acc = 0.816, test acc = 0.826\n",
      "time duration = 8.233, best val acc = 0.832, test acc = 0.849\n",
      "time duration = 8.835, best val acc = 0.832, test acc = 0.848\n",
      "time duration = 9.179, best val acc = 0.844, test acc = 0.853\n",
      "time duration = 9.047, best val acc = 0.836, test acc = 0.846\n",
      "time duration = 12.770, best val acc = 0.832, test acc = 0.850\n",
      "test acc (mean, std):  0.84390003879865 0.00863076357512145\n",
      "test acc (mean, std) after filter:  0.8447500392794609 0.006091323200614306\n",
      "\n",
      "start testing on Citeseer dataset\n",
      "  NumNodes: 3327\n",
      "  NumEdges: 9228\n",
      "  NumFeats: 3703\n",
      "  NumClasses: 6\n",
      "  NumTrainingSamples: 120\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "time duration = 8.854, best val acc = 0.734, test acc = 0.712\n",
      "time duration = 6.092, best val acc = 0.726, test acc = 0.706\n",
      "time duration = 7.501, best val acc = 0.724, test acc = 0.728\n",
      "time duration = 6.021, best val acc = 0.732, test acc = 0.720\n",
      "time duration = 8.451, best val acc = 0.734, test acc = 0.728\n",
      "time duration = 6.063, best val acc = 0.738, test acc = 0.710\n",
      "time duration = 9.133, best val acc = 0.738, test acc = 0.728\n",
      "time duration = 11.353, best val acc = 0.742, test acc = 0.723\n",
      "time duration = 4.813, best val acc = 0.748, test acc = 0.733\n",
      "time duration = 7.286, best val acc = 0.736, test acc = 0.734\n",
      "time duration = 6.934, best val acc = 0.738, test acc = 0.735\n",
      "time duration = 5.408, best val acc = 0.732, test acc = 0.741\n",
      "time duration = 5.249, best val acc = 0.746, test acc = 0.729\n",
      "time duration = 8.037, best val acc = 0.746, test acc = 0.736\n",
      "time duration = 10.790, best val acc = 0.722, test acc = 0.729\n",
      "time duration = 6.168, best val acc = 0.726, test acc = 0.724\n",
      "time duration = 7.925, best val acc = 0.736, test acc = 0.740\n",
      "time duration = 5.719, best val acc = 0.730, test acc = 0.727\n",
      "time duration = 5.805, best val acc = 0.722, test acc = 0.717\n",
      "time duration = 7.267, best val acc = 0.736, test acc = 0.724\n",
      "time duration = 5.914, best val acc = 0.734, test acc = 0.725\n",
      "time duration = 8.067, best val acc = 0.734, test acc = 0.732\n",
      "time duration = 12.007, best val acc = 0.734, test acc = 0.729\n",
      "time duration = 6.006, best val acc = 0.736, test acc = 0.723\n",
      "time duration = 6.273, best val acc = 0.736, test acc = 0.737\n",
      "time duration = 6.835, best val acc = 0.738, test acc = 0.734\n",
      "time duration = 6.481, best val acc = 0.728, test acc = 0.727\n",
      "time duration = 6.438, best val acc = 0.728, test acc = 0.740\n",
      "time duration = 5.297, best val acc = 0.748, test acc = 0.738\n",
      "time duration = 5.755, best val acc = 0.724, test acc = 0.730\n",
      "test acc (mean, std):  0.7279667019844055 0.008588689230950007\n",
      "test acc (mean, std) after filter:  0.7287500376502672 0.005317034572849062\n",
      "\n",
      "start testing on Pubmed dataset\n",
      "  NumNodes: 19717\n",
      "  NumEdges: 88651\n",
      "  NumFeats: 500\n",
      "  NumClasses: 3\n",
      "  NumTrainingSamples: 60\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "time duration = 16.431, best val acc = 0.820, test acc = 0.797\n",
      "time duration = 19.509, best val acc = 0.822, test acc = 0.793\n",
      "time duration = 11.802, best val acc = 0.810, test acc = 0.800\n",
      "time duration = 17.642, best val acc = 0.820, test acc = 0.793\n",
      "time duration = 21.028, best val acc = 0.820, test acc = 0.789\n",
      "time duration = 12.304, best val acc = 0.820, test acc = 0.795\n",
      "time duration = 11.207, best val acc = 0.820, test acc = 0.800\n",
      "time duration = 9.193, best val acc = 0.814, test acc = 0.789\n",
      "time duration = 10.833, best val acc = 0.824, test acc = 0.797\n",
      "time duration = 9.768, best val acc = 0.812, test acc = 0.784\n",
      "time duration = 11.128, best val acc = 0.818, test acc = 0.792\n",
      "time duration = 12.464, best val acc = 0.818, test acc = 0.791\n",
      "time duration = 16.487, best val acc = 0.820, test acc = 0.795\n",
      "time duration = 12.259, best val acc = 0.816, test acc = 0.792\n",
      "time duration = 17.375, best val acc = 0.820, test acc = 0.794\n",
      "time duration = 12.231, best val acc = 0.820, test acc = 0.787\n",
      "time duration = 12.051, best val acc = 0.820, test acc = 0.793\n",
      "time duration = 12.243, best val acc = 0.814, test acc = 0.794\n",
      "time duration = 14.045, best val acc = 0.820, test acc = 0.792\n",
      "time duration = 11.379, best val acc = 0.824, test acc = 0.790\n",
      "time duration = 11.012, best val acc = 0.812, test acc = 0.799\n",
      "time duration = 12.644, best val acc = 0.820, test acc = 0.786\n",
      "time duration = 9.213, best val acc = 0.818, test acc = 0.789\n",
      "time duration = 9.379, best val acc = 0.816, test acc = 0.781\n",
      "time duration = 20.889, best val acc = 0.820, test acc = 0.796\n",
      "time duration = 11.469, best val acc = 0.816, test acc = 0.779\n",
      "time duration = 22.510, best val acc = 0.818, test acc = 0.801\n",
      "time duration = 10.200, best val acc = 0.818, test acc = 0.790\n",
      "time duration = 9.635, best val acc = 0.814, test acc = 0.794\n",
      "time duration = 10.288, best val acc = 0.812, test acc = 0.789\n",
      "test acc (mean, std):  0.7920333683490753 0.005218450505056188\n",
      "test acc (mean, std) after filter:  0.7923333694537481 0.003236083387848514\n",
      "\n",
      "start testing on Coauthor-CS dataset with random split: 0\n",
      "time duration = 10.506, best val acc = 0.929, test acc = 0.930\n",
      "time duration = 10.588, best val acc = 0.927, test acc = 0.926\n",
      "time duration = 16.155, best val acc = 0.927, test acc = 0.926\n",
      "time duration = 10.625, best val acc = 0.929, test acc = 0.930\n",
      "time duration = 10.290, best val acc = 0.927, test acc = 0.922\n",
      "time duration = 12.834, best val acc = 0.924, test acc = 0.927\n",
      "time duration = 20.528, best val acc = 0.927, test acc = 0.925\n",
      "time duration = 11.710, best val acc = 0.931, test acc = 0.924\n",
      "time duration = 10.465, best val acc = 0.924, test acc = 0.925\n",
      "time duration = 9.552, best val acc = 0.927, test acc = 0.927\n",
      "time duration = 17.711, best val acc = 0.927, test acc = 0.927\n",
      "time duration = 14.767, best val acc = 0.929, test acc = 0.927\n",
      "time duration = 13.144, best val acc = 0.931, test acc = 0.930\n",
      "time duration = 10.934, best val acc = 0.936, test acc = 0.929\n",
      "time duration = 9.295, best val acc = 0.927, test acc = 0.925\n",
      "time duration = 10.112, best val acc = 0.933, test acc = 0.926\n",
      "time duration = 9.681, best val acc = 0.924, test acc = 0.924\n",
      "time duration = 11.664, best val acc = 0.922, test acc = 0.929\n",
      "time duration = 10.739, best val acc = 0.924, test acc = 0.926\n",
      "time duration = 15.705, best val acc = 0.929, test acc = 0.926\n",
      "time duration = 15.107, best val acc = 0.924, test acc = 0.927\n",
      "time duration = 11.715, best val acc = 0.936, test acc = 0.926\n",
      "time duration = 13.173, best val acc = 0.931, test acc = 0.929\n",
      "time duration = 10.501, best val acc = 0.931, test acc = 0.926\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time duration = 10.063, best val acc = 0.929, test acc = 0.926\n",
      "time duration = 12.963, best val acc = 0.929, test acc = 0.926\n",
      "time duration = 14.082, best val acc = 0.924, test acc = 0.930\n",
      "time duration = 11.638, best val acc = 0.931, test acc = 0.930\n",
      "time duration = 20.674, best val acc = 0.936, test acc = 0.924\n",
      "time duration = 10.712, best val acc = 0.929, test acc = 0.928\n",
      "test acc (mean, std):  0.9266696949799855 0.0020392781233601032\n",
      "test acc (mean, std) after filter:  0.9266810715198517 0.001451343285107006\n",
      "\n",
      "start testing on Coauthor-CS dataset with random split: 1\n",
      "time duration = 10.144, best val acc = 0.942, test acc = 0.927\n",
      "time duration = 11.580, best val acc = 0.933, test acc = 0.923\n",
      "time duration = 9.503, best val acc = 0.936, test acc = 0.927\n",
      "time duration = 9.552, best val acc = 0.938, test acc = 0.927\n",
      "time duration = 22.203, best val acc = 0.936, test acc = 0.928\n",
      "time duration = 17.441, best val acc = 0.936, test acc = 0.927\n",
      "time duration = 14.088, best val acc = 0.933, test acc = 0.926\n",
      "time duration = 11.312, best val acc = 0.936, test acc = 0.928\n",
      "time duration = 17.860, best val acc = 0.940, test acc = 0.923\n",
      "time duration = 14.614, best val acc = 0.933, test acc = 0.926\n",
      "time duration = 10.679, best val acc = 0.933, test acc = 0.927\n",
      "time duration = 10.363, best val acc = 0.936, test acc = 0.925\n",
      "time duration = 14.478, best val acc = 0.938, test acc = 0.923\n",
      "time duration = 11.247, best val acc = 0.936, test acc = 0.924\n",
      "time duration = 21.347, best val acc = 0.940, test acc = 0.926\n",
      "time duration = 11.069, best val acc = 0.933, test acc = 0.926\n",
      "time duration = 13.325, best val acc = 0.938, test acc = 0.922\n",
      "time duration = 14.828, best val acc = 0.936, test acc = 0.926\n",
      "time duration = 14.049, best val acc = 0.936, test acc = 0.926\n",
      "time duration = 10.328, best val acc = 0.938, test acc = 0.929\n",
      "time duration = 15.994, best val acc = 0.936, test acc = 0.925\n",
      "time duration = 11.495, best val acc = 0.936, test acc = 0.930\n",
      "time duration = 12.886, best val acc = 0.938, test acc = 0.925\n",
      "time duration = 9.563, best val acc = 0.933, test acc = 0.924\n",
      "time duration = 10.145, best val acc = 0.933, test acc = 0.924\n",
      "time duration = 17.124, best val acc = 0.940, test acc = 0.927\n",
      "time duration = 9.670, best val acc = 0.938, test acc = 0.922\n",
      "time duration = 11.881, best val acc = 0.938, test acc = 0.923\n",
      "time duration = 13.971, best val acc = 0.942, test acc = 0.923\n",
      "time duration = 15.251, best val acc = 0.940, test acc = 0.925\n",
      "test acc (mean, std):  0.9254658738772075 0.002012896416534765\n",
      "test acc (mean, std) after filter:  0.9254322250684103 0.001473250010027242\n",
      "\n",
      "start testing on Coauthor-CS dataset with random split: 2\n",
      "time duration = 23.729, best val acc = 0.931, test acc = 0.929\n",
      "time duration = 11.500, best val acc = 0.929, test acc = 0.926\n",
      "time duration = 13.309, best val acc = 0.931, test acc = 0.924\n",
      "time duration = 14.155, best val acc = 0.929, test acc = 0.924\n",
      "time duration = 19.114, best val acc = 0.927, test acc = 0.917\n",
      "time duration = 15.143, best val acc = 0.933, test acc = 0.928\n",
      "time duration = 13.044, best val acc = 0.927, test acc = 0.923\n",
      "time duration = 11.620, best val acc = 0.924, test acc = 0.925\n",
      "time duration = 18.678, best val acc = 0.927, test acc = 0.923\n",
      "time duration = 12.071, best val acc = 0.927, test acc = 0.926\n",
      "time duration = 9.444, best val acc = 0.927, test acc = 0.924\n",
      "time duration = 17.182, best val acc = 0.924, test acc = 0.925\n",
      "time duration = 17.722, best val acc = 0.929, test acc = 0.923\n",
      "time duration = 16.086, best val acc = 0.929, test acc = 0.920\n",
      "time duration = 13.543, best val acc = 0.927, test acc = 0.922\n",
      "time duration = 11.577, best val acc = 0.924, test acc = 0.922\n",
      "time duration = 23.278, best val acc = 0.929, test acc = 0.924\n",
      "time duration = 12.247, best val acc = 0.927, test acc = 0.924\n",
      "time duration = 13.667, best val acc = 0.927, test acc = 0.924\n",
      "time duration = 11.970, best val acc = 0.924, test acc = 0.922\n",
      "time duration = 10.853, best val acc = 0.922, test acc = 0.927\n",
      "time duration = 13.432, best val acc = 0.924, test acc = 0.925\n",
      "time duration = 16.611, best val acc = 0.927, test acc = 0.925\n",
      "time duration = 12.622, best val acc = 0.924, test acc = 0.923\n",
      "time duration = 11.934, best val acc = 0.924, test acc = 0.922\n",
      "time duration = 18.146, best val acc = 0.931, test acc = 0.927\n",
      "time duration = 12.860, best val acc = 0.929, test acc = 0.927\n",
      "time duration = 11.035, best val acc = 0.931, test acc = 0.928\n",
      "time duration = 18.175, best val acc = 0.929, test acc = 0.926\n",
      "time duration = 14.437, best val acc = 0.929, test acc = 0.924\n",
      "test acc (mean, std):  0.9242393136024475 0.002541332120022132\n",
      "test acc (mean, std) after filter:  0.924351637562116 0.0015702282660770583\n"
     ]
    }
   ],
   "source": [
    "accs = []\n",
    "datasets = ['Cora', 'Citeseer', 'Pubmed', 'Coauthor-CS']\n",
    "for dataset in datasets:\n",
    "    args.data = dataset\n",
    "    if dataset == 'Cora':\n",
    "        args.dropout = 0.6\n",
    "        args.dropout2 = 0.6\n",
    "        args.weight_decay = 5e-4\n",
    "        args.learning_rate = 0.01\n",
    "        print('\\nstart testing on ' + dataset + ' dataset')\n",
    "        accs.append(main(args))\n",
    "    elif dataset == 'Citeseer':\n",
    "        args.dropout = 0.8\n",
    "        args.dropout2 = 0.6\n",
    "        args.weight_decay = 5e-4\n",
    "        args.learning_rate = 0.01\n",
    "        print('\\nstart testing on ' + dataset + ' dataset')\n",
    "        accs.append(main(args))\n",
    "    elif dataset == 'Pubmed':\n",
    "        args.dropout = 0.8\n",
    "        args.dropout2 = 0.5\n",
    "        args.weight_decay = 5e-4\n",
    "        args.learning_rate = 0.02\n",
    "        print('\\nstart testing on ' + dataset + ' dataset')\n",
    "        accs.append(main(args))\n",
    "    elif dataset == 'Coauthor-CS':\n",
    "        args.dropout = 0.6\n",
    "        args.dropout2 = 0.2\n",
    "        args.weight_decay = 5e-3\n",
    "        args.learning_rate = 0.01\n",
    "        for i in range(3):\n",
    "            print('\\nstart testing on ' + dataset + ' dataset with random split: ' + str(i))\n",
    "            args.test_id = i\n",
    "            accs.append(main(args))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "accs = np.array(accs)\n",
    "np.savetxt('acc_GTCN_hop10.txt',accs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
