{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "class config:\n",
    "    data = 'Cora' #choices=[\"Cora\", \"Citeseer\", \"Pubmed\", \"Coauthor-CS\"]\n",
    "    model = 'GTCN2' #choices=[\"GCN\", \"GAT\", \"APPNP\", \"DAGNN\", \"TreeLSTM\", \"GTCN\", \"GTAN\", \"GTCN2\"]\n",
    "    n_in = 0\n",
    "    n_hid = 64\n",
    "    n_out = 0\n",
    "    num_heads = 1\n",
    "    num_out_heads = 1\n",
    "    device = 'cpu'\n",
    "    dropout = 0.6\n",
    "    dropout2 = 0.6\n",
    "    learning_rate = 0.01\n",
    "    weight_decay = 5e-4\n",
    "    patience = 200\n",
    "    num_iter = 1000\n",
    "    num_test = 30\n",
    "    hop = 5\n",
    "    alpha = 0.1 # used by APPNP only\n",
    "    random_label_split = False\n",
    "    num_train = 20 # for random label split only\n",
    "    num_val = 30 # for random label split only\n",
    "    data_load = True # load the saved label split to rerun the test (for reproduce purpose)\n",
    "    test_id = 1 # number of the test, only used to record the ith number of the random label split (for reproduce purpose)\n",
    "    filter_pct = 0.1 # remove the top and bottom filer_pct points before obtaining statistics of test accuracy\n",
    "    log = False # whether to show the training log or not\n",
    "    eval_metric = 'acc' # evaluation metrics, choices=[\"acc\", \"f1-macro\", \"f1-micro\"]\n",
    "    root_dir = '../..' # dir of the source code\n",
    "\n",
    "args = config()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append(args.root_dir)\n",
    "from train import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "start testing on Cora dataset\n",
      "  NumNodes: 2708\n",
      "  NumEdges: 10556\n",
      "  NumFeats: 1433\n",
      "  NumClasses: 7\n",
      "  NumTrainingSamples: 140\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "time duration = 6.363, best val acc = 0.818, test acc = 0.821\n",
      "time duration = 10.277, best val acc = 0.816, test acc = 0.797\n",
      "time duration = 4.445, best val acc = 0.826, test acc = 0.800\n",
      "time duration = 7.340, best val acc = 0.802, test acc = 0.790\n",
      "time duration = 4.331, best val acc = 0.804, test acc = 0.801\n",
      "time duration = 5.604, best val acc = 0.814, test acc = 0.816\n",
      "time duration = 4.678, best val acc = 0.816, test acc = 0.810\n",
      "time duration = 4.248, best val acc = 0.808, test acc = 0.805\n",
      "time duration = 6.572, best val acc = 0.798, test acc = 0.797\n",
      "time duration = 4.427, best val acc = 0.810, test acc = 0.815\n",
      "time duration = 6.325, best val acc = 0.810, test acc = 0.808\n",
      "time duration = 5.367, best val acc = 0.808, test acc = 0.817\n",
      "time duration = 5.338, best val acc = 0.804, test acc = 0.800\n",
      "time duration = 4.047, best val acc = 0.800, test acc = 0.818\n",
      "time duration = 5.362, best val acc = 0.808, test acc = 0.809\n",
      "time duration = 4.399, best val acc = 0.808, test acc = 0.802\n",
      "time duration = 4.533, best val acc = 0.806, test acc = 0.800\n",
      "time duration = 5.045, best val acc = 0.800, test acc = 0.795\n",
      "time duration = 4.601, best val acc = 0.796, test acc = 0.791\n",
      "time duration = 4.849, best val acc = 0.808, test acc = 0.810\n",
      "time duration = 4.499, best val acc = 0.806, test acc = 0.797\n",
      "time duration = 4.796, best val acc = 0.816, test acc = 0.805\n",
      "time duration = 5.625, best val acc = 0.796, test acc = 0.807\n",
      "time duration = 6.208, best val acc = 0.818, test acc = 0.813\n",
      "time duration = 4.476, best val acc = 0.804, test acc = 0.809\n",
      "time duration = 5.124, best val acc = 0.806, test acc = 0.832\n",
      "time duration = 4.424, best val acc = 0.808, test acc = 0.828\n",
      "time duration = 4.479, best val acc = 0.816, test acc = 0.802\n",
      "time duration = 7.658, best val acc = 0.798, test acc = 0.802\n",
      "time duration = 4.198, best val acc = 0.806, test acc = 0.802\n",
      "test acc (mean, std):  0.8066333711147309 0.009944794779337166\n",
      "test acc (mean, std) after filter:  0.8059167067209879 0.006428558576015201\n"
     ]
    }
   ],
   "source": [
    "accs = []\n",
    "datasets = ['Cora']\n",
    "for dataset in datasets:\n",
    "    args.data = dataset\n",
    "    if dataset == 'Cora':\n",
    "        args.dropout = 0.6\n",
    "        args.dropout2 = 0.6\n",
    "        args.weight_decay = 5e-4\n",
    "        args.learning_rate = 0.01\n",
    "        print('\\nstart testing on ' + dataset + ' dataset')\n",
    "        accs.append(main(args))\n",
    "    elif dataset == 'Citeseer':\n",
    "        args.dropout = 0.8\n",
    "        args.dropout2 = 0.6\n",
    "        args.weight_decay = 5e-4\n",
    "        args.learning_rate = 0.01\n",
    "        print('\\nstart testing on ' + dataset + ' dataset')\n",
    "        accs.append(main(args))\n",
    "    elif dataset == 'Pubmed':\n",
    "        args.dropout = 0.8\n",
    "        args.dropout2 = 0.5\n",
    "        args.weight_decay = 5e-4\n",
    "        args.learning_rate = 0.02\n",
    "        print('\\nstart testing on ' + dataset + ' dataset')\n",
    "        accs.append(main(args))\n",
    "    elif dataset == 'Coauthor-CS':\n",
    "        args.dropout = 0.6\n",
    "        args.dropout2 = 0.2\n",
    "        args.weight_decay = 5e-3\n",
    "        args.learning_rate = 0.01\n",
    "        for i in range(3):\n",
    "            print('\\nstart testing on ' + dataset + ' dataset with random split: ' + str(i))\n",
    "            args.test_id = i\n",
    "            accs.append(main(args))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "accs = np.array(accs)\n",
    "np.savetxt('acc_GTCN2_hop5.txt',accs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
