{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "class config:\n",
    "    data = 'Cora' #choices=[\"Cora\", \"Citeseer\", \"Pubmed\", \"Coauthor-CS\"]\n",
    "    model = 'GAT' #choices=[\"GCN\", \"GAT\", \"APPNP\", \"DAGNN\", \"TreeLSTM\", \"GTCN\", \"GTAN\", \"GTCN2\"]\n",
    "    n_in = 0\n",
    "    n_hid = 64\n",
    "    n_out = 0\n",
    "    num_heads = 1\n",
    "    num_out_heads = 1\n",
    "    device = 'cpu'\n",
    "    dropout = 0.8\n",
    "    dropout2 = 0\n",
    "    learning_rate = 0.01\n",
    "    weight_decay = 5e-4\n",
    "    patience = 200\n",
    "    num_iter = 1000\n",
    "    num_test = 30\n",
    "    hop = 5\n",
    "    alpha = 0.1 # used by APPNP only\n",
    "    random_label_split = False\n",
    "    num_train = 20 # for random label split only\n",
    "    num_val = 30 # for random label split only\n",
    "    data_load = True # load the saved label split to rerun the test (for reproduce purpose)\n",
    "    test_id = 1 # number of the test, only used to record the ith number of the random label split (for reproduce purpose)\n",
    "    filter_pct = 0.1 # remove the top and bottom filer_pct points before obtaining statistics of test accuracy\n",
    "    log = False # whether to show the training log or not\n",
    "    eval_metric = 'acc' # evaluation metrics, choices=[\"acc\", \"f1-macro\", \"f1-micro\"]\n",
    "    root_dir = '../..' # dir of the source code\n",
    "\n",
    "args = config()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append(args.root_dir)\n",
    "from train import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "start testing on Cora dataset\n",
      "  NumNodes: 2708\n",
      "  NumEdges: 10556\n",
      "  NumFeats: 1433\n",
      "  NumClasses: 7\n",
      "  NumTrainingSamples: 140\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "time duration = 8.943, best val acc = 0.802, test acc = 0.791\n",
      "time duration = 13.083, best val acc = 0.794, test acc = 0.796\n",
      "time duration = 7.850, best val acc = 0.810, test acc = 0.805\n",
      "time duration = 7.320, best val acc = 0.816, test acc = 0.799\n",
      "time duration = 7.715, best val acc = 0.810, test acc = 0.794\n",
      "time duration = 6.949, best val acc = 0.808, test acc = 0.802\n",
      "time duration = 7.397, best val acc = 0.816, test acc = 0.811\n",
      "time duration = 7.412, best val acc = 0.798, test acc = 0.793\n",
      "time duration = 7.632, best val acc = 0.802, test acc = 0.797\n",
      "time duration = 9.265, best val acc = 0.808, test acc = 0.789\n",
      "time duration = 7.272, best val acc = 0.794, test acc = 0.801\n",
      "time duration = 7.863, best val acc = 0.808, test acc = 0.811\n",
      "time duration = 9.078, best val acc = 0.820, test acc = 0.814\n",
      "time duration = 7.839, best val acc = 0.816, test acc = 0.808\n",
      "time duration = 8.079, best val acc = 0.816, test acc = 0.812\n",
      "time duration = 8.182, best val acc = 0.814, test acc = 0.795\n",
      "time duration = 7.625, best val acc = 0.792, test acc = 0.796\n",
      "time duration = 6.928, best val acc = 0.798, test acc = 0.784\n",
      "time duration = 8.790, best val acc = 0.812, test acc = 0.796\n",
      "time duration = 7.804, best val acc = 0.800, test acc = 0.794\n",
      "time duration = 8.991, best val acc = 0.810, test acc = 0.797\n",
      "time duration = 7.055, best val acc = 0.816, test acc = 0.806\n",
      "time duration = 7.398, best val acc = 0.816, test acc = 0.815\n",
      "time duration = 7.320, best val acc = 0.804, test acc = 0.808\n",
      "time duration = 7.644, best val acc = 0.814, test acc = 0.803\n",
      "time duration = 8.891, best val acc = 0.812, test acc = 0.801\n",
      "time duration = 7.127, best val acc = 0.800, test acc = 0.800\n",
      "time duration = 9.305, best val acc = 0.810, test acc = 0.811\n",
      "time duration = 7.488, best val acc = 0.816, test acc = 0.807\n",
      "time duration = 10.825, best val acc = 0.816, test acc = 0.795\n",
      "test acc (mean, std):  0.8010333736737569 0.007842124954779268\n",
      "test acc (mean, std) after filter:  0.801083376010259 0.005865982241176457\n",
      "\n",
      "start testing on Citeseer dataset\n",
      "  NumNodes: 3327\n",
      "  NumEdges: 9228\n",
      "  NumFeats: 3703\n",
      "  NumClasses: 6\n",
      "  NumTrainingSamples: 120\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "time duration = 14.079, best val acc = 0.688, test acc = 0.678\n",
      "time duration = 7.183, best val acc = 0.688, test acc = 0.688\n",
      "time duration = 6.565, best val acc = 0.682, test acc = 0.680\n",
      "time duration = 8.533, best val acc = 0.672, test acc = 0.675\n",
      "time duration = 7.035, best val acc = 0.700, test acc = 0.671\n",
      "time duration = 7.483, best val acc = 0.680, test acc = 0.689\n",
      "time duration = 10.047, best val acc = 0.680, test acc = 0.672\n",
      "time duration = 6.470, best val acc = 0.710, test acc = 0.681\n",
      "time duration = 10.001, best val acc = 0.680, test acc = 0.689\n",
      "time duration = 7.170, best val acc = 0.722, test acc = 0.713\n",
      "time duration = 7.621, best val acc = 0.700, test acc = 0.697\n",
      "time duration = 7.092, best val acc = 0.712, test acc = 0.703\n",
      "time duration = 7.387, best val acc = 0.680, test acc = 0.658\n",
      "time duration = 8.698, best val acc = 0.682, test acc = 0.670\n",
      "time duration = 22.830, best val acc = 0.684, test acc = 0.675\n",
      "time duration = 11.050, best val acc = 0.668, test acc = 0.657\n",
      "time duration = 7.132, best val acc = 0.688, test acc = 0.689\n",
      "time duration = 19.159, best val acc = 0.674, test acc = 0.659\n",
      "time duration = 7.024, best val acc = 0.696, test acc = 0.673\n",
      "time duration = 6.816, best val acc = 0.698, test acc = 0.676\n",
      "time duration = 10.398, best val acc = 0.692, test acc = 0.674\n",
      "time duration = 9.758, best val acc = 0.688, test acc = 0.683\n",
      "time duration = 7.047, best val acc = 0.674, test acc = 0.691\n",
      "time duration = 13.626, best val acc = 0.676, test acc = 0.668\n",
      "time duration = 7.017, best val acc = 0.720, test acc = 0.699\n",
      "time duration = 6.793, best val acc = 0.670, test acc = 0.682\n",
      "time duration = 9.100, best val acc = 0.694, test acc = 0.689\n",
      "time duration = 7.018, best val acc = 0.698, test acc = 0.669\n",
      "time duration = 7.009, best val acc = 0.690, test acc = 0.677\n",
      "time duration = 7.125, best val acc = 0.714, test acc = 0.699\n",
      "test acc (mean, std):  0.6808000306288401 0.013242358696458466\n",
      "test acc (mean, std) after filter:  0.6806250289082527 0.008750297776633295\n",
      "\n",
      "start testing on Pubmed dataset\n",
      "  NumNodes: 19717\n",
      "  NumEdges: 88651\n",
      "  NumFeats: 500\n",
      "  NumClasses: 3\n",
      "  NumTrainingSamples: 60\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "time duration = 12.432, best val acc = 0.808, test acc = 0.776\n",
      "time duration = 12.281, best val acc = 0.796, test acc = 0.766\n",
      "time duration = 13.613, best val acc = 0.806, test acc = 0.767\n",
      "time duration = 12.522, best val acc = 0.798, test acc = 0.772\n",
      "time duration = 12.410, best val acc = 0.806, test acc = 0.778\n",
      "time duration = 14.309, best val acc = 0.798, test acc = 0.774\n",
      "time duration = 12.784, best val acc = 0.800, test acc = 0.774\n",
      "time duration = 15.151, best val acc = 0.802, test acc = 0.778\n",
      "time duration = 13.455, best val acc = 0.808, test acc = 0.764\n",
      "time duration = 12.831, best val acc = 0.802, test acc = 0.781\n",
      "time duration = 12.840, best val acc = 0.800, test acc = 0.766\n",
      "time duration = 12.673, best val acc = 0.808, test acc = 0.770\n",
      "time duration = 13.451, best val acc = 0.802, test acc = 0.771\n",
      "time duration = 12.859, best val acc = 0.792, test acc = 0.761\n",
      "time duration = 12.793, best val acc = 0.814, test acc = 0.770\n",
      "time duration = 12.648, best val acc = 0.804, test acc = 0.759\n",
      "time duration = 12.612, best val acc = 0.794, test acc = 0.746\n",
      "time duration = 13.040, best val acc = 0.800, test acc = 0.760\n",
      "time duration = 12.764, best val acc = 0.798, test acc = 0.766\n",
      "time duration = 13.343, best val acc = 0.804, test acc = 0.765\n",
      "time duration = 12.589, best val acc = 0.814, test acc = 0.775\n",
      "time duration = 14.012, best val acc = 0.790, test acc = 0.756\n",
      "time duration = 12.819, best val acc = 0.794, test acc = 0.755\n",
      "time duration = 12.997, best val acc = 0.808, test acc = 0.778\n",
      "time duration = 12.845, best val acc = 0.800, test acc = 0.765\n",
      "time duration = 14.298, best val acc = 0.808, test acc = 0.777\n",
      "time duration = 13.738, best val acc = 0.808, test acc = 0.774\n",
      "time duration = 12.609, best val acc = 0.806, test acc = 0.769\n",
      "time duration = 12.341, best val acc = 0.812, test acc = 0.765\n",
      "time duration = 13.328, best val acc = 0.800, test acc = 0.769\n",
      "test acc (mean, std):  0.7682333747545879 0.007855708582849473\n",
      "test acc (mean, std) after filter:  0.7688750401139259 0.005278193766564821\n",
      "\n",
      "start testing on Coauthor-CS dataset with random split: 2\n",
      "time duration = 21.287, best val acc = 0.849, test acc = 0.857\n",
      "time duration = 19.000, best val acc = 0.827, test acc = 0.863\n",
      "time duration = 30.359, best val acc = 0.864, test acc = 0.869\n",
      "time duration = 17.294, best val acc = 0.776, test acc = 0.852\n",
      "time duration = 19.002, best val acc = 0.836, test acc = 0.869\n",
      "time duration = 17.512, best val acc = 0.762, test acc = 0.844\n",
      "time duration = 27.122, best val acc = 0.860, test acc = 0.875\n",
      "time duration = 20.523, best val acc = 0.811, test acc = 0.815\n",
      "time duration = 19.630, best val acc = 0.851, test acc = 0.867\n",
      "time duration = 20.845, best val acc = 0.851, test acc = 0.872\n",
      "time duration = 19.400, best val acc = 0.820, test acc = 0.865\n",
      "time duration = 16.794, best val acc = 0.207, test acc = 0.301\n",
      "time duration = 17.038, best val acc = 0.722, test acc = 0.799\n",
      "time duration = 25.001, best val acc = 0.858, test acc = 0.887\n",
      "time duration = 17.710, best val acc = 0.822, test acc = 0.829\n",
      "time duration = 17.281, best val acc = 0.791, test acc = 0.833\n",
      "time duration = 18.131, best val acc = 0.824, test acc = 0.869\n",
      "time duration = 16.829, best val acc = 0.418, test acc = 0.497\n",
      "time duration = 17.044, best val acc = 0.209, test acc = 0.195\n",
      "time duration = 22.981, best val acc = 0.864, test acc = 0.883\n",
      "time duration = 23.169, best val acc = 0.862, test acc = 0.897\n",
      "time duration = 17.955, best val acc = 0.827, test acc = 0.874\n",
      "time duration = 18.303, best val acc = 0.791, test acc = 0.833\n",
      "time duration = 18.376, best val acc = 0.813, test acc = 0.866\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time duration = 16.846, best val acc = 0.398, test acc = 0.501\n",
      "time duration = 16.828, best val acc = 0.356, test acc = 0.235\n",
      "time duration = 17.060, best val acc = 0.758, test acc = 0.815\n",
      "time duration = 17.452, best val acc = 0.780, test acc = 0.850\n",
      "time duration = 16.931, best val acc = 0.187, test acc = 0.222\n",
      "time duration = 16.804, best val acc = 0.229, test acc = 0.217\n",
      "test acc (mean, std):  0.7283266002933184 0.23954205862307648\n",
      "test acc (mean, std) after filter:  0.7729056452711424 0.1818278199283646\n"
     ]
    }
   ],
   "source": [
    "accs = []\n",
    "datasets = ['Cora', 'Citeseer', 'Pubmed', 'Coauthor-CS']\n",
    "for dataset in datasets:\n",
    "    args.data = dataset\n",
    "    if dataset == 'Cora':\n",
    "        args.dropout = 0.8\n",
    "        args.dropout2 = 0\n",
    "        args.weight_decay = 5e-4\n",
    "        print('\\nstart testing on ' + dataset + ' dataset')\n",
    "        accs.append(main(args))\n",
    "    elif dataset == 'Citeseer':\n",
    "        args.dropout = 0.8\n",
    "        args.dropout2 = 0\n",
    "        args.weight_decay = 5e-4\n",
    "        print('\\nstart testing on ' + dataset + ' dataset')\n",
    "        accs.append(main(args))\n",
    "    elif dataset == 'Pubmed':\n",
    "        args.dropout = 0.8\n",
    "        args.dropout2 = 0\n",
    "        print('\\nstart testing on ' + dataset + ' dataset')\n",
    "        accs.append(main(args))\n",
    "    elif dataset == 'Coauthor-CS':\n",
    "        args.dropout = 0.8\n",
    "        args.dropout2 = 0\n",
    "        for i in [2]:\n",
    "            print('\\nstart testing on ' + dataset + ' dataset with random split: ' + str(i))\n",
    "            args.test_id = i\n",
    "            accs.append(main(args))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "accs = np.array(accs)\n",
    "np.savetxt('acc_GAT_hop5.txt',accs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import remove_edge_pts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test accuracy (mean, std) after filter for Coauthor-CS dataset 2\n",
      " 0.8594538900587294 0.015591932342979451\n"
     ]
    }
   ],
   "source": [
    "# remove some bad points for Coauthor-CS data\n",
    "acc3 = accs[3]\n",
    "acc3 = acc3[acc3 > 0.8]\n",
    "acc3 = remove_edge_pts(acc3, args.filter_pct)\n",
    "\n",
    "print('test accuracy (mean, std) after filter for Coauthor-CS dataset 2\\n', acc3.mean(), acc3.std())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
