{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "class config:\n",
    "    data = 'Cora' #choices=[\"Cora\", \"Citeseer\", \"Pubmed\", \"Coauthor-CS\"]\n",
    "    model = 'DAGNN' #choices=[\"GCN\", \"GAT\", \"APPNP\", \"DAGNN\", \"TreeLSTM\", \"GTCN\", \"GTAN\", \"GTCN2\"]\n",
    "    n_in = 0\n",
    "    n_hid = 64\n",
    "    n_out = 0\n",
    "    num_heads = 1\n",
    "    num_out_heads = 1\n",
    "    device = 'cpu'\n",
    "    dropout = 0.8\n",
    "    dropout2 = 0.6\n",
    "    learning_rate = 0.01\n",
    "    weight_decay = 5e-3\n",
    "    patience = 100\n",
    "    num_iter = 1000\n",
    "    num_test = 30\n",
    "    hop = 10\n",
    "    alpha = 0.1 # used by APPNP only\n",
    "    random_label_split = False\n",
    "    num_train = 20 # for random label split only\n",
    "    num_val = 30 # for random label split only\n",
    "    data_load = True # load the saved label split to rerun the test (for reproduce purpose)\n",
    "    test_id = 1 # number of the test, only used to record the ith number of the random label split (for reproduce purpose)\n",
    "    filter_pct = 0.1 # remove the top and bottom filer_pct points before obtaining statistics of test accuracy\n",
    "    log = False # whether to show the training log or not\n",
    "    eval_metric = 'acc' # evaluation metrics, choices=[\"acc\", \"f1-macro\", \"f1-micro\"]\n",
    "    root_dir = '../..' # dir of the source code\n",
    "\n",
    "args = config()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append(args.root_dir)\n",
    "from train import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "start testing on Cora dataset\n",
      "  NumNodes: 2708\n",
      "  NumEdges: 10556\n",
      "  NumFeats: 1433\n",
      "  NumClasses: 7\n",
      "  NumTrainingSamples: 140\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "time duration = 3.595, best val acc = 0.824, test acc = 0.844\n",
      "time duration = 3.400, best val acc = 0.832, test acc = 0.838\n",
      "time duration = 4.737, best val acc = 0.832, test acc = 0.849\n",
      "time duration = 5.121, best val acc = 0.828, test acc = 0.831\n",
      "time duration = 5.305, best val acc = 0.822, test acc = 0.833\n",
      "time duration = 4.085, best val acc = 0.830, test acc = 0.851\n",
      "time duration = 3.272, best val acc = 0.826, test acc = 0.843\n",
      "time duration = 4.139, best val acc = 0.824, test acc = 0.842\n",
      "time duration = 3.048, best val acc = 0.828, test acc = 0.853\n",
      "time duration = 4.468, best val acc = 0.826, test acc = 0.851\n",
      "time duration = 4.902, best val acc = 0.830, test acc = 0.829\n",
      "time duration = 5.250, best val acc = 0.844, test acc = 0.838\n",
      "time duration = 4.334, best val acc = 0.828, test acc = 0.846\n",
      "time duration = 3.887, best val acc = 0.828, test acc = 0.841\n",
      "time duration = 4.907, best val acc = 0.820, test acc = 0.851\n",
      "time duration = 4.769, best val acc = 0.828, test acc = 0.837\n",
      "time duration = 4.387, best val acc = 0.824, test acc = 0.845\n",
      "time duration = 5.093, best val acc = 0.822, test acc = 0.808\n",
      "time duration = 4.134, best val acc = 0.820, test acc = 0.842\n",
      "time duration = 3.093, best val acc = 0.828, test acc = 0.844\n",
      "time duration = 4.620, best val acc = 0.824, test acc = 0.835\n",
      "time duration = 5.433, best val acc = 0.824, test acc = 0.836\n",
      "time duration = 2.921, best val acc = 0.828, test acc = 0.837\n",
      "time duration = 3.043, best val acc = 0.826, test acc = 0.825\n",
      "time duration = 3.932, best val acc = 0.818, test acc = 0.836\n",
      "time duration = 4.903, best val acc = 0.824, test acc = 0.845\n",
      "time duration = 3.874, best val acc = 0.824, test acc = 0.844\n",
      "time duration = 3.082, best val acc = 0.822, test acc = 0.835\n",
      "time duration = 3.595, best val acc = 0.822, test acc = 0.836\n",
      "time duration = 4.969, best val acc = 0.832, test acc = 0.836\n",
      "test acc (mean, std):  0.8393667022387187 0.008931153607990641\n",
      "test acc (mean, std) after filter:  0.8401667028665543 0.005096293501537869\n",
      "\n",
      "start testing on Citeseer dataset\n",
      "  NumNodes: 3327\n",
      "  NumEdges: 9228\n",
      "  NumFeats: 3703\n",
      "  NumClasses: 6\n",
      "  NumTrainingSamples: 120\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "time duration = 6.901, best val acc = 0.748, test acc = 0.732\n",
      "time duration = 8.285, best val acc = 0.750, test acc = 0.726\n",
      "time duration = 6.023, best val acc = 0.736, test acc = 0.724\n",
      "time duration = 5.827, best val acc = 0.746, test acc = 0.726\n",
      "time duration = 5.314, best val acc = 0.744, test acc = 0.719\n",
      "time duration = 7.853, best val acc = 0.748, test acc = 0.731\n",
      "time duration = 8.815, best val acc = 0.748, test acc = 0.721\n",
      "time duration = 7.197, best val acc = 0.744, test acc = 0.723\n",
      "time duration = 7.261, best val acc = 0.756, test acc = 0.730\n",
      "time duration = 3.432, best val acc = 0.732, test acc = 0.730\n",
      "time duration = 6.314, best val acc = 0.744, test acc = 0.717\n",
      "time duration = 5.710, best val acc = 0.746, test acc = 0.729\n",
      "time duration = 7.017, best val acc = 0.744, test acc = 0.729\n",
      "time duration = 6.551, best val acc = 0.744, test acc = 0.727\n",
      "time duration = 6.074, best val acc = 0.744, test acc = 0.727\n",
      "time duration = 4.606, best val acc = 0.754, test acc = 0.720\n",
      "time duration = 6.252, best val acc = 0.744, test acc = 0.719\n",
      "time duration = 5.958, best val acc = 0.746, test acc = 0.734\n",
      "time duration = 8.056, best val acc = 0.748, test acc = 0.718\n",
      "time duration = 5.952, best val acc = 0.746, test acc = 0.734\n",
      "time duration = 6.128, best val acc = 0.742, test acc = 0.718\n",
      "time duration = 6.329, best val acc = 0.752, test acc = 0.733\n",
      "time duration = 5.752, best val acc = 0.750, test acc = 0.731\n",
      "time duration = 5.103, best val acc = 0.742, test acc = 0.723\n",
      "time duration = 6.853, best val acc = 0.744, test acc = 0.732\n",
      "time duration = 6.363, best val acc = 0.744, test acc = 0.721\n",
      "time duration = 7.134, best val acc = 0.742, test acc = 0.726\n",
      "time duration = 5.786, best val acc = 0.748, test acc = 0.726\n",
      "time duration = 7.765, best val acc = 0.754, test acc = 0.718\n",
      "time duration = 5.485, best val acc = 0.740, test acc = 0.734\n",
      "test acc (mean, std):  0.7259333650271098 0.005458527470238962\n",
      "test acc (mean, std) after filter:  0.7259583647052447 0.004532097744325579\n",
      "\n",
      "start testing on Pubmed dataset\n",
      "  NumNodes: 19717\n",
      "  NumEdges: 88651\n",
      "  NumFeats: 500\n",
      "  NumClasses: 3\n",
      "  NumTrainingSamples: 60\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "time duration = 3.675, best val acc = 0.818, test acc = 0.794\n",
      "time duration = 3.484, best val acc = 0.818, test acc = 0.786\n",
      "time duration = 5.659, best val acc = 0.828, test acc = 0.807\n",
      "time duration = 3.712, best val acc = 0.826, test acc = 0.802\n",
      "time duration = 4.337, best val acc = 0.824, test acc = 0.796\n",
      "time duration = 4.519, best val acc = 0.824, test acc = 0.795\n",
      "time duration = 3.391, best val acc = 0.822, test acc = 0.800\n",
      "time duration = 3.948, best val acc = 0.816, test acc = 0.793\n",
      "time duration = 7.546, best val acc = 0.824, test acc = 0.802\n",
      "time duration = 7.373, best val acc = 0.824, test acc = 0.794\n",
      "time duration = 3.921, best val acc = 0.824, test acc = 0.794\n",
      "time duration = 4.699, best val acc = 0.830, test acc = 0.804\n",
      "time duration = 3.553, best val acc = 0.818, test acc = 0.794\n",
      "time duration = 3.477, best val acc = 0.826, test acc = 0.797\n",
      "time duration = 3.740, best val acc = 0.824, test acc = 0.797\n",
      "time duration = 3.649, best val acc = 0.822, test acc = 0.791\n",
      "time duration = 3.452, best val acc = 0.820, test acc = 0.797\n",
      "time duration = 5.284, best val acc = 0.818, test acc = 0.794\n",
      "time duration = 4.331, best val acc = 0.828, test acc = 0.788\n",
      "time duration = 3.684, best val acc = 0.818, test acc = 0.790\n",
      "time duration = 3.772, best val acc = 0.818, test acc = 0.795\n",
      "time duration = 3.601, best val acc = 0.822, test acc = 0.794\n",
      "time duration = 4.807, best val acc = 0.822, test acc = 0.790\n",
      "time duration = 3.845, best val acc = 0.820, test acc = 0.792\n",
      "time duration = 6.141, best val acc = 0.830, test acc = 0.806\n",
      "time duration = 3.758, best val acc = 0.832, test acc = 0.799\n",
      "time duration = 4.454, best val acc = 0.820, test acc = 0.793\n",
      "time duration = 5.144, best val acc = 0.822, test acc = 0.805\n",
      "time duration = 5.169, best val acc = 0.828, test acc = 0.801\n",
      "time duration = 4.694, best val acc = 0.830, test acc = 0.798\n",
      "test acc (mean, std):  0.7962667028109233 0.005208546253606751\n",
      "test acc (mean, std) after filter:  0.7960833683609962 0.0036161318278548587\n",
      "\n",
      "start testing on Coauthor-CS dataset with random split: 0\n",
      "time duration = 11.604, best val acc = 0.891, test acc = 0.901\n",
      "time duration = 9.809, best val acc = 0.889, test acc = 0.910\n",
      "time duration = 8.096, best val acc = 0.889, test acc = 0.887\n",
      "time duration = 6.209, best val acc = 0.891, test acc = 0.904\n",
      "time duration = 9.721, best val acc = 0.893, test acc = 0.908\n",
      "time duration = 6.251, best val acc = 0.893, test acc = 0.895\n",
      "time duration = 8.518, best val acc = 0.889, test acc = 0.903\n",
      "time duration = 8.292, best val acc = 0.889, test acc = 0.905\n",
      "time duration = 6.503, best val acc = 0.882, test acc = 0.909\n",
      "time duration = 9.312, best val acc = 0.896, test acc = 0.911\n",
      "time duration = 5.893, best val acc = 0.884, test acc = 0.901\n",
      "time duration = 8.968, best val acc = 0.887, test acc = 0.906\n",
      "time duration = 8.710, best val acc = 0.893, test acc = 0.899\n",
      "time duration = 11.357, best val acc = 0.896, test acc = 0.906\n",
      "time duration = 7.669, best val acc = 0.884, test acc = 0.902\n",
      "time duration = 4.709, best val acc = 0.882, test acc = 0.898\n",
      "time duration = 9.837, best val acc = 0.896, test acc = 0.907\n",
      "time duration = 6.682, best val acc = 0.889, test acc = 0.897\n",
      "time duration = 11.371, best val acc = 0.898, test acc = 0.906\n",
      "time duration = 8.845, best val acc = 0.891, test acc = 0.907\n",
      "time duration = 7.287, best val acc = 0.898, test acc = 0.904\n",
      "time duration = 11.106, best val acc = 0.893, test acc = 0.909\n",
      "time duration = 7.186, best val acc = 0.889, test acc = 0.903\n",
      "time duration = 5.911, best val acc = 0.893, test acc = 0.903\n",
      "time duration = 6.640, best val acc = 0.884, test acc = 0.905\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time duration = 9.088, best val acc = 0.893, test acc = 0.901\n",
      "time duration = 5.121, best val acc = 0.898, test acc = 0.899\n",
      "time duration = 12.746, best val acc = 0.898, test acc = 0.904\n",
      "time duration = 12.227, best val acc = 0.887, test acc = 0.905\n",
      "time duration = 5.863, best val acc = 0.882, test acc = 0.904\n",
      "test acc (mean, std):  0.9032929420471192 0.004775457294182246\n",
      "test acc (mean, std) after filter:  0.903763567407926 0.0026988925627195208\n",
      "\n",
      "start testing on Coauthor-CS dataset with random split: 1\n",
      "time duration = 7.514, best val acc = 0.920, test acc = 0.906\n",
      "time duration = 11.783, best val acc = 0.916, test acc = 0.909\n",
      "time duration = 6.743, best val acc = 0.922, test acc = 0.916\n",
      "time duration = 5.622, best val acc = 0.916, test acc = 0.906\n",
      "time duration = 7.261, best val acc = 0.918, test acc = 0.892\n",
      "time duration = 7.126, best val acc = 0.916, test acc = 0.906\n",
      "time duration = 5.638, best val acc = 0.916, test acc = 0.901\n",
      "time duration = 9.461, best val acc = 0.911, test acc = 0.888\n",
      "time duration = 5.575, best val acc = 0.909, test acc = 0.897\n",
      "time duration = 8.930, best val acc = 0.911, test acc = 0.908\n",
      "time duration = 8.701, best val acc = 0.920, test acc = 0.904\n",
      "time duration = 5.590, best val acc = 0.918, test acc = 0.907\n",
      "time duration = 9.247, best val acc = 0.918, test acc = 0.898\n",
      "time duration = 6.161, best val acc = 0.913, test acc = 0.903\n",
      "time duration = 8.419, best val acc = 0.909, test acc = 0.895\n",
      "time duration = 8.617, best val acc = 0.916, test acc = 0.897\n",
      "time duration = 6.558, best val acc = 0.920, test acc = 0.902\n",
      "time duration = 6.091, best val acc = 0.916, test acc = 0.902\n",
      "time duration = 11.181, best val acc = 0.920, test acc = 0.905\n",
      "time duration = 5.390, best val acc = 0.916, test acc = 0.906\n",
      "time duration = 9.190, best val acc = 0.913, test acc = 0.891\n",
      "time duration = 10.587, best val acc = 0.922, test acc = 0.896\n",
      "time duration = 6.900, best val acc = 0.918, test acc = 0.910\n",
      "time duration = 6.421, best val acc = 0.916, test acc = 0.902\n",
      "time duration = 8.624, best val acc = 0.918, test acc = 0.899\n",
      "time duration = 5.506, best val acc = 0.924, test acc = 0.910\n",
      "time duration = 6.414, best val acc = 0.916, test acc = 0.900\n",
      "time duration = 6.787, best val acc = 0.913, test acc = 0.915\n",
      "time duration = 6.545, best val acc = 0.913, test acc = 0.904\n",
      "time duration = 5.715, best val acc = 0.916, test acc = 0.908\n",
      "test acc (mean, std):  0.9027981559435526 0.00658627999603236\n",
      "test acc (mean, std) after filter:  0.9029626150925955 0.00431086295823306\n",
      "\n",
      "start testing on Coauthor-CS dataset with random split: 2\n",
      "time duration = 15.066, best val acc = 0.898, test acc = 0.882\n",
      "time duration = 7.509, best val acc = 0.896, test acc = 0.886\n",
      "time duration = 5.308, best val acc = 0.893, test acc = 0.890\n",
      "time duration = 6.933, best val acc = 0.891, test acc = 0.908\n",
      "time duration = 7.069, best val acc = 0.898, test acc = 0.902\n",
      "time duration = 7.102, best val acc = 0.896, test acc = 0.901\n",
      "time duration = 5.713, best val acc = 0.891, test acc = 0.908\n",
      "time duration = 11.757, best val acc = 0.893, test acc = 0.896\n",
      "time duration = 6.602, best val acc = 0.900, test acc = 0.897\n",
      "time duration = 9.316, best val acc = 0.893, test acc = 0.895\n",
      "time duration = 6.280, best val acc = 0.896, test acc = 0.892\n",
      "time duration = 12.927, best val acc = 0.896, test acc = 0.898\n",
      "time duration = 5.659, best val acc = 0.896, test acc = 0.903\n",
      "time duration = 6.172, best val acc = 0.893, test acc = 0.894\n",
      "time duration = 11.188, best val acc = 0.909, test acc = 0.903\n",
      "time duration = 7.589, best val acc = 0.902, test acc = 0.904\n",
      "time duration = 5.809, best val acc = 0.896, test acc = 0.900\n",
      "time duration = 6.495, best val acc = 0.889, test acc = 0.892\n",
      "time duration = 6.375, best val acc = 0.898, test acc = 0.896\n",
      "time duration = 7.277, best val acc = 0.889, test acc = 0.886\n",
      "time duration = 7.035, best val acc = 0.891, test acc = 0.902\n",
      "time duration = 7.075, best val acc = 0.884, test acc = 0.886\n",
      "time duration = 9.151, best val acc = 0.893, test acc = 0.899\n",
      "time duration = 6.396, best val acc = 0.887, test acc = 0.900\n",
      "time duration = 10.254, best val acc = 0.893, test acc = 0.908\n",
      "time duration = 5.782, best val acc = 0.898, test acc = 0.898\n",
      "time duration = 6.372, best val acc = 0.898, test acc = 0.895\n",
      "time duration = 5.156, best val acc = 0.893, test acc = 0.879\n",
      "time duration = 6.576, best val acc = 0.887, test acc = 0.886\n",
      "time duration = 8.105, best val acc = 0.898, test acc = 0.890\n",
      "test acc (mean, std):  0.895889961719513 0.007677878829918623\n",
      "test acc (mean, std) after filter:  0.896043044825395 0.0055038796357322894\n"
     ]
    }
   ],
   "source": [
    "accs = []\n",
    "datasets = ['Cora', 'Citeseer', 'Pubmed', 'Coauthor-CS']\n",
    "for dataset in datasets:\n",
    "    args.data = dataset\n",
    "    if dataset != 'Coauthor-CS':\n",
    "        if dataset == 'Citeseer':\n",
    "            args.dropout = 0.5\n",
    "            args.weight_decay = 0.02\n",
    "        else:\n",
    "            args.dropout = 0.8\n",
    "            args.weight_decay = 0.005\n",
    "        print('\\nstart testing on ' + dataset + ' dataset')\n",
    "        accs.append(main(args))\n",
    "    else:\n",
    "        args.dropout = 0.8\n",
    "        args.weight_decay = 0.005\n",
    "        for i in range(3):\n",
    "            print('\\nstart testing on ' + dataset + ' dataset with random split: ' + str(i))\n",
    "            args.test_id = i\n",
    "            #args.random_label_split = True\n",
    "            accs.append(main(args))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "accs = np.array(accs)\n",
    "np.savetxt('acc_DAGNN.txt',accs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
