{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "class config:\n",
    "    data = 'Cora' #choices=[\"Cora\", \"Citeseer\", \"Pubmed\", \"Coauthor-CS\"]\n",
    "    model = 'GTAN' #choices=[\"GCN\", \"GAT\", \"APPNP\", \"DAGNN\", \"TreeLSTM\", \"GTCN\", \"GTAN\", \"GTCN2\"]\n",
    "    n_in = 0\n",
    "    n_hid = 64\n",
    "    n_out = 0\n",
    "    num_heads = 1\n",
    "    num_out_heads = 1\n",
    "    device = 'cpu'\n",
    "    dropout = 0.6\n",
    "    dropout2 = 0.6\n",
    "    learning_rate = 0.01\n",
    "    weight_decay = 5e-4\n",
    "    patience = 300\n",
    "    num_iter = 1000\n",
    "    num_test = 30\n",
    "    hop = 5\n",
    "    alpha = 0.1 # used by APPNP only\n",
    "    random_label_split = False\n",
    "    num_train = 20 # for random label split only\n",
    "    num_val = 30 # for random label split only\n",
    "    data_load = True # load the saved label split to rerun the test (for reproduce purpose)\n",
    "    test_id = 1 # number of the test, only used to record the ith number of the random label split (for reproduce purpose)\n",
    "    filter_pct = 0.1 # remove the top and bottom filer_pct points before obtaining statistics of test accuracy\n",
    "    log = False # whether to show the training log or not\n",
    "    eval_metric = 'acc' # evaluation metrics, choices=[\"acc\", \"f1-macro\", \"f1-micro\"]\n",
    "    root_dir = '../..' # dir of the source code\n",
    "\n",
    "args = config()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append(args.root_dir)\n",
    "from train import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "start testing on Cora dataset\n",
      "  NumNodes: 2708\n",
      "  NumEdges: 10556\n",
      "  NumFeats: 1433\n",
      "  NumClasses: 7\n",
      "  NumTrainingSamples: 140\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "time duration = 16.412, best val acc = 0.828, test acc = 0.847\n",
      "time duration = 15.420, best val acc = 0.830, test acc = 0.850\n",
      "time duration = 14.134, best val acc = 0.822, test acc = 0.842\n",
      "time duration = 12.937, best val acc = 0.828, test acc = 0.845\n",
      "time duration = 13.286, best val acc = 0.814, test acc = 0.821\n",
      "time duration = 19.060, best val acc = 0.822, test acc = 0.839\n",
      "time duration = 25.212, best val acc = 0.822, test acc = 0.855\n",
      "time duration = 17.923, best val acc = 0.822, test acc = 0.843\n",
      "time duration = 11.823, best val acc = 0.816, test acc = 0.832\n",
      "time duration = 15.713, best val acc = 0.822, test acc = 0.852\n",
      "time duration = 23.253, best val acc = 0.830, test acc = 0.842\n",
      "time duration = 11.492, best val acc = 0.830, test acc = 0.841\n",
      "time duration = 25.280, best val acc = 0.822, test acc = 0.843\n",
      "time duration = 30.794, best val acc = 0.818, test acc = 0.846\n",
      "time duration = 30.828, best val acc = 0.824, test acc = 0.845\n",
      "time duration = 23.275, best val acc = 0.830, test acc = 0.849\n",
      "time duration = 12.023, best val acc = 0.832, test acc = 0.857\n",
      "time duration = 20.714, best val acc = 0.818, test acc = 0.839\n",
      "time duration = 30.880, best val acc = 0.828, test acc = 0.846\n",
      "time duration = 11.864, best val acc = 0.824, test acc = 0.848\n",
      "time duration = 19.276, best val acc = 0.830, test acc = 0.844\n",
      "time duration = 10.718, best val acc = 0.828, test acc = 0.837\n",
      "time duration = 30.886, best val acc = 0.826, test acc = 0.846\n",
      "time duration = 19.327, best val acc = 0.826, test acc = 0.838\n",
      "time duration = 14.796, best val acc = 0.816, test acc = 0.843\n",
      "time duration = 13.961, best val acc = 0.816, test acc = 0.831\n",
      "time duration = 10.093, best val acc = 0.816, test acc = 0.822\n",
      "time duration = 18.197, best val acc = 0.824, test acc = 0.840\n",
      "time duration = 11.055, best val acc = 0.838, test acc = 0.837\n",
      "time duration = 13.976, best val acc = 0.820, test acc = 0.852\n",
      "test acc (mean, std):  0.8424000382423401 0.008138799291172085\n",
      "test acc (mean, std) after filter:  0.8430833717187246 0.004599972942443291\n",
      "\n",
      "start testing on Citeseer dataset\n",
      "  NumNodes: 3327\n",
      "  NumEdges: 9228\n",
      "  NumFeats: 3703\n",
      "  NumClasses: 6\n",
      "  NumTrainingSamples: 120\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "time duration = 20.868, best val acc = 0.728, test acc = 0.708\n",
      "time duration = 10.429, best val acc = 0.724, test acc = 0.721\n",
      "time duration = 10.563, best val acc = 0.738, test acc = 0.724\n",
      "time duration = 10.895, best val acc = 0.728, test acc = 0.708\n",
      "time duration = 12.094, best val acc = 0.734, test acc = 0.718\n",
      "time duration = 15.707, best val acc = 0.734, test acc = 0.711\n",
      "time duration = 9.648, best val acc = 0.732, test acc = 0.724\n",
      "time duration = 19.516, best val acc = 0.724, test acc = 0.703\n",
      "time duration = 10.281, best val acc = 0.736, test acc = 0.708\n",
      "time duration = 10.186, best val acc = 0.732, test acc = 0.715\n",
      "time duration = 10.229, best val acc = 0.730, test acc = 0.707\n",
      "time duration = 9.760, best val acc = 0.724, test acc = 0.737\n",
      "time duration = 10.387, best val acc = 0.734, test acc = 0.719\n",
      "time duration = 13.048, best val acc = 0.726, test acc = 0.715\n",
      "time duration = 10.155, best val acc = 0.720, test acc = 0.706\n",
      "time duration = 16.165, best val acc = 0.730, test acc = 0.710\n",
      "time duration = 10.120, best val acc = 0.740, test acc = 0.720\n",
      "time duration = 9.552, best val acc = 0.742, test acc = 0.726\n",
      "time duration = 30.493, best val acc = 0.726, test acc = 0.702\n",
      "time duration = 26.975, best val acc = 0.740, test acc = 0.729\n",
      "time duration = 10.182, best val acc = 0.738, test acc = 0.712\n",
      "time duration = 9.899, best val acc = 0.744, test acc = 0.709\n",
      "time duration = 22.699, best val acc = 0.728, test acc = 0.719\n",
      "time duration = 10.851, best val acc = 0.744, test acc = 0.724\n",
      "time duration = 11.116, best val acc = 0.732, test acc = 0.732\n",
      "time duration = 12.468, best val acc = 0.746, test acc = 0.719\n",
      "time duration = 14.995, best val acc = 0.742, test acc = 0.722\n",
      "time duration = 10.711, best val acc = 0.734, test acc = 0.716\n",
      "time duration = 14.916, best val acc = 0.736, test acc = 0.719\n",
      "time duration = 10.111, best val acc = 0.740, test acc = 0.742\n",
      "test acc (mean, std):  0.7175000309944153 0.009573408869796645\n",
      "test acc (mean, std) after filter:  0.7167916968464851 0.0063375868676542315\n",
      "\n",
      "start testing on Pubmed dataset\n",
      "  NumNodes: 19717\n",
      "  NumEdges: 88651\n",
      "  NumFeats: 500\n",
      "  NumClasses: 3\n",
      "  NumTrainingSamples: 60\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "time duration = 21.713, best val acc = 0.822, test acc = 0.796\n",
      "time duration = 19.880, best val acc = 0.822, test acc = 0.790\n",
      "time duration = 22.321, best val acc = 0.820, test acc = 0.792\n",
      "time duration = 22.535, best val acc = 0.826, test acc = 0.794\n",
      "time duration = 19.939, best val acc = 0.816, test acc = 0.790\n",
      "time duration = 23.747, best val acc = 0.828, test acc = 0.798\n",
      "time duration = 20.731, best val acc = 0.828, test acc = 0.794\n",
      "time duration = 19.920, best val acc = 0.824, test acc = 0.795\n",
      "time duration = 21.785, best val acc = 0.822, test acc = 0.795\n",
      "time duration = 29.250, best val acc = 0.824, test acc = 0.802\n",
      "time duration = 40.938, best val acc = 0.822, test acc = 0.795\n",
      "time duration = 25.204, best val acc = 0.818, test acc = 0.798\n",
      "time duration = 21.243, best val acc = 0.822, test acc = 0.790\n",
      "time duration = 26.415, best val acc = 0.822, test acc = 0.800\n",
      "time duration = 21.857, best val acc = 0.814, test acc = 0.801\n",
      "time duration = 24.469, best val acc = 0.820, test acc = 0.783\n",
      "time duration = 23.715, best val acc = 0.826, test acc = 0.799\n",
      "time duration = 20.772, best val acc = 0.820, test acc = 0.798\n",
      "time duration = 23.089, best val acc = 0.822, test acc = 0.790\n",
      "time duration = 23.840, best val acc = 0.820, test acc = 0.793\n",
      "time duration = 20.843, best val acc = 0.820, test acc = 0.789\n",
      "time duration = 21.118, best val acc = 0.816, test acc = 0.785\n",
      "time duration = 29.655, best val acc = 0.820, test acc = 0.791\n",
      "time duration = 22.061, best val acc = 0.822, test acc = 0.792\n",
      "time duration = 22.848, best val acc = 0.824, test acc = 0.796\n",
      "time duration = 20.742, best val acc = 0.818, test acc = 0.780\n",
      "time duration = 21.603, best val acc = 0.820, test acc = 0.793\n",
      "time duration = 20.034, best val acc = 0.820, test acc = 0.787\n",
      "time duration = 21.374, best val acc = 0.822, test acc = 0.787\n",
      "time duration = 24.274, best val acc = 0.820, test acc = 0.787\n",
      "test acc (mean, std):  0.7926667034626007 0.005268353208704212\n",
      "test acc (mean, std) after filter:  0.7928750365972519 0.0035859942754177233\n",
      "\n",
      "start testing on Coauthor-CS dataset with random split: 2\n",
      "time duration = 29.290, best val acc = 0.911, test acc = 0.918\n",
      "time duration = 33.298, best val acc = 0.920, test acc = 0.919\n",
      "time duration = 31.767, best val acc = 0.911, test acc = 0.914\n",
      "time duration = 54.216, best val acc = 0.922, test acc = 0.919\n",
      "time duration = 56.359, best val acc = 0.918, test acc = 0.918\n",
      "time duration = 66.550, best val acc = 0.922, test acc = 0.917\n",
      "time duration = 33.520, best val acc = 0.920, test acc = 0.923\n",
      "time duration = 26.703, best val acc = 0.900, test acc = 0.912\n",
      "time duration = 29.104, best val acc = 0.911, test acc = 0.915\n",
      "time duration = 30.220, best val acc = 0.916, test acc = 0.921\n",
      "time duration = 27.719, best val acc = 0.916, test acc = 0.916\n",
      "time duration = 27.813, best val acc = 0.911, test acc = 0.915\n",
      "time duration = 30.091, best val acc = 0.918, test acc = 0.916\n",
      "time duration = 28.549, best val acc = 0.913, test acc = 0.915\n",
      "time duration = 47.538, best val acc = 0.920, test acc = 0.918\n",
      "time duration = 35.967, best val acc = 0.911, test acc = 0.915\n",
      "time duration = 28.550, best val acc = 0.913, test acc = 0.919\n",
      "time duration = 29.996, best val acc = 0.913, test acc = 0.918\n",
      "time duration = 33.892, best val acc = 0.913, test acc = 0.920\n",
      "time duration = 37.389, best val acc = 0.918, test acc = 0.913\n",
      "time duration = 33.183, best val acc = 0.916, test acc = 0.914\n",
      "time duration = 35.347, best val acc = 0.916, test acc = 0.919\n",
      "time duration = 26.876, best val acc = 0.911, test acc = 0.915\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time duration = 30.090, best val acc = 0.909, test acc = 0.921\n",
      "time duration = 46.439, best val acc = 0.918, test acc = 0.918\n",
      "time duration = 32.706, best val acc = 0.913, test acc = 0.922\n",
      "time duration = 52.799, best val acc = 0.920, test acc = 0.917\n",
      "time duration = 27.814, best val acc = 0.909, test acc = 0.920\n",
      "time duration = 29.399, best val acc = 0.911, test acc = 0.916\n",
      "time duration = 36.432, best val acc = 0.916, test acc = 0.917\n",
      "test acc (mean, std):  0.9173879981040954 0.0025438110600012214\n",
      "test acc (mean, std) after filter:  0.917363353073597 0.0017899251023138823\n"
     ]
    }
   ],
   "source": [
    "accs = []\n",
    "datasets = ['Cora', 'Citeseer', 'Pubmed', 'Coauthor-CS']\n",
    "for dataset in datasets:\n",
    "    args.data = dataset\n",
    "    if dataset == 'Cora':\n",
    "        args.dropout = 0.6\n",
    "        args.dropout2 = 0.6\n",
    "        args.weight_decay = 5e-4\n",
    "        print('\\nstart testing on ' + dataset + ' dataset')\n",
    "        accs.append(main(args))\n",
    "    elif dataset == 'Citeseer':\n",
    "        args.dropout = 0.6\n",
    "        args.dropout2 = 0.6\n",
    "        args.weight_decay = 5e-4\n",
    "        print('\\nstart testing on ' + dataset + ' dataset')\n",
    "        accs.append(main(args))\n",
    "    elif dataset == 'Pubmed':\n",
    "        args.dropout = 0.6\n",
    "        args.dropout2 = 0\n",
    "        args.weight_decay = 5e-4\n",
    "        print('\\nstart testing on ' + dataset + ' dataset')\n",
    "        accs.append(main(args))\n",
    "    elif dataset == 'Coauthor-CS':\n",
    "        args.dropout = 0.2\n",
    "        args.dropout2 = 0.2\n",
    "        args.weight_decay = 5e-3\n",
    "        for i in [2]:\n",
    "            print('\\nstart testing on ' + dataset + ' dataset with random split: ' + str(i))\n",
    "            args.test_id = i\n",
    "            accs.append(main(args))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "accs = np.array(accs)\n",
    "np.savetxt('acc_GTAN_hop5.txt',accs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
