{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "class config:\n",
    "    data = 'Cora' #choices=[\"Cora\", \"Citeseer\", \"Pubmed\", \"Coauthor-CS\"]\n",
    "    model = 'TreeLSTM' #choices=[\"GCN\", \"GAT\", \"APPNP\", \"DAGNN\", \"TreeLSTM\", \"GTCN\", \"GTAN\", \"GTCN2\"]\n",
    "    n_in = 0\n",
    "    n_hid = 64\n",
    "    n_out = 0\n",
    "    num_heads = 1\n",
    "    num_out_heads = 1\n",
    "    device = 'cpu'\n",
    "    dropout = 0.8\n",
    "    dropout2 = 0.6\n",
    "    learning_rate = 0.01\n",
    "    weight_decay = 5e-4\n",
    "    patience = 200\n",
    "    num_iter = 1000\n",
    "    num_test = 30\n",
    "    hop = 2\n",
    "    alpha = 0.1 # used by APPNP only\n",
    "    random_label_split = False\n",
    "    num_train = 20 # for random label split only\n",
    "    num_val = 30 # for random label split only\n",
    "    data_load = True # load the saved label split to rerun the test (for reproduce purpose)\n",
    "    test_id = 1 # number of the test, only used to record the ith number of the random label split (for reproduce purpose)\n",
    "    filter_pct = 0.1 # remove the top and bottom filer_pct points before obtaining statistics of test accuracy\n",
    "    log = False # whether to show the training log or not\n",
    "    eval_metric = 'acc' # evaluation metrics, choices=[\"acc\", \"f1-macro\", \"f1-micro\"]\n",
    "    root_dir = '../..' # dir of the source code\n",
    "\n",
    "args = config()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append(args.root_dir)\n",
    "from train import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "start testing on Cora dataset\n",
      "  NumNodes: 2708\n",
      "  NumEdges: 10556\n",
      "  NumFeats: 1433\n",
      "  NumClasses: 7\n",
      "  NumTrainingSamples: 140\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "time duration = 4.455, best val acc = 0.814, test acc = 0.818\n",
      "time duration = 3.576, best val acc = 0.832, test acc = 0.819\n",
      "time duration = 3.432, best val acc = 0.816, test acc = 0.801\n",
      "time duration = 3.084, best val acc = 0.814, test acc = 0.830\n",
      "time duration = 2.088, best val acc = 0.818, test acc = 0.824\n",
      "time duration = 2.424, best val acc = 0.810, test acc = 0.819\n",
      "time duration = 2.220, best val acc = 0.822, test acc = 0.828\n",
      "time duration = 1.996, best val acc = 0.812, test acc = 0.830\n",
      "time duration = 2.109, best val acc = 0.810, test acc = 0.832\n",
      "time duration = 3.640, best val acc = 0.818, test acc = 0.813\n",
      "time duration = 2.168, best val acc = 0.806, test acc = 0.798\n",
      "time duration = 2.584, best val acc = 0.814, test acc = 0.811\n",
      "time duration = 2.820, best val acc = 0.810, test acc = 0.811\n",
      "time duration = 2.125, best val acc = 0.814, test acc = 0.836\n",
      "time duration = 2.607, best val acc = 0.806, test acc = 0.811\n",
      "time duration = 2.417, best val acc = 0.822, test acc = 0.817\n",
      "time duration = 3.481, best val acc = 0.822, test acc = 0.826\n",
      "time duration = 2.389, best val acc = 0.814, test acc = 0.818\n",
      "time duration = 3.496, best val acc = 0.810, test acc = 0.807\n",
      "time duration = 2.117, best val acc = 0.808, test acc = 0.818\n",
      "time duration = 3.290, best val acc = 0.816, test acc = 0.800\n",
      "time duration = 2.451, best val acc = 0.818, test acc = 0.830\n",
      "time duration = 2.297, best val acc = 0.814, test acc = 0.829\n",
      "time duration = 2.753, best val acc = 0.816, test acc = 0.819\n",
      "time duration = 2.375, best val acc = 0.820, test acc = 0.819\n",
      "time duration = 2.750, best val acc = 0.808, test acc = 0.800\n",
      "time duration = 2.266, best val acc = 0.812, test acc = 0.816\n",
      "time duration = 2.237, best val acc = 0.808, test acc = 0.824\n",
      "time duration = 4.304, best val acc = 0.818, test acc = 0.824\n",
      "time duration = 3.094, best val acc = 0.812, test acc = 0.822\n",
      "test acc (mean, std):  0.818333375453949 0.009994446091676431\n",
      "test acc (mean, std) after filter:  0.8189167131980261 0.007268178321645288\n",
      "\n",
      "start testing on Citeseer dataset\n",
      "  NumNodes: 3327\n",
      "  NumEdges: 9228\n",
      "  NumFeats: 3703\n",
      "  NumClasses: 6\n",
      "  NumTrainingSamples: 120\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "time duration = 6.326, best val acc = 0.686, test acc = 0.663\n",
      "time duration = 5.254, best val acc = 0.700, test acc = 0.668\n",
      "time duration = 2.525, best val acc = 0.712, test acc = 0.684\n",
      "time duration = 2.440, best val acc = 0.700, test acc = 0.685\n",
      "time duration = 2.406, best val acc = 0.712, test acc = 0.712\n",
      "time duration = 2.725, best val acc = 0.702, test acc = 0.683\n",
      "time duration = 2.559, best val acc = 0.684, test acc = 0.657\n",
      "time duration = 2.736, best val acc = 0.706, test acc = 0.686\n",
      "time duration = 2.731, best val acc = 0.706, test acc = 0.688\n",
      "time duration = 3.206, best val acc = 0.708, test acc = 0.707\n",
      "time duration = 2.867, best val acc = 0.704, test acc = 0.688\n",
      "time duration = 4.407, best val acc = 0.696, test acc = 0.672\n",
      "time duration = 2.650, best val acc = 0.704, test acc = 0.686\n",
      "time duration = 2.669, best val acc = 0.690, test acc = 0.688\n",
      "time duration = 2.455, best val acc = 0.720, test acc = 0.690\n",
      "time duration = 2.521, best val acc = 0.710, test acc = 0.715\n",
      "time duration = 2.530, best val acc = 0.712, test acc = 0.706\n",
      "time duration = 3.192, best val acc = 0.714, test acc = 0.700\n",
      "time duration = 2.722, best val acc = 0.708, test acc = 0.688\n",
      "time duration = 2.579, best val acc = 0.700, test acc = 0.691\n",
      "time duration = 2.689, best val acc = 0.694, test acc = 0.690\n",
      "time duration = 2.719, best val acc = 0.700, test acc = 0.679\n",
      "time duration = 3.278, best val acc = 0.694, test acc = 0.671\n",
      "time duration = 2.534, best val acc = 0.710, test acc = 0.701\n",
      "time duration = 2.666, best val acc = 0.704, test acc = 0.679\n",
      "time duration = 2.430, best val acc = 0.694, test acc = 0.668\n",
      "time duration = 2.757, best val acc = 0.704, test acc = 0.686\n",
      "time duration = 2.768, best val acc = 0.698, test acc = 0.679\n",
      "time duration = 2.563, best val acc = 0.708, test acc = 0.703\n",
      "time duration = 2.894, best val acc = 0.696, test acc = 0.676\n",
      "test acc (mean, std):  0.6863000333309174 0.013950267493976208\n",
      "test acc (mean, std) after filter:  0.6861250350872675 0.00954948730286091\n",
      "\n",
      "start testing on Pubmed dataset\n",
      "  NumNodes: 19717\n",
      "  NumEdges: 88651\n",
      "  NumFeats: 500\n",
      "  NumClasses: 3\n",
      "  NumTrainingSamples: 60\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "time duration = 3.173, best val acc = 0.792, test acc = 0.770\n",
      "time duration = 5.028, best val acc = 0.794, test acc = 0.769\n",
      "time duration = 3.366, best val acc = 0.790, test acc = 0.763\n",
      "time duration = 3.737, best val acc = 0.800, test acc = 0.771\n",
      "time duration = 5.414, best val acc = 0.796, test acc = 0.765\n",
      "time duration = 3.850, best val acc = 0.794, test acc = 0.772\n",
      "time duration = 4.529, best val acc = 0.798, test acc = 0.768\n",
      "time duration = 4.212, best val acc = 0.794, test acc = 0.772\n",
      "time duration = 3.139, best val acc = 0.792, test acc = 0.747\n",
      "time duration = 3.789, best val acc = 0.788, test acc = 0.755\n",
      "time duration = 3.528, best val acc = 0.794, test acc = 0.772\n",
      "time duration = 5.382, best val acc = 0.796, test acc = 0.758\n",
      "time duration = 3.208, best val acc = 0.798, test acc = 0.763\n",
      "time duration = 3.565, best val acc = 0.794, test acc = 0.752\n",
      "time duration = 4.135, best val acc = 0.798, test acc = 0.755\n",
      "time duration = 5.844, best val acc = 0.796, test acc = 0.776\n",
      "time duration = 3.487, best val acc = 0.790, test acc = 0.761\n",
      "time duration = 3.309, best val acc = 0.792, test acc = 0.772\n",
      "time duration = 4.457, best val acc = 0.792, test acc = 0.762\n",
      "time duration = 3.953, best val acc = 0.792, test acc = 0.773\n",
      "time duration = 4.916, best val acc = 0.800, test acc = 0.768\n",
      "time duration = 3.517, best val acc = 0.792, test acc = 0.751\n",
      "time duration = 4.066, best val acc = 0.798, test acc = 0.777\n",
      "time duration = 4.159, best val acc = 0.790, test acc = 0.770\n",
      "time duration = 3.196, best val acc = 0.792, test acc = 0.762\n",
      "time duration = 5.297, best val acc = 0.796, test acc = 0.775\n",
      "time duration = 3.243, best val acc = 0.792, test acc = 0.760\n",
      "time duration = 4.830, best val acc = 0.794, test acc = 0.772\n",
      "time duration = 3.553, best val acc = 0.784, test acc = 0.761\n",
      "time duration = 4.730, best val acc = 0.796, test acc = 0.763\n",
      "test acc (mean, std):  0.765166695912679 0.007784951247184311\n",
      "test acc (mean, std) after filter:  0.7657083620627722 0.0056011794887851525\n",
      "\n",
      "start testing on Coauthor-CS dataset with random split: 0\n",
      "time duration = 9.921, best val acc = 0.893, test acc = 0.911\n",
      "time duration = 8.137, best val acc = 0.893, test acc = 0.913\n",
      "time duration = 14.478, best val acc = 0.896, test acc = 0.910\n",
      "time duration = 8.698, best val acc = 0.889, test acc = 0.909\n",
      "time duration = 8.402, best val acc = 0.887, test acc = 0.913\n",
      "time duration = 8.903, best val acc = 0.891, test acc = 0.911\n",
      "time duration = 8.049, best val acc = 0.893, test acc = 0.909\n",
      "time duration = 8.335, best val acc = 0.893, test acc = 0.910\n",
      "time duration = 8.438, best val acc = 0.891, test acc = 0.912\n",
      "time duration = 7.988, best val acc = 0.891, test acc = 0.903\n",
      "time duration = 8.822, best val acc = 0.887, test acc = 0.910\n",
      "time duration = 9.917, best val acc = 0.887, test acc = 0.912\n",
      "time duration = 9.061, best val acc = 0.896, test acc = 0.911\n",
      "time duration = 8.080, best val acc = 0.893, test acc = 0.909\n",
      "time duration = 8.237, best val acc = 0.893, test acc = 0.906\n",
      "time duration = 8.208, best val acc = 0.889, test acc = 0.904\n",
      "time duration = 8.577, best val acc = 0.898, test acc = 0.909\n",
      "time duration = 8.170, best val acc = 0.893, test acc = 0.908\n",
      "time duration = 8.597, best val acc = 0.893, test acc = 0.910\n",
      "time duration = 9.374, best val acc = 0.896, test acc = 0.911\n",
      "time duration = 8.169, best val acc = 0.889, test acc = 0.906\n",
      "time duration = 8.776, best val acc = 0.900, test acc = 0.911\n",
      "time duration = 10.705, best val acc = 0.893, test acc = 0.910\n",
      "time duration = 8.108, best val acc = 0.900, test acc = 0.909\n",
      "time duration = 13.970, best val acc = 0.893, test acc = 0.912\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time duration = 8.027, best val acc = 0.902, test acc = 0.907\n",
      "time duration = 10.803, best val acc = 0.889, test acc = 0.912\n",
      "time duration = 8.137, best val acc = 0.889, test acc = 0.907\n",
      "time duration = 8.200, best val acc = 0.898, test acc = 0.911\n",
      "time duration = 8.137, best val acc = 0.893, test acc = 0.915\n",
      "test acc (mean, std):  0.9096551517645518 0.0027188995436304327\n",
      "test acc (mean, std) after filter:  0.9098347748319308 0.0016503664325685748\n",
      "\n",
      "start testing on Coauthor-CS dataset with random split: 1\n",
      "time duration = 8.707, best val acc = 0.913, test acc = 0.909\n",
      "time duration = 14.787, best val acc = 0.911, test acc = 0.910\n",
      "time duration = 9.247, best val acc = 0.918, test acc = 0.911\n",
      "time duration = 8.404, best val acc = 0.918, test acc = 0.908\n",
      "time duration = 9.580, best val acc = 0.918, test acc = 0.912\n",
      "time duration = 11.194, best val acc = 0.920, test acc = 0.909\n",
      "time duration = 8.500, best val acc = 0.918, test acc = 0.901\n",
      "time duration = 16.995, best val acc = 0.918, test acc = 0.902\n",
      "time duration = 11.460, best val acc = 0.911, test acc = 0.909\n",
      "time duration = 8.186, best val acc = 0.920, test acc = 0.909\n",
      "time duration = 11.903, best val acc = 0.916, test acc = 0.904\n",
      "time duration = 8.671, best val acc = 0.916, test acc = 0.910\n",
      "time duration = 9.988, best val acc = 0.918, test acc = 0.912\n",
      "time duration = 9.141, best val acc = 0.911, test acc = 0.911\n",
      "time duration = 12.763, best val acc = 0.913, test acc = 0.908\n",
      "time duration = 8.829, best val acc = 0.922, test acc = 0.911\n",
      "time duration = 8.532, best val acc = 0.918, test acc = 0.905\n",
      "time duration = 15.771, best val acc = 0.918, test acc = 0.908\n",
      "time duration = 8.799, best val acc = 0.918, test acc = 0.908\n",
      "time duration = 9.704, best val acc = 0.916, test acc = 0.911\n",
      "time duration = 10.161, best val acc = 0.916, test acc = 0.907\n",
      "time duration = 13.678, best val acc = 0.918, test acc = 0.907\n",
      "time duration = 8.718, best val acc = 0.916, test acc = 0.910\n",
      "time duration = 9.325, best val acc = 0.920, test acc = 0.907\n",
      "time duration = 9.338, best val acc = 0.922, test acc = 0.909\n",
      "time duration = 8.174, best val acc = 0.922, test acc = 0.905\n",
      "time duration = 8.902, best val acc = 0.913, test acc = 0.910\n",
      "time duration = 9.252, best val acc = 0.920, test acc = 0.910\n",
      "time duration = 11.117, best val acc = 0.918, test acc = 0.907\n",
      "time duration = 8.938, best val acc = 0.922, test acc = 0.906\n",
      "test acc (mean, std):  0.9081404368082683 0.0026845940136275386\n",
      "test acc (mean, std) after filter:  0.9083987350265185 0.0017674899000668526\n",
      "\n",
      "start testing on Coauthor-CS dataset with random split: 2\n",
      "time duration = 10.797, best val acc = 0.898, test acc = 0.909\n",
      "time duration = 9.994, best val acc = 0.898, test acc = 0.907\n",
      "time duration = 9.360, best val acc = 0.898, test acc = 0.904\n",
      "time duration = 21.773, best val acc = 0.909, test acc = 0.908\n",
      "time duration = 12.246, best val acc = 0.900, test acc = 0.909\n",
      "time duration = 11.714, best val acc = 0.902, test acc = 0.908\n",
      "time duration = 9.292, best val acc = 0.898, test acc = 0.905\n",
      "time duration = 10.848, best val acc = 0.907, test acc = 0.910\n",
      "time duration = 10.652, best val acc = 0.904, test acc = 0.907\n",
      "time duration = 8.860, best val acc = 0.900, test acc = 0.912\n",
      "time duration = 8.620, best val acc = 0.896, test acc = 0.909\n",
      "time duration = 16.432, best val acc = 0.904, test acc = 0.907\n",
      "time duration = 16.339, best val acc = 0.907, test acc = 0.910\n",
      "time duration = 13.227, best val acc = 0.907, test acc = 0.908\n",
      "time duration = 10.129, best val acc = 0.904, test acc = 0.909\n",
      "time duration = 13.194, best val acc = 0.902, test acc = 0.903\n",
      "time duration = 15.916, best val acc = 0.907, test acc = 0.909\n",
      "time duration = 17.301, best val acc = 0.900, test acc = 0.896\n",
      "time duration = 15.067, best val acc = 0.904, test acc = 0.912\n",
      "time duration = 16.466, best val acc = 0.904, test acc = 0.910\n",
      "time duration = 9.872, best val acc = 0.902, test acc = 0.906\n",
      "time duration = 8.514, best val acc = 0.907, test acc = 0.906\n",
      "time duration = 12.423, best val acc = 0.896, test acc = 0.904\n",
      "time duration = 14.316, best val acc = 0.902, test acc = 0.907\n",
      "time duration = 8.574, best val acc = 0.902, test acc = 0.908\n",
      "time duration = 18.471, best val acc = 0.904, test acc = 0.908\n",
      "time duration = 17.750, best val acc = 0.902, test acc = 0.909\n",
      "time duration = 14.973, best val acc = 0.907, test acc = 0.903\n",
      "time duration = 12.998, best val acc = 0.896, test acc = 0.909\n",
      "time duration = 10.021, best val acc = 0.902, test acc = 0.913\n",
      "test acc (mean, std):  0.907556535800298 0.0031511546018612177\n",
      "test acc (mean, std) after filter:  0.9078157842159271 0.001757881725644659\n"
     ]
    }
   ],
   "source": [
    "accs = []\n",
    "datasets = ['Cora', 'Citeseer', 'Pubmed', 'Coauthor-CS']\n",
    "for dataset in datasets:\n",
    "    args.data = dataset\n",
    "    if dataset == 'Cora':\n",
    "        print('\\nstart testing on ' + dataset + ' dataset')\n",
    "        accs.append(main(args))\n",
    "    elif dataset == 'Citeseer':\n",
    "        print('\\nstart testing on ' + dataset + ' dataset')\n",
    "        accs.append(main(args))\n",
    "    elif dataset == 'Pubmed':\n",
    "        print('\\nstart testing on ' + dataset + ' dataset')\n",
    "        accs.append(main(args))\n",
    "    elif dataset == 'Coauthor-CS':\n",
    "        for i in range(3):\n",
    "            print('\\nstart testing on ' + dataset + ' dataset with random split: ' + str(i))\n",
    "            args.test_id = i\n",
    "            accs.append(main(args))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "accs = np.array(accs)\n",
    "np.savetxt('acc_TreeLSTM_hop2.txt',accs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
