{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "class config:\n",
    "    data = 'Cora' #choices=[\"Cora\", \"Citeseer\", \"Pubmed\", \"Coauthor-CS\"]\n",
    "    model = 'TreeLSTM' #choices=[\"GCN\", \"GAT\", \"APPNP\", \"DAGNN\", \"TreeLSTM\", \"GTCN\", \"GTAN\", \"GTCN2\"]\n",
    "    n_in = 0\n",
    "    n_hid = 64\n",
    "    n_out = 0\n",
    "    num_heads = 1\n",
    "    num_out_heads = 1\n",
    "    device = 'cpu'\n",
    "    dropout = 0.8\n",
    "    dropout2 = 0.6\n",
    "    learning_rate = 0.01\n",
    "    weight_decay = 5e-4\n",
    "    patience = 200\n",
    "    num_iter = 1000\n",
    "    num_test = 30\n",
    "    hop = 5\n",
    "    alpha = 0.1 # used by APPNP only\n",
    "    random_label_split = False\n",
    "    num_train = 20 # for random label split only\n",
    "    num_val = 30 # for random label split only\n",
    "    data_load = True # load the saved label split to rerun the test (for reproduce purpose)\n",
    "    test_id = 1 # number of the test, only used to record the ith number of the random label split (for reproduce purpose)\n",
    "    filter_pct = 0.1 # remove the top and bottom filer_pct points before obtaining statistics of test accuracy\n",
    "    log = False # whether to show the training log or not\n",
    "    eval_metric = 'acc' # evaluation metrics, choices=[\"acc\", \"f1-macro\", \"f1-micro\"]\n",
    "    root_dir = '../..' # dir of the source code\n",
    "\n",
    "args = config()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append(args.root_dir)\n",
    "from train import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "start testing on Cora dataset\n",
      "  NumNodes: 2708\n",
      "  NumEdges: 10556\n",
      "  NumFeats: 1433\n",
      "  NumClasses: 7\n",
      "  NumTrainingSamples: 140\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "time duration = 13.253, best val acc = 0.808, test acc = 0.813\n",
      "time duration = 8.430, best val acc = 0.812, test acc = 0.796\n",
      "time duration = 8.987, best val acc = 0.814, test acc = 0.787\n",
      "time duration = 11.320, best val acc = 0.816, test acc = 0.811\n",
      "time duration = 9.716, best val acc = 0.806, test acc = 0.796\n",
      "time duration = 16.511, best val acc = 0.812, test acc = 0.808\n",
      "time duration = 8.672, best val acc = 0.818, test acc = 0.816\n",
      "time duration = 9.590, best val acc = 0.810, test acc = 0.819\n",
      "time duration = 20.154, best val acc = 0.814, test acc = 0.815\n",
      "time duration = 10.683, best val acc = 0.804, test acc = 0.810\n",
      "time duration = 8.933, best val acc = 0.818, test acc = 0.819\n",
      "time duration = 9.260, best val acc = 0.818, test acc = 0.818\n",
      "time duration = 14.922, best val acc = 0.812, test acc = 0.820\n",
      "time duration = 12.022, best val acc = 0.804, test acc = 0.802\n",
      "time duration = 14.261, best val acc = 0.814, test acc = 0.819\n",
      "time duration = 12.840, best val acc = 0.806, test acc = 0.803\n",
      "time duration = 16.133, best val acc = 0.798, test acc = 0.798\n",
      "time duration = 10.067, best val acc = 0.816, test acc = 0.806\n",
      "time duration = 13.769, best val acc = 0.802, test acc = 0.818\n",
      "time duration = 9.312, best val acc = 0.814, test acc = 0.812\n",
      "time duration = 9.851, best val acc = 0.810, test acc = 0.806\n",
      "time duration = 12.798, best val acc = 0.818, test acc = 0.809\n",
      "time duration = 10.319, best val acc = 0.812, test acc = 0.799\n",
      "time duration = 17.781, best val acc = 0.814, test acc = 0.808\n",
      "time duration = 15.394, best val acc = 0.812, test acc = 0.816\n",
      "time duration = 10.164, best val acc = 0.804, test acc = 0.788\n",
      "time duration = 8.544, best val acc = 0.806, test acc = 0.796\n",
      "time duration = 17.683, best val acc = 0.804, test acc = 0.806\n",
      "time duration = 8.797, best val acc = 0.810, test acc = 0.817\n",
      "time duration = 9.902, best val acc = 0.810, test acc = 0.816\n",
      "test acc (mean, std):  0.8082333783308665 0.009326244762933337\n",
      "test acc (mean, std) after filter:  0.8090833748380343 0.0071058431949443555\n",
      "\n",
      "start testing on Citeseer dataset\n",
      "  NumNodes: 3327\n",
      "  NumEdges: 9228\n",
      "  NumFeats: 3703\n",
      "  NumClasses: 6\n",
      "  NumTrainingSamples: 120\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "time duration = 21.569, best val acc = 0.686, test acc = 0.655\n",
      "time duration = 17.445, best val acc = 0.670, test acc = 0.652\n",
      "time duration = 17.272, best val acc = 0.684, test acc = 0.679\n",
      "time duration = 18.444, best val acc = 0.690, test acc = 0.673\n",
      "time duration = 20.027, best val acc = 0.684, test acc = 0.682\n",
      "time duration = 23.712, best val acc = 0.662, test acc = 0.639\n",
      "time duration = 21.447, best val acc = 0.664, test acc = 0.653\n",
      "time duration = 33.323, best val acc = 0.676, test acc = 0.636\n",
      "time duration = 26.508, best val acc = 0.676, test acc = 0.656\n",
      "time duration = 19.319, best val acc = 0.696, test acc = 0.696\n",
      "time duration = 17.393, best val acc = 0.644, test acc = 0.637\n",
      "time duration = 29.907, best val acc = 0.682, test acc = 0.660\n",
      "time duration = 18.124, best val acc = 0.676, test acc = 0.666\n",
      "time duration = 19.153, best val acc = 0.672, test acc = 0.692\n",
      "time duration = 31.073, best val acc = 0.672, test acc = 0.645\n",
      "time duration = 38.104, best val acc = 0.664, test acc = 0.665\n",
      "time duration = 21.446, best val acc = 0.660, test acc = 0.648\n",
      "time duration = 19.618, best val acc = 0.680, test acc = 0.649\n",
      "time duration = 18.718, best val acc = 0.668, test acc = 0.645\n",
      "time duration = 23.509, best val acc = 0.680, test acc = 0.670\n",
      "time duration = 18.953, best val acc = 0.686, test acc = 0.671\n",
      "time duration = 17.029, best val acc = 0.672, test acc = 0.638\n",
      "time duration = 17.695, best val acc = 0.666, test acc = 0.651\n",
      "time duration = 17.156, best val acc = 0.650, test acc = 0.642\n",
      "time duration = 17.559, best val acc = 0.690, test acc = 0.672\n",
      "time duration = 28.885, best val acc = 0.678, test acc = 0.648\n",
      "time duration = 18.974, best val acc = 0.668, test acc = 0.659\n",
      "time duration = 17.753, best val acc = 0.664, test acc = 0.669\n",
      "time duration = 17.616, best val acc = 0.680, test acc = 0.659\n",
      "time duration = 17.290, best val acc = 0.694, test acc = 0.672\n",
      "test acc (mean, std):  0.6593000292778015 0.015727154018695785\n",
      "test acc (mean, std) after filter:  0.6582500288883845 0.011121486602849865\n",
      "\n",
      "start testing on Pubmed dataset\n",
      "  NumNodes: 19717\n",
      "  NumEdges: 88651\n",
      "  NumFeats: 500\n",
      "  NumClasses: 3\n",
      "  NumTrainingSamples: 60\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "time duration = 27.251, best val acc = 0.790, test acc = 0.759\n",
      "time duration = 27.432, best val acc = 0.804, test acc = 0.769\n",
      "time duration = 41.618, best val acc = 0.804, test acc = 0.777\n",
      "time duration = 28.249, best val acc = 0.800, test acc = 0.784\n",
      "time duration = 40.182, best val acc = 0.792, test acc = 0.771\n",
      "time duration = 31.066, best val acc = 0.798, test acc = 0.763\n",
      "time duration = 30.128, best val acc = 0.794, test acc = 0.767\n",
      "time duration = 33.926, best val acc = 0.790, test acc = 0.766\n",
      "time duration = 26.106, best val acc = 0.806, test acc = 0.775\n",
      "time duration = 39.337, best val acc = 0.796, test acc = 0.765\n",
      "time duration = 27.638, best val acc = 0.796, test acc = 0.772\n",
      "time duration = 35.692, best val acc = 0.798, test acc = 0.768\n",
      "time duration = 25.970, best val acc = 0.808, test acc = 0.779\n",
      "time duration = 36.245, best val acc = 0.796, test acc = 0.762\n",
      "time duration = 26.219, best val acc = 0.792, test acc = 0.773\n",
      "time duration = 26.314, best val acc = 0.788, test acc = 0.759\n",
      "time duration = 25.052, best val acc = 0.800, test acc = 0.763\n",
      "time duration = 29.684, best val acc = 0.802, test acc = 0.773\n",
      "time duration = 26.960, best val acc = 0.800, test acc = 0.767\n",
      "time duration = 26.968, best val acc = 0.798, test acc = 0.773\n",
      "time duration = 38.285, best val acc = 0.798, test acc = 0.771\n",
      "time duration = 26.861, best val acc = 0.804, test acc = 0.768\n",
      "time duration = 34.196, best val acc = 0.786, test acc = 0.756\n",
      "time duration = 36.460, best val acc = 0.788, test acc = 0.766\n",
      "time duration = 27.142, best val acc = 0.796, test acc = 0.782\n",
      "time duration = 26.586, best val acc = 0.800, test acc = 0.776\n",
      "time duration = 52.251, best val acc = 0.794, test acc = 0.770\n",
      "time duration = 26.869, best val acc = 0.802, test acc = 0.768\n",
      "time duration = 26.530, best val acc = 0.792, test acc = 0.768\n",
      "time duration = 32.610, best val acc = 0.784, test acc = 0.748\n",
      "test acc (mean, std):  0.7686000327269237 0.00748598920753327\n",
      "test acc (mean, std) after filter:  0.768750029305617 0.004474465952779865\n",
      "\n",
      "start testing on Coauthor-CS dataset with random split: 2\n",
      "time duration = 173.665, best val acc = 0.902, test acc = 0.907\n",
      "time duration = 160.117, best val acc = 0.902, test acc = 0.911\n",
      "time duration = 141.780, best val acc = 0.904, test acc = 0.911\n",
      "time duration = 277.379, best val acc = 0.902, test acc = 0.905\n",
      "time duration = 158.698, best val acc = 0.898, test acc = 0.909\n",
      "time duration = 149.075, best val acc = 0.900, test acc = 0.909\n",
      "time duration = 159.495, best val acc = 0.907, test acc = 0.909\n",
      "time duration = 148.112, best val acc = 0.904, test acc = 0.910\n",
      "time duration = 146.003, best val acc = 0.900, test acc = 0.911\n",
      "time duration = 205.934, best val acc = 0.898, test acc = 0.904\n",
      "time duration = 194.665, best val acc = 0.904, test acc = 0.912\n",
      "time duration = 166.552, best val acc = 0.902, test acc = 0.911\n",
      "time duration = 165.898, best val acc = 0.909, test acc = 0.912\n",
      "time duration = 159.715, best val acc = 0.907, test acc = 0.910\n",
      "time duration = 245.933, best val acc = 0.907, test acc = 0.910\n",
      "time duration = 209.023, best val acc = 0.900, test acc = 0.906\n",
      "time duration = 162.364, best val acc = 0.900, test acc = 0.909\n",
      "time duration = 184.496, best val acc = 0.909, test acc = 0.910\n",
      "time duration = 174.380, best val acc = 0.900, test acc = 0.910\n",
      "time duration = 146.315, best val acc = 0.907, test acc = 0.910\n",
      "time duration = 183.420, best val acc = 0.902, test acc = 0.909\n",
      "time duration = 147.341, best val acc = 0.900, test acc = 0.913\n",
      "time duration = 210.459, best val acc = 0.902, test acc = 0.912\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time duration = 272.050, best val acc = 0.902, test acc = 0.909\n",
      "time duration = 137.093, best val acc = 0.898, test acc = 0.909\n",
      "time duration = 136.592, best val acc = 0.900, test acc = 0.906\n",
      "time duration = 158.590, best val acc = 0.904, test acc = 0.913\n",
      "time duration = 214.019, best val acc = 0.900, test acc = 0.909\n",
      "time duration = 155.109, best val acc = 0.902, test acc = 0.912\n",
      "time duration = 148.033, best val acc = 0.902, test acc = 0.911\n",
      "test acc (mean, std):  0.909645672639211 0.0021253199339950585\n",
      "test acc (mean, std) after filter:  0.9098703215519587 0.001357159003338635\n"
     ]
    }
   ],
   "source": [
    "accs = []\n",
    "datasets = ['Cora', 'Citeseer', 'Pubmed', 'Coauthor-CS']\n",
    "for dataset in datasets:\n",
    "    args.data = dataset\n",
    "    if dataset == 'Cora':\n",
    "        print('\\nstart testing on ' + dataset + ' dataset')\n",
    "        accs.append(main(args))\n",
    "    elif dataset == 'Citeseer':\n",
    "        print('\\nstart testing on ' + dataset + ' dataset')\n",
    "        accs.append(main(args))\n",
    "    elif dataset == 'Pubmed':\n",
    "        print('\\nstart testing on ' + dataset + ' dataset')\n",
    "        accs.append(main(args))\n",
    "    elif dataset == 'Coauthor-CS':\n",
    "        for i in [2]:\n",
    "            print('\\nstart testing on ' + dataset + ' dataset with random split: ' + str(i))\n",
    "            args.test_id = i\n",
    "            accs.append(main(args))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "accs = np.array(accs)\n",
    "np.savetxt('acc_TreeLSTM_hop5.txt',accs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
