{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "class config:\n",
    "    data = 'Cora' #choices=[\"Cora\", \"Citeseer\", \"Pubmed\", \"Coauthor-CS\"]\n",
    "    model = 'GAT' #choices=[\"GCN\", \"GAT\", \"APPNP\", \"DAGNN\", \"TreeLSTM\", \"GTCN\", \"GTAN\", \"GTCN2\"]\n",
    "    n_in = 0\n",
    "    n_hid = 64\n",
    "    n_out = 0\n",
    "    num_heads = 1\n",
    "    num_out_heads = 1\n",
    "    device = 'cpu'\n",
    "    dropout = 0.2\n",
    "    dropout2 = 0\n",
    "    learning_rate = 0.01\n",
    "    weight_decay = 5e-3\n",
    "    patience = 200\n",
    "    num_iter = 1000\n",
    "    num_test = 30\n",
    "    hop = 10\n",
    "    alpha = 0.1 # used by APPNP only\n",
    "    random_label_split = False\n",
    "    num_train = 20 # for random label split only\n",
    "    num_val = 30 # for random label split only\n",
    "    data_load = True # load the saved label split to rerun the test (for reproduce purpose)\n",
    "    test_id = 1 # number of the test, only used to record the ith number of the random label split (for reproduce purpose)\n",
    "    filter_pct = 0.1 # remove the top and bottom filer_pct points before obtaining statistics of test accuracy\n",
    "    log = False # whether to show the training log or not\n",
    "    eval_metric = 'acc' # evaluation metrics, choices=[\"acc\", \"f1-macro\", \"f1-micro\"]\n",
    "    root_dir = '../..' # dir of the source code\n",
    "\n",
    "args = config()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append(args.root_dir)\n",
    "from train import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "start testing on Cora dataset\n",
      "  NumNodes: 2708\n",
      "  NumEdges: 10556\n",
      "  NumFeats: 1433\n",
      "  NumClasses: 7\n",
      "  NumTrainingSamples: 140\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "time duration = 26.728, best val acc = 0.794, test acc = 0.778\n",
      "time duration = 30.385, best val acc = 0.800, test acc = 0.784\n",
      "time duration = 28.717, best val acc = 0.786, test acc = 0.793\n",
      "time duration = 34.173, best val acc = 0.788, test acc = 0.763\n",
      "time duration = 28.807, best val acc = 0.786, test acc = 0.783\n",
      "time duration = 30.727, best val acc = 0.796, test acc = 0.798\n",
      "time duration = 34.444, best val acc = 0.790, test acc = 0.795\n",
      "time duration = 23.216, best val acc = 0.804, test acc = 0.812\n",
      "time duration = 24.881, best val acc = 0.790, test acc = 0.775\n",
      "time duration = 28.407, best val acc = 0.796, test acc = 0.792\n",
      "time duration = 30.587, best val acc = 0.802, test acc = 0.782\n",
      "time duration = 40.871, best val acc = 0.790, test acc = 0.801\n",
      "time duration = 38.639, best val acc = 0.790, test acc = 0.789\n",
      "time duration = 27.704, best val acc = 0.796, test acc = 0.802\n",
      "time duration = 37.976, best val acc = 0.798, test acc = 0.802\n",
      "time duration = 29.079, best val acc = 0.804, test acc = 0.791\n",
      "time duration = 29.652, best val acc = 0.796, test acc = 0.785\n",
      "time duration = 32.725, best val acc = 0.798, test acc = 0.784\n",
      "time duration = 27.394, best val acc = 0.784, test acc = 0.770\n",
      "time duration = 28.197, best val acc = 0.796, test acc = 0.805\n",
      "time duration = 28.288, best val acc = 0.814, test acc = 0.809\n",
      "time duration = 73.488, best val acc = 0.784, test acc = 0.766\n",
      "time duration = 27.530, best val acc = 0.796, test acc = 0.773\n",
      "time duration = 26.705, best val acc = 0.802, test acc = 0.798\n",
      "time duration = 28.132, best val acc = 0.788, test acc = 0.782\n",
      "time duration = 22.648, best val acc = 0.804, test acc = 0.782\n",
      "time duration = 26.675, best val acc = 0.794, test acc = 0.809\n",
      "time duration = 28.122, best val acc = 0.794, test acc = 0.792\n",
      "time duration = 31.399, best val acc = 0.804, test acc = 0.791\n",
      "time duration = 20.985, best val acc = 0.776, test acc = 0.779\n",
      "test acc (mean, std):  0.7888333737850189 0.01264142060838947\n",
      "test acc (mean, std) after filter:  0.7890000442663828 0.008897563144521533\n",
      "\n",
      "start testing on Citeseer dataset\n",
      "  NumNodes: 3327\n",
      "  NumEdges: 9228\n",
      "  NumFeats: 3703\n",
      "  NumClasses: 6\n",
      "  NumTrainingSamples: 120\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "time duration = 21.712, best val acc = 0.656, test acc = 0.648\n",
      "time duration = 24.652, best val acc = 0.674, test acc = 0.671\n",
      "time duration = 22.418, best val acc = 0.694, test acc = 0.675\n",
      "time duration = 31.761, best val acc = 0.678, test acc = 0.662\n",
      "time duration = 20.646, best val acc = 0.680, test acc = 0.681\n",
      "time duration = 31.707, best val acc = 0.664, test acc = 0.652\n",
      "time duration = 24.426, best val acc = 0.688, test acc = 0.677\n",
      "time duration = 34.329, best val acc = 0.664, test acc = 0.650\n",
      "time duration = 26.833, best val acc = 0.652, test acc = 0.667\n",
      "time duration = 28.937, best val acc = 0.652, test acc = 0.649\n",
      "time duration = 22.372, best val acc = 0.682, test acc = 0.684\n",
      "time duration = 28.916, best val acc = 0.700, test acc = 0.686\n",
      "time duration = 25.925, best val acc = 0.688, test acc = 0.675\n",
      "time duration = 40.027, best val acc = 0.650, test acc = 0.647\n",
      "time duration = 25.172, best val acc = 0.680, test acc = 0.660\n",
      "time duration = 34.199, best val acc = 0.684, test acc = 0.683\n",
      "time duration = 34.116, best val acc = 0.672, test acc = 0.664\n",
      "time duration = 28.118, best val acc = 0.678, test acc = 0.654\n",
      "time duration = 23.955, best val acc = 0.650, test acc = 0.642\n",
      "time duration = 20.017, best val acc = 0.666, test acc = 0.679\n",
      "time duration = 23.948, best val acc = 0.678, test acc = 0.651\n",
      "time duration = 22.021, best val acc = 0.674, test acc = 0.668\n",
      "time duration = 27.418, best val acc = 0.674, test acc = 0.665\n",
      "time duration = 30.326, best val acc = 0.676, test acc = 0.681\n",
      "time duration = 24.153, best val acc = 0.678, test acc = 0.663\n",
      "time duration = 27.362, best val acc = 0.662, test acc = 0.660\n",
      "time duration = 20.571, best val acc = 0.658, test acc = 0.663\n",
      "time duration = 26.996, best val acc = 0.686, test acc = 0.673\n",
      "time duration = 21.911, best val acc = 0.666, test acc = 0.674\n",
      "time duration = 31.917, best val acc = 0.676, test acc = 0.681\n",
      "test acc (mean, std):  0.6661666989326477 0.012601811536291937\n",
      "test acc (mean, std) after filter:  0.6664583683013916 0.01017750570761952\n",
      "\n",
      "start testing on Pubmed dataset\n",
      "  NumNodes: 19717\n",
      "  NumEdges: 88651\n",
      "  NumFeats: 500\n",
      "  NumClasses: 3\n",
      "  NumTrainingSamples: 60\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "time duration = 93.647, best val acc = 0.804, test acc = 0.783\n",
      "time duration = 137.742, best val acc = 0.794, test acc = 0.766\n",
      "time duration = 125.965, best val acc = 0.780, test acc = 0.758\n",
      "time duration = 97.762, best val acc = 0.800, test acc = 0.766\n",
      "time duration = 95.864, best val acc = 0.792, test acc = 0.756\n",
      "time duration = 95.811, best val acc = 0.800, test acc = 0.769\n",
      "time duration = 162.545, best val acc = 0.802, test acc = 0.761\n",
      "time duration = 96.807, best val acc = 0.810, test acc = 0.782\n",
      "time duration = 170.391, best val acc = 0.798, test acc = 0.771\n",
      "time duration = 103.752, best val acc = 0.778, test acc = 0.755\n",
      "time duration = 99.779, best val acc = 0.802, test acc = 0.771\n",
      "time duration = 100.989, best val acc = 0.822, test acc = 0.782\n",
      "time duration = 100.585, best val acc = 0.802, test acc = 0.774\n",
      "time duration = 104.181, best val acc = 0.800, test acc = 0.772\n",
      "time duration = 111.869, best val acc = 0.806, test acc = 0.792\n",
      "time duration = 97.289, best val acc = 0.818, test acc = 0.787\n",
      "time duration = 157.810, best val acc = 0.800, test acc = 0.776\n",
      "time duration = 93.716, best val acc = 0.796, test acc = 0.763\n",
      "time duration = 131.156, best val acc = 0.796, test acc = 0.770\n",
      "time duration = 107.801, best val acc = 0.802, test acc = 0.784\n",
      "time duration = 109.838, best val acc = 0.808, test acc = 0.779\n",
      "time duration = 96.415, best val acc = 0.806, test acc = 0.777\n",
      "time duration = 96.892, best val acc = 0.810, test acc = 0.764\n",
      "time duration = 133.129, best val acc = 0.782, test acc = 0.760\n",
      "time duration = 95.350, best val acc = 0.816, test acc = 0.768\n",
      "time duration = 100.540, best val acc = 0.804, test acc = 0.767\n",
      "time duration = 96.419, best val acc = 0.806, test acc = 0.769\n",
      "time duration = 109.463, best val acc = 0.788, test acc = 0.764\n",
      "time duration = 97.805, best val acc = 0.810, test acc = 0.781\n",
      "time duration = 97.739, best val acc = 0.802, test acc = 0.768\n",
      "test acc (mean, std):  0.7711667040983836 0.009331253764408187\n",
      "test acc (mean, std) after filter:  0.770958368976911 0.00676066808059299\n",
      "\n",
      "start testing on Coauthor-CS dataset with random split: 2\n",
      "time duration = 134.704, best val acc = 0.329, test acc = 0.335\n",
      "time duration = 140.581, best val acc = 0.571, test acc = 0.620\n",
      "time duration = 170.448, best val acc = 0.822, test acc = 0.841\n",
      "time duration = 143.167, best val acc = 0.482, test acc = 0.374\n",
      "time duration = 188.631, best val acc = 0.824, test acc = 0.843\n",
      "time duration = 180.008, best val acc = 0.818, test acc = 0.843\n",
      "time duration = 136.596, best val acc = 0.400, test acc = 0.392\n",
      "time duration = 191.577, best val acc = 0.818, test acc = 0.855\n",
      "time duration = 149.661, best val acc = 0.816, test acc = 0.840\n",
      "time duration = 175.166, best val acc = 0.818, test acc = 0.837\n",
      "time duration = 158.046, best val acc = 0.833, test acc = 0.851\n",
      "time duration = 152.197, best val acc = 0.809, test acc = 0.821\n",
      "time duration = 169.635, best val acc = 0.809, test acc = 0.831\n",
      "time duration = 168.368, best val acc = 0.811, test acc = 0.839\n",
      "time duration = 188.515, best val acc = 0.824, test acc = 0.836\n",
      "time duration = 138.407, best val acc = 0.607, test acc = 0.606\n",
      "time duration = 136.590, best val acc = 0.429, test acc = 0.521\n",
      "time duration = 133.922, best val acc = 0.338, test acc = 0.261\n",
      "time duration = 300.673, best val acc = 0.824, test acc = 0.840\n",
      "time duration = 192.433, best val acc = 0.816, test acc = 0.830\n",
      "time duration = 136.033, best val acc = 0.522, test acc = 0.621\n",
      "time duration = 186.471, best val acc = 0.833, test acc = 0.844\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time duration = 141.243, best val acc = 0.684, test acc = 0.720\n",
      "time duration = 157.919, best val acc = 0.811, test acc = 0.830\n",
      "time duration = 139.825, best val acc = 0.742, test acc = 0.796\n",
      "time duration = 138.581, best val acc = 0.571, test acc = 0.604\n",
      "time duration = 145.631, best val acc = 0.744, test acc = 0.757\n",
      "time duration = 146.236, best val acc = 0.780, test acc = 0.828\n",
      "time duration = 138.503, best val acc = 0.529, test acc = 0.530\n",
      "time duration = 169.539, best val acc = 0.829, test acc = 0.844\n",
      "test acc (mean, std):  0.7129575868447622 0.17825357009639542\n",
      "test acc (mean, std) after filter:  0.7445425453285376 0.1308575205498569\n"
     ]
    }
   ],
   "source": [
    "args = config()\n",
    "accs = []\n",
    "datasets = ['Cora', 'Citeseer', 'Pubmed', 'Coauthor-CS']\n",
    "for dataset in datasets:\n",
    "    args.data = dataset\n",
    "    if dataset == 'Cora':\n",
    "        print('\\nstart testing on ' + dataset + ' dataset')\n",
    "        accs.append(main(args))\n",
    "    elif dataset == 'Citeseer':\n",
    "        print('\\nstart testing on ' + dataset + ' dataset')\n",
    "        accs.append(main(args))\n",
    "    elif dataset == 'Pubmed':\n",
    "        #args.learning_rate = 0.01\n",
    "        print('\\nstart testing on ' + dataset + ' dataset')\n",
    "        accs.append(main(args))\n",
    "    elif dataset == 'Coauthor-CS':\n",
    "        for i in [2]:\n",
    "            print('\\nstart testing on ' + dataset + ' dataset with random split: ' + str(i))\n",
    "            args.test_id = i\n",
    "            accs.append(main(args))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "accs = np.array(accs)\n",
    "np.savetxt('acc_GAT_hop10.txt',accs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import remove_edge_pts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test accuracy (mean, std) after filter for Coauthor-CS dataset 2\n",
      " 0.8384424289067586 0.006411895280113244\n"
     ]
    }
   ],
   "source": [
    "# remove some bad points for Coauthor-CS data\n",
    "acc3 = accs[3]\n",
    "acc3 = acc3[acc3 > 0.8]\n",
    "acc3 = remove_edge_pts(acc3, args.filter_pct)\n",
    "\n",
    "print('test accuracy (mean, std) after filter for Coauthor-CS dataset 2\\n', acc3.mean(), acc3.std())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
