{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "class config:\n",
    "    data = 'Cora' #choices=[\"Cora\", \"Citeseer\", \"Pubmed\", \"Coauthor-CS\"]\n",
    "    model = 'APPNP' #choices=[\"GCN\", \"GAT\", \"APPNP\", \"DAGNN\", \"TreeLSTM\", \"GTCN\", \"GTAN\", \"GTCN2\"]\n",
    "    n_in = 0\n",
    "    n_hid = 64\n",
    "    n_out = 0\n",
    "    num_heads = 1\n",
    "    num_out_heads = 1\n",
    "    device = 'cpu'\n",
    "    dropout = 0.5\n",
    "    dropout2 = 0.5\n",
    "    learning_rate = 0.01\n",
    "    weight_decay = 5e-4\n",
    "    patience = 200\n",
    "    num_iter = 1000\n",
    "    num_test = 30\n",
    "    hop = 10\n",
    "    alpha = 0.1 # used by APPNP only\n",
    "    random_label_split = False\n",
    "    num_train = 20 # for random label split only\n",
    "    num_val = 30 # for random label split only\n",
    "    data_load = True # load the saved label split to rerun the test (for reproduce purpose)\n",
    "    test_id = 1 # number of the test, only used to record the ith number of the random label split (for reproduce purpose)\n",
    "    filter_pct = 0.1 # remove the top and bottom filer_pct points before obtaining statistics of test accuracy\n",
    "    log = False # whether to show the training log or not\n",
    "    eval_metric = 'acc' # evaluation metrics, choices=[\"acc\", \"f1-macro\", \"f1-micro\"]\n",
    "    root_dir = '../..' # dir of the source code\n",
    "\n",
    "args = config()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append(args.root_dir)\n",
    "from train import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "start testing on Cora dataset\n",
      "  NumNodes: 2708\n",
      "  NumEdges: 10556\n",
      "  NumFeats: 1433\n",
      "  NumClasses: 7\n",
      "  NumTrainingSamples: 140\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "time duration = 13.567, best val acc = 0.814, test acc = 0.833\n",
      "time duration = 7.102, best val acc = 0.816, test acc = 0.828\n",
      "time duration = 8.736, best val acc = 0.814, test acc = 0.841\n",
      "time duration = 7.602, best val acc = 0.814, test acc = 0.834\n",
      "time duration = 10.769, best val acc = 0.824, test acc = 0.843\n",
      "time duration = 15.342, best val acc = 0.814, test acc = 0.829\n",
      "time duration = 9.751, best val acc = 0.814, test acc = 0.842\n",
      "time duration = 10.267, best val acc = 0.812, test acc = 0.830\n",
      "time duration = 15.660, best val acc = 0.824, test acc = 0.832\n",
      "time duration = 7.633, best val acc = 0.810, test acc = 0.826\n",
      "time duration = 9.230, best val acc = 0.816, test acc = 0.840\n",
      "time duration = 14.887, best val acc = 0.824, test acc = 0.829\n",
      "time duration = 7.709, best val acc = 0.826, test acc = 0.840\n",
      "time duration = 6.823, best val acc = 0.808, test acc = 0.835\n",
      "time duration = 8.944, best val acc = 0.812, test acc = 0.839\n",
      "time duration = 7.498, best val acc = 0.818, test acc = 0.843\n",
      "time duration = 16.066, best val acc = 0.820, test acc = 0.834\n",
      "time duration = 11.354, best val acc = 0.814, test acc = 0.840\n",
      "time duration = 19.302, best val acc = 0.826, test acc = 0.831\n",
      "time duration = 10.689, best val acc = 0.814, test acc = 0.832\n",
      "time duration = 8.969, best val acc = 0.820, test acc = 0.842\n",
      "time duration = 8.598, best val acc = 0.820, test acc = 0.842\n",
      "time duration = 11.598, best val acc = 0.820, test acc = 0.840\n",
      "time duration = 9.610, best val acc = 0.820, test acc = 0.838\n",
      "time duration = 11.680, best val acc = 0.818, test acc = 0.842\n",
      "time duration = 11.642, best val acc = 0.818, test acc = 0.829\n",
      "time duration = 18.497, best val acc = 0.820, test acc = 0.838\n",
      "time duration = 14.329, best val acc = 0.818, test acc = 0.830\n",
      "time duration = 8.716, best val acc = 0.816, test acc = 0.830\n",
      "time duration = 10.425, best val acc = 0.818, test acc = 0.837\n",
      "test acc (mean, std):  0.8356333792209625 0.005288249074248025\n",
      "test acc (mean, std) after filter:  0.8357500433921814 0.004539183294980775\n",
      "\n",
      "start testing on Citeseer dataset\n",
      "  NumNodes: 3327\n",
      "  NumEdges: 9228\n",
      "  NumFeats: 3703\n",
      "  NumClasses: 6\n",
      "  NumTrainingSamples: 120\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "time duration = 10.239, best val acc = 0.742, test acc = 0.720\n",
      "time duration = 9.379, best val acc = 0.740, test acc = 0.719\n",
      "time duration = 16.148, best val acc = 0.738, test acc = 0.722\n",
      "time duration = 8.386, best val acc = 0.736, test acc = 0.712\n",
      "time duration = 7.846, best val acc = 0.752, test acc = 0.714\n",
      "time duration = 8.625, best val acc = 0.738, test acc = 0.718\n",
      "time duration = 8.524, best val acc = 0.734, test acc = 0.714\n",
      "time duration = 7.540, best val acc = 0.742, test acc = 0.710\n",
      "time duration = 6.579, best val acc = 0.744, test acc = 0.724\n",
      "time duration = 13.069, best val acc = 0.734, test acc = 0.701\n",
      "time duration = 9.002, best val acc = 0.744, test acc = 0.720\n",
      "time duration = 8.159, best val acc = 0.746, test acc = 0.720\n",
      "time duration = 8.394, best val acc = 0.736, test acc = 0.714\n",
      "time duration = 6.803, best val acc = 0.744, test acc = 0.725\n",
      "time duration = 12.553, best val acc = 0.734, test acc = 0.710\n",
      "time duration = 11.911, best val acc = 0.738, test acc = 0.705\n",
      "time duration = 7.542, best val acc = 0.744, test acc = 0.728\n",
      "time duration = 8.418, best val acc = 0.744, test acc = 0.728\n",
      "time duration = 7.636, best val acc = 0.754, test acc = 0.726\n",
      "time duration = 7.582, best val acc = 0.740, test acc = 0.710\n",
      "time duration = 9.190, best val acc = 0.736, test acc = 0.709\n",
      "time duration = 10.008, best val acc = 0.734, test acc = 0.710\n",
      "time duration = 12.309, best val acc = 0.736, test acc = 0.719\n",
      "time duration = 8.285, best val acc = 0.738, test acc = 0.714\n",
      "time duration = 7.219, best val acc = 0.746, test acc = 0.711\n",
      "time duration = 9.109, best val acc = 0.738, test acc = 0.717\n",
      "time duration = 8.121, best val acc = 0.746, test acc = 0.715\n",
      "time duration = 8.694, best val acc = 0.742, test acc = 0.715\n",
      "time duration = 7.386, best val acc = 0.736, test acc = 0.721\n",
      "time duration = 7.818, best val acc = 0.740, test acc = 0.712\n",
      "test acc (mean, std):  0.7161000351111094 0.006528653874709649\n",
      "test acc (mean, std) after filter:  0.7160833676656088 0.004545296933008932\n",
      "\n",
      "start testing on Pubmed dataset\n",
      "  NumNodes: 19717\n",
      "  NumEdges: 88651\n",
      "  NumFeats: 500\n",
      "  NumClasses: 3\n",
      "  NumTrainingSamples: 60\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "time duration = 14.599, best val acc = 0.816, test acc = 0.799\n",
      "time duration = 14.153, best val acc = 0.820, test acc = 0.796\n",
      "time duration = 13.789, best val acc = 0.820, test acc = 0.797\n",
      "time duration = 12.138, best val acc = 0.820, test acc = 0.795\n",
      "time duration = 12.286, best val acc = 0.816, test acc = 0.795\n",
      "time duration = 15.988, best val acc = 0.818, test acc = 0.791\n",
      "time duration = 12.983, best val acc = 0.824, test acc = 0.795\n",
      "time duration = 10.252, best val acc = 0.824, test acc = 0.792\n",
      "time duration = 14.041, best val acc = 0.820, test acc = 0.800\n",
      "time duration = 10.739, best val acc = 0.820, test acc = 0.795\n",
      "time duration = 10.421, best val acc = 0.822, test acc = 0.799\n",
      "time duration = 12.148, best val acc = 0.824, test acc = 0.798\n",
      "time duration = 9.684, best val acc = 0.818, test acc = 0.794\n",
      "time duration = 12.017, best val acc = 0.826, test acc = 0.795\n",
      "time duration = 10.733, best val acc = 0.828, test acc = 0.802\n",
      "time duration = 12.327, best val acc = 0.820, test acc = 0.798\n",
      "time duration = 10.944, best val acc = 0.822, test acc = 0.803\n",
      "time duration = 10.591, best val acc = 0.818, test acc = 0.792\n",
      "time duration = 14.000, best val acc = 0.818, test acc = 0.798\n",
      "time duration = 14.585, best val acc = 0.824, test acc = 0.801\n",
      "time duration = 12.246, best val acc = 0.818, test acc = 0.796\n",
      "time duration = 11.071, best val acc = 0.820, test acc = 0.796\n",
      "time duration = 9.559, best val acc = 0.818, test acc = 0.796\n",
      "time duration = 13.911, best val acc = 0.822, test acc = 0.795\n",
      "time duration = 11.766, best val acc = 0.820, test acc = 0.797\n",
      "time duration = 14.955, best val acc = 0.818, test acc = 0.804\n",
      "time duration = 12.112, best val acc = 0.820, test acc = 0.790\n",
      "time duration = 10.775, best val acc = 0.816, test acc = 0.794\n",
      "time duration = 17.275, best val acc = 0.818, test acc = 0.799\n",
      "time duration = 11.232, best val acc = 0.820, test acc = 0.791\n",
      "test acc (mean, std):  0.7964333673318227 0.0034514110378355695\n",
      "test acc (mean, std) after filter:  0.7963333701093992 0.002285217035178998\n",
      "\n",
      "start testing on Coauthor-CS dataset with random split: 0\n",
      "time duration = 16.347, best val acc = 0.913, test acc = 0.920\n",
      "time duration = 13.364, best val acc = 0.902, test acc = 0.911\n",
      "time duration = 21.275, best val acc = 0.913, test acc = 0.923\n",
      "time duration = 21.562, best val acc = 0.911, test acc = 0.919\n",
      "time duration = 26.359, best val acc = 0.907, test acc = 0.915\n",
      "time duration = 12.305, best val acc = 0.902, test acc = 0.922\n",
      "time duration = 26.800, best val acc = 0.911, test acc = 0.918\n",
      "time duration = 13.624, best val acc = 0.909, test acc = 0.925\n",
      "time duration = 28.537, best val acc = 0.913, test acc = 0.918\n",
      "time duration = 24.441, best val acc = 0.913, test acc = 0.903\n",
      "time duration = 32.811, best val acc = 0.916, test acc = 0.919\n",
      "time duration = 23.328, best val acc = 0.913, test acc = 0.918\n",
      "time duration = 30.771, best val acc = 0.913, test acc = 0.921\n",
      "time duration = 21.211, best val acc = 0.907, test acc = 0.913\n",
      "time duration = 14.444, best val acc = 0.909, test acc = 0.922\n",
      "time duration = 20.452, best val acc = 0.907, test acc = 0.910\n",
      "time duration = 13.936, best val acc = 0.907, test acc = 0.923\n",
      "time duration = 27.074, best val acc = 0.913, test acc = 0.921\n",
      "time duration = 18.490, best val acc = 0.916, test acc = 0.920\n",
      "time duration = 24.918, best val acc = 0.911, test acc = 0.912\n",
      "time duration = 32.065, best val acc = 0.909, test acc = 0.919\n",
      "time duration = 19.819, best val acc = 0.909, test acc = 0.919\n",
      "time duration = 25.049, best val acc = 0.916, test acc = 0.918\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time duration = 19.319, best val acc = 0.916, test acc = 0.908\n",
      "time duration = 15.077, best val acc = 0.907, test acc = 0.920\n",
      "time duration = 22.345, best val acc = 0.909, test acc = 0.904\n",
      "time duration = 23.246, best val acc = 0.911, test acc = 0.918\n",
      "time duration = 19.727, best val acc = 0.913, test acc = 0.918\n",
      "time duration = 31.302, best val acc = 0.920, test acc = 0.903\n",
      "time duration = 34.155, best val acc = 0.913, test acc = 0.921\n",
      "test acc (mean, std):  0.916733960310618 0.005916607790025345\n",
      "test acc (mean, std) after filter:  0.9175126502911249 0.003804439435789541\n",
      "\n",
      "start testing on Coauthor-CS dataset with random split: 1\n",
      "time duration = 14.302, best val acc = 0.931, test acc = 0.923\n",
      "time duration = 21.206, best val acc = 0.936, test acc = 0.919\n",
      "time duration = 13.702, best val acc = 0.933, test acc = 0.923\n",
      "time duration = 11.975, best val acc = 0.938, test acc = 0.919\n",
      "time duration = 21.628, best val acc = 0.931, test acc = 0.920\n",
      "time duration = 14.980, best val acc = 0.933, test acc = 0.920\n",
      "time duration = 21.513, best val acc = 0.936, test acc = 0.922\n",
      "time duration = 15.365, best val acc = 0.933, test acc = 0.916\n",
      "time duration = 19.928, best val acc = 0.933, test acc = 0.913\n",
      "time duration = 14.817, best val acc = 0.933, test acc = 0.924\n",
      "time duration = 11.812, best val acc = 0.931, test acc = 0.916\n",
      "time duration = 15.419, best val acc = 0.933, test acc = 0.911\n",
      "time duration = 15.059, best val acc = 0.933, test acc = 0.915\n",
      "time duration = 20.641, best val acc = 0.929, test acc = 0.917\n",
      "time duration = 13.696, best val acc = 0.938, test acc = 0.919\n",
      "time duration = 14.260, best val acc = 0.933, test acc = 0.912\n",
      "time duration = 12.132, best val acc = 0.931, test acc = 0.918\n",
      "time duration = 12.028, best val acc = 0.933, test acc = 0.921\n",
      "time duration = 16.951, best val acc = 0.929, test acc = 0.916\n",
      "time duration = 12.534, best val acc = 0.940, test acc = 0.912\n",
      "time duration = 12.725, best val acc = 0.933, test acc = 0.923\n",
      "time duration = 12.044, best val acc = 0.936, test acc = 0.912\n",
      "time duration = 12.581, best val acc = 0.936, test acc = 0.920\n",
      "time duration = 18.231, best val acc = 0.938, test acc = 0.921\n",
      "time duration = 18.664, best val acc = 0.933, test acc = 0.915\n",
      "time duration = 19.705, best val acc = 0.929, test acc = 0.924\n",
      "time duration = 19.028, best val acc = 0.940, test acc = 0.922\n",
      "time duration = 17.531, best val acc = 0.931, test acc = 0.920\n",
      "time duration = 16.794, best val acc = 0.936, test acc = 0.912\n",
      "time duration = 13.172, best val acc = 0.936, test acc = 0.917\n",
      "test acc (mean, std):  0.9181159714857737 0.003915161246012908\n",
      "test acc (mean, std) after filter:  0.9182117084662119 0.0030773713566526074\n",
      "\n",
      "start testing on Coauthor-CS dataset with random split: 2\n",
      "time duration = 13.243, best val acc = 0.916, test acc = 0.921\n",
      "time duration = 14.223, best val acc = 0.916, test acc = 0.915\n",
      "time duration = 32.552, best val acc = 0.922, test acc = 0.908\n",
      "time duration = 18.676, best val acc = 0.922, test acc = 0.915\n",
      "time duration = 14.804, best val acc = 0.918, test acc = 0.904\n",
      "time duration = 11.560, best val acc = 0.916, test acc = 0.915\n",
      "time duration = 18.907, best val acc = 0.916, test acc = 0.919\n",
      "time duration = 14.363, best val acc = 0.920, test acc = 0.916\n",
      "time duration = 14.794, best val acc = 0.918, test acc = 0.919\n",
      "time duration = 17.467, best val acc = 0.920, test acc = 0.916\n",
      "time duration = 17.639, best val acc = 0.918, test acc = 0.917\n",
      "time duration = 21.431, best val acc = 0.922, test acc = 0.917\n",
      "time duration = 33.854, best val acc = 0.927, test acc = 0.919\n",
      "time duration = 30.674, best val acc = 0.922, test acc = 0.917\n",
      "time duration = 21.621, best val acc = 0.922, test acc = 0.908\n",
      "time duration = 21.876, best val acc = 0.920, test acc = 0.917\n",
      "time duration = 26.436, best val acc = 0.920, test acc = 0.908\n",
      "time duration = 18.348, best val acc = 0.920, test acc = 0.916\n",
      "time duration = 16.809, best val acc = 0.922, test acc = 0.912\n",
      "time duration = 28.605, best val acc = 0.924, test acc = 0.910\n",
      "time duration = 14.892, best val acc = 0.916, test acc = 0.916\n",
      "time duration = 16.226, best val acc = 0.911, test acc = 0.913\n",
      "time duration = 12.686, best val acc = 0.916, test acc = 0.906\n",
      "time duration = 25.602, best val acc = 0.918, test acc = 0.914\n",
      "time duration = 26.439, best val acc = 0.920, test acc = 0.908\n",
      "time duration = 14.123, best val acc = 0.918, test acc = 0.912\n",
      "time duration = 14.353, best val acc = 0.916, test acc = 0.914\n",
      "time duration = 23.518, best val acc = 0.916, test acc = 0.912\n",
      "time duration = 13.480, best val acc = 0.913, test acc = 0.915\n",
      "time duration = 23.334, best val acc = 0.920, test acc = 0.918\n",
      "test acc (mean, std):  0.9139206329981486 0.004189479467874759\n",
      "test acc (mean, std) after filter:  0.9141808276375135 0.0030767416751829443\n"
     ]
    }
   ],
   "source": [
    "accs = []\n",
    "datasets = ['Cora', 'Citeseer', 'Pubmed', 'Coauthor-CS']\n",
    "for dataset in datasets:\n",
    "    args.data = dataset\n",
    "    if dataset != 'Coauthor-CS':\n",
    "        args.alpha = 0.1\n",
    "        print('\\nstart testing on ' + dataset + ' dataset')\n",
    "        accs.append(main(args))\n",
    "    else:\n",
    "        args.alpha = 0.2\n",
    "        for i in range(3):\n",
    "            print('\\nstart testing on ' + dataset + ' dataset with random split: ' + str(i))\n",
    "            args.test_id = i\n",
    "            args.random_label_split = True\n",
    "            accs.append(main(args))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "accs = np.array(accs)\n",
    "np.savetxt('acc_APPNP.txt', accs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
