{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "class config:\n",
    "    data = 'Cora' #choices=[\"Cora\", \"Citeseer\", \"Pubmed\", \"Coauthor-CS\"]\n",
    "    model = 'GTAN' #choices=[\"GCN\", \"GAT\", \"APPNP\", \"DAGNN\", \"TreeLSTM\", \"GTCN\", \"GTAN\", \"GTCN2\"]\n",
    "    n_in = 0\n",
    "    n_hid = 64\n",
    "    n_out = 0\n",
    "    num_heads = 1\n",
    "    num_out_heads = 1\n",
    "    device = 'cpu'\n",
    "    dropout = 0.6\n",
    "    dropout2 = 0\n",
    "    learning_rate = 0.01\n",
    "    weight_decay = 5e-4\n",
    "    patience = 300\n",
    "    num_iter = 1000\n",
    "    num_test = 30\n",
    "    hop = 10\n",
    "    alpha = 0.1 # used by APPNP only\n",
    "    random_label_split = False\n",
    "    num_train = 20 # for random label split only\n",
    "    num_val = 30 # for random label split only\n",
    "    data_load = True # load the saved label split to rerun the test (for reproduce purpose)\n",
    "    test_id = 1 # number of the test, only used to record the ith number of the random label split (for reproduce purpose)\n",
    "    filter_pct = 0.1 # remove the top and bottom filer_pct points before obtaining statistics of test accuracy\n",
    "    log = False # whether to show the training log or not\n",
    "    eval_metric = 'acc' # evaluation metrics, choices=[\"acc\", \"f1-macro\", \"f1-micro\"]\n",
    "    root_dir = '../..' # dir of the source code\n",
    "\n",
    "args = config()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append(args.root_dir)\n",
    "from train import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "start testing on Cora dataset\n",
      "  NumNodes: 2708\n",
      "  NumEdges: 10556\n",
      "  NumFeats: 1433\n",
      "  NumClasses: 7\n",
      "  NumTrainingSamples: 140\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "time duration = 27.043, best val acc = 0.810, test acc = 0.824\n",
      "time duration = 31.746, best val acc = 0.834, test acc = 0.850\n",
      "time duration = 30.702, best val acc = 0.826, test acc = 0.850\n",
      "time duration = 27.659, best val acc = 0.816, test acc = 0.828\n",
      "time duration = 28.504, best val acc = 0.816, test acc = 0.839\n",
      "time duration = 19.788, best val acc = 0.806, test acc = 0.825\n",
      "time duration = 48.808, best val acc = 0.824, test acc = 0.845\n",
      "time duration = 28.332, best val acc = 0.822, test acc = 0.848\n",
      "time duration = 26.886, best val acc = 0.820, test acc = 0.837\n",
      "time duration = 27.481, best val acc = 0.826, test acc = 0.834\n",
      "time duration = 29.905, best val acc = 0.818, test acc = 0.828\n",
      "time duration = 19.271, best val acc = 0.816, test acc = 0.825\n",
      "time duration = 40.540, best val acc = 0.814, test acc = 0.842\n",
      "time duration = 33.494, best val acc = 0.826, test acc = 0.854\n",
      "time duration = 19.777, best val acc = 0.814, test acc = 0.820\n",
      "time duration = 19.364, best val acc = 0.818, test acc = 0.831\n",
      "time duration = 19.223, best val acc = 0.808, test acc = 0.820\n",
      "time duration = 26.550, best val acc = 0.818, test acc = 0.834\n",
      "time duration = 19.362, best val acc = 0.816, test acc = 0.822\n",
      "time duration = 18.331, best val acc = 0.818, test acc = 0.807\n",
      "time duration = 27.528, best val acc = 0.824, test acc = 0.847\n",
      "time duration = 32.711, best val acc = 0.808, test acc = 0.848\n",
      "time duration = 48.251, best val acc = 0.820, test acc = 0.835\n",
      "time duration = 18.537, best val acc = 0.814, test acc = 0.819\n",
      "time duration = 27.122, best val acc = 0.824, test acc = 0.833\n",
      "time duration = 20.173, best val acc = 0.810, test acc = 0.834\n",
      "time duration = 27.963, best val acc = 0.814, test acc = 0.848\n",
      "time duration = 34.842, best val acc = 0.808, test acc = 0.852\n",
      "time duration = 27.557, best val acc = 0.822, test acc = 0.818\n",
      "time duration = 19.225, best val acc = 0.822, test acc = 0.819\n",
      "test acc (mean, std):  0.8338667134443919 0.01228476677144218\n",
      "test acc (mean, std) after filter:  0.8340000485380491 0.009869984906546373\n",
      "\n",
      "start testing on Citeseer dataset\n",
      "  NumNodes: 3327\n",
      "  NumEdges: 9228\n",
      "  NumFeats: 3703\n",
      "  NumClasses: 6\n",
      "  NumTrainingSamples: 120\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "time duration = 20.115, best val acc = 0.732, test acc = 0.720\n",
      "time duration = 18.041, best val acc = 0.740, test acc = 0.714\n",
      "time duration = 27.418, best val acc = 0.730, test acc = 0.709\n",
      "time duration = 19.939, best val acc = 0.738, test acc = 0.713\n",
      "time duration = 20.417, best val acc = 0.734, test acc = 0.710\n",
      "time duration = 19.405, best val acc = 0.730, test acc = 0.705\n",
      "time duration = 18.932, best val acc = 0.742, test acc = 0.713\n",
      "time duration = 19.886, best val acc = 0.738, test acc = 0.705\n",
      "time duration = 18.262, best val acc = 0.730, test acc = 0.713\n",
      "time duration = 19.593, best val acc = 0.740, test acc = 0.708\n",
      "time duration = 18.986, best val acc = 0.728, test acc = 0.712\n",
      "time duration = 29.156, best val acc = 0.736, test acc = 0.723\n",
      "time duration = 19.933, best val acc = 0.740, test acc = 0.710\n",
      "time duration = 19.109, best val acc = 0.724, test acc = 0.716\n",
      "time duration = 19.736, best val acc = 0.746, test acc = 0.715\n",
      "time duration = 20.893, best val acc = 0.736, test acc = 0.712\n",
      "time duration = 19.238, best val acc = 0.742, test acc = 0.714\n",
      "time duration = 20.936, best val acc = 0.736, test acc = 0.727\n",
      "time duration = 17.737, best val acc = 0.742, test acc = 0.726\n",
      "time duration = 19.511, best val acc = 0.734, test acc = 0.711\n",
      "time duration = 18.398, best val acc = 0.728, test acc = 0.714\n",
      "time duration = 18.636, best val acc = 0.730, test acc = 0.710\n",
      "time duration = 18.233, best val acc = 0.736, test acc = 0.721\n",
      "time duration = 19.314, best val acc = 0.738, test acc = 0.699\n",
      "time duration = 49.134, best val acc = 0.738, test acc = 0.731\n",
      "time duration = 18.794, best val acc = 0.738, test acc = 0.711\n",
      "time duration = 20.790, best val acc = 0.732, test acc = 0.711\n",
      "time duration = 18.754, best val acc = 0.740, test acc = 0.714\n",
      "time duration = 18.671, best val acc = 0.746, test acc = 0.718\n",
      "time duration = 19.328, best val acc = 0.736, test acc = 0.719\n",
      "test acc (mean, std):  0.714133369922638 0.006701902851751179\n",
      "test acc (mean, std) after filter:  0.7137917031844457 0.0038619876530834787\n",
      "\n",
      "start testing on Pubmed dataset\n",
      "  NumNodes: 19717\n",
      "  NumEdges: 88651\n",
      "  NumFeats: 500\n",
      "  NumClasses: 3\n",
      "  NumTrainingSamples: 60\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "time duration = 40.701, best val acc = 0.822, test acc = 0.795\n",
      "time duration = 40.650, best val acc = 0.820, test acc = 0.792\n",
      "time duration = 38.615, best val acc = 0.822, test acc = 0.791\n",
      "time duration = 50.744, best val acc = 0.820, test acc = 0.795\n",
      "time duration = 44.441, best val acc = 0.822, test acc = 0.789\n",
      "time duration = 83.039, best val acc = 0.822, test acc = 0.797\n",
      "time duration = 38.531, best val acc = 0.820, test acc = 0.784\n",
      "time duration = 41.224, best val acc = 0.820, test acc = 0.793\n",
      "time duration = 45.347, best val acc = 0.820, test acc = 0.794\n",
      "time duration = 68.921, best val acc = 0.806, test acc = 0.797\n",
      "time duration = 74.350, best val acc = 0.818, test acc = 0.794\n",
      "time duration = 39.250, best val acc = 0.818, test acc = 0.793\n",
      "time duration = 81.400, best val acc = 0.820, test acc = 0.796\n",
      "time duration = 76.767, best val acc = 0.822, test acc = 0.784\n",
      "time duration = 46.080, best val acc = 0.818, test acc = 0.792\n",
      "time duration = 45.247, best val acc = 0.822, test acc = 0.793\n",
      "time duration = 67.937, best val acc = 0.820, test acc = 0.795\n",
      "time duration = 38.699, best val acc = 0.822, test acc = 0.795\n",
      "time duration = 39.811, best val acc = 0.814, test acc = 0.790\n",
      "time duration = 42.434, best val acc = 0.820, test acc = 0.796\n",
      "time duration = 52.751, best val acc = 0.820, test acc = 0.792\n",
      "time duration = 47.364, best val acc = 0.818, test acc = 0.806\n",
      "time duration = 50.509, best val acc = 0.818, test acc = 0.794\n",
      "time duration = 40.633, best val acc = 0.818, test acc = 0.787\n",
      "time duration = 40.662, best val acc = 0.818, test acc = 0.790\n",
      "time duration = 41.921, best val acc = 0.824, test acc = 0.795\n",
      "time duration = 43.864, best val acc = 0.818, test acc = 0.795\n",
      "time duration = 40.033, best val acc = 0.816, test acc = 0.796\n",
      "time duration = 52.576, best val acc = 0.820, test acc = 0.792\n",
      "time duration = 45.579, best val acc = 0.824, test acc = 0.797\n",
      "test acc (mean, std):  0.7933000385761261 0.004116230077575387\n",
      "test acc (mean, std) after filter:  0.7935000360012054 0.0021015891886059274\n",
      "\n",
      "start testing on Coauthor-CS dataset with random split: 0\n",
      "time duration = 66.406, best val acc = 0.911, test acc = 0.924\n",
      "time duration = 86.728, best val acc = 0.916, test acc = 0.916\n",
      "time duration = 53.788, best val acc = 0.907, test acc = 0.919\n",
      "time duration = 94.382, best val acc = 0.922, test acc = 0.928\n",
      "time duration = 50.839, best val acc = 0.900, test acc = 0.924\n",
      "time duration = 54.584, best val acc = 0.913, test acc = 0.921\n",
      "time duration = 54.163, best val acc = 0.902, test acc = 0.925\n",
      "time duration = 50.627, best val acc = 0.904, test acc = 0.920\n",
      "time duration = 94.330, best val acc = 0.916, test acc = 0.924\n",
      "time duration = 49.466, best val acc = 0.902, test acc = 0.919\n",
      "time duration = 51.609, best val acc = 0.909, test acc = 0.918\n",
      "time duration = 50.381, best val acc = 0.916, test acc = 0.922\n",
      "time duration = 50.356, best val acc = 0.911, test acc = 0.927\n",
      "time duration = 88.528, best val acc = 0.922, test acc = 0.922\n",
      "time duration = 51.461, best val acc = 0.893, test acc = 0.923\n",
      "time duration = 50.019, best val acc = 0.907, test acc = 0.922\n",
      "time duration = 59.076, best val acc = 0.916, test acc = 0.907\n",
      "time duration = 50.285, best val acc = 0.907, test acc = 0.924\n",
      "time duration = 53.358, best val acc = 0.902, test acc = 0.913\n",
      "time duration = 50.268, best val acc = 0.916, test acc = 0.927\n",
      "time duration = 50.327, best val acc = 0.907, test acc = 0.918\n",
      "time duration = 50.586, best val acc = 0.911, test acc = 0.924\n",
      "time duration = 53.493, best val acc = 0.916, test acc = 0.928\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time duration = 56.056, best val acc = 0.911, test acc = 0.923\n",
      "time duration = 58.150, best val acc = 0.916, test acc = 0.923\n",
      "time duration = 51.107, best val acc = 0.902, test acc = 0.923\n",
      "time duration = 55.508, best val acc = 0.918, test acc = 0.912\n",
      "time duration = 57.369, best val acc = 0.909, test acc = 0.917\n",
      "time duration = 58.330, best val acc = 0.913, test acc = 0.928\n",
      "time duration = 51.467, best val acc = 0.918, test acc = 0.918\n",
      "test acc (mean, std):  0.9213198304176331 0.004771730493584739\n",
      "test acc (mean, std) after filter:  0.9217757657170296 0.002920900285200657\n",
      "\n",
      "start testing on Coauthor-CS dataset with random split: 1\n",
      "time duration = 62.050, best val acc = 0.942, test acc = 0.925\n",
      "time duration = 91.996, best val acc = 0.942, test acc = 0.931\n",
      "time duration = 50.027, best val acc = 0.944, test acc = 0.927\n",
      "time duration = 159.919, best val acc = 0.942, test acc = 0.927\n",
      "time duration = 56.787, best val acc = 0.944, test acc = 0.932\n",
      "time duration = 50.364, best val acc = 0.933, test acc = 0.924\n",
      "time duration = 158.972, best val acc = 0.949, test acc = 0.924\n",
      "time duration = 49.631, best val acc = 0.931, test acc = 0.913\n",
      "time duration = 52.362, best val acc = 0.931, test acc = 0.922\n",
      "time duration = 51.343, best val acc = 0.938, test acc = 0.923\n",
      "time duration = 79.010, best val acc = 0.933, test acc = 0.924\n",
      "time duration = 49.962, best val acc = 0.931, test acc = 0.924\n",
      "time duration = 51.295, best val acc = 0.933, test acc = 0.918\n",
      "time duration = 51.707, best val acc = 0.933, test acc = 0.919\n",
      "time duration = 116.153, best val acc = 0.949, test acc = 0.925\n",
      "time duration = 50.599, best val acc = 0.933, test acc = 0.911\n",
      "time duration = 50.908, best val acc = 0.933, test acc = 0.920\n",
      "time duration = 61.135, best val acc = 0.944, test acc = 0.927\n",
      "time duration = 51.539, best val acc = 0.940, test acc = 0.922\n",
      "time duration = 61.207, best val acc = 0.940, test acc = 0.922\n",
      "time duration = 50.139, best val acc = 0.933, test acc = 0.923\n",
      "time duration = 51.010, best val acc = 0.936, test acc = 0.928\n",
      "time duration = 52.485, best val acc = 0.929, test acc = 0.926\n",
      "time duration = 50.692, best val acc = 0.942, test acc = 0.925\n",
      "time duration = 50.936, best val acc = 0.933, test acc = 0.921\n",
      "time duration = 54.719, best val acc = 0.944, test acc = 0.926\n",
      "time duration = 50.216, best val acc = 0.933, test acc = 0.927\n",
      "time duration = 69.108, best val acc = 0.936, test acc = 0.925\n",
      "time duration = 100.615, best val acc = 0.947, test acc = 0.929\n",
      "time duration = 64.620, best val acc = 0.944, test acc = 0.923\n",
      "test acc (mean, std):  0.9238165577252706 0.004400253859507569\n",
      "test acc (mean, std) after filter:  0.9241407364606857 0.00238079901058741\n",
      "\n",
      "start testing on Coauthor-CS dataset with random split: 2\n",
      "time duration = 99.880, best val acc = 0.913, test acc = 0.918\n",
      "time duration = 61.690, best val acc = 0.922, test acc = 0.920\n",
      "time duration = 56.371, best val acc = 0.922, test acc = 0.925\n",
      "time duration = 50.126, best val acc = 0.907, test acc = 0.918\n",
      "time duration = 62.412, best val acc = 0.911, test acc = 0.925\n",
      "time duration = 86.672, best val acc = 0.911, test acc = 0.920\n",
      "time duration = 50.834, best val acc = 0.916, test acc = 0.918\n",
      "time duration = 51.277, best val acc = 0.918, test acc = 0.924\n",
      "time duration = 51.184, best val acc = 0.907, test acc = 0.914\n",
      "time duration = 66.800, best val acc = 0.913, test acc = 0.925\n",
      "time duration = 51.663, best val acc = 0.916, test acc = 0.919\n",
      "time duration = 82.102, best val acc = 0.913, test acc = 0.919\n",
      "time duration = 64.140, best val acc = 0.916, test acc = 0.918\n",
      "time duration = 105.594, best val acc = 0.922, test acc = 0.924\n",
      "time duration = 57.338, best val acc = 0.916, test acc = 0.918\n",
      "time duration = 49.467, best val acc = 0.904, test acc = 0.915\n",
      "time duration = 160.016, best val acc = 0.920, test acc = 0.923\n",
      "time duration = 66.271, best val acc = 0.922, test acc = 0.924\n",
      "time duration = 52.353, best val acc = 0.913, test acc = 0.919\n",
      "time duration = 74.808, best val acc = 0.920, test acc = 0.924\n",
      "time duration = 48.905, best val acc = 0.864, test acc = 0.891\n",
      "time duration = 57.250, best val acc = 0.911, test acc = 0.917\n",
      "time duration = 77.912, best val acc = 0.918, test acc = 0.922\n",
      "time duration = 52.983, best val acc = 0.916, test acc = 0.927\n",
      "time duration = 61.836, best val acc = 0.918, test acc = 0.924\n",
      "time duration = 53.089, best val acc = 0.922, test acc = 0.917\n",
      "time duration = 51.307, best val acc = 0.922, test acc = 0.921\n",
      "time duration = 51.469, best val acc = 0.907, test acc = 0.919\n",
      "time duration = 64.104, best val acc = 0.913, test acc = 0.920\n",
      "time duration = 48.569, best val acc = 0.822, test acc = 0.895\n",
      "test acc (mean, std):  0.9187491655349731 0.007647034538467201\n",
      "test acc (mean, std) after filter:  0.9202401886383692 0.002724320814719172\n"
     ]
    }
   ],
   "source": [
    "accs = []\n",
    "datasets = ['Cora', 'Citeseer', 'Pubmed', 'Coauthor-CS']\n",
    "for dataset in datasets:\n",
    "    args.data = dataset\n",
    "    if dataset == 'Cora':\n",
    "        args.dropout = 0.6\n",
    "        args.dropout2 = 0\n",
    "        args.weight_decay = 5e-4\n",
    "        print('\\nstart testing on ' + dataset + ' dataset')\n",
    "        accs.append(main(args))\n",
    "    elif dataset == 'Citeseer':\n",
    "        args.dropout = 0.6\n",
    "        args.dropout2 = 0\n",
    "        args.weight_decay = 5e-4\n",
    "        print('\\nstart testing on ' + dataset + ' dataset')\n",
    "        accs.append(main(args))\n",
    "    elif dataset == 'Pubmed':\n",
    "        args.dropout = 0.6\n",
    "        args.dropout2 = 0\n",
    "        args.weight_decay = 5e-4\n",
    "        print('\\nstart testing on ' + dataset + ' dataset')\n",
    "        accs.append(main(args))\n",
    "    elif dataset == 'Coauthor-CS':\n",
    "        args.dropout = 0.2\n",
    "        args.dropout2 = 0.2\n",
    "        args.weight_decay = 5e-3\n",
    "        for i in range(3):\n",
    "            print('\\nstart testing on ' + dataset + ' dataset with random split: ' + str(i))\n",
    "            args.test_id = i\n",
    "            accs.append(main(args))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "accs = np.array(accs)\n",
    "np.savetxt('acc_GTAN_hop10.txt',accs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
