{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n"
     ]
    }
   ],
   "source": [
    "from train_ogbn_arxiv import main"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class config:\n",
    "    model = 'TreeLSTM'\n",
    "    n_hid = 256\n",
    "    num_heads = 1\n",
    "    num_out_heads = 1\n",
    "    device = 'cpu'\n",
    "    dropout = 0.2\n",
    "    dropout2 = 0\n",
    "    learning_rate = 0.01\n",
    "    weight_decay = 0\n",
    "    num_iter = 1000\n",
    "    num_test = 10\n",
    "    hop = 3\n",
    "    eval_metric = 'acc'\n",
    "    log_steps = 100\n",
    "\n",
    "args = config()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#Parameters: 405544\n",
      "Run: 01, Time elapsed: 2.15, Epoch: 00, Loss: 3.7049, Train: 42.17%, Valid: 44.64% Test: 43.11%\n",
      "Run: 01, Time elapsed: 62.17, Epoch: 100, Loss: 0.8281, Train: 77.52%, Valid: 71.84% Test: 70.36%\n",
      "Run: 01, Time elapsed: 62.27, Epoch: 200, Loss: 0.6106, Train: 83.90%, Valid: 72.29% Test: 71.64%\n",
      "Run: 01, Time elapsed: 62.12, Epoch: 300, Loss: 0.5066, Train: 87.35%, Valid: 72.09% Test: 71.30%\n",
      "Run: 01, Time elapsed: 62.10, Epoch: 400, Loss: 0.4477, Train: 89.64%, Valid: 71.69% Test: 70.28%\n",
      "Run: 01, Time elapsed: 62.25, Epoch: 500, Loss: 0.3973, Train: 91.36%, Valid: 71.57% Test: 70.50%\n",
      "Run: 01, Time elapsed: 62.16, Epoch: 600, Loss: 0.3561, Train: 92.47%, Valid: 70.89% Test: 69.74%\n",
      "Run: 01, Time elapsed: 62.10, Epoch: 700, Loss: 0.3434, Train: 93.45%, Valid: 70.59% Test: 69.35%\n",
      "Run: 01, Time elapsed: 62.11, Epoch: 800, Loss: 0.3091, Train: 93.94%, Valid: 70.76% Test: 69.36%\n",
      "Run: 01, Time elapsed: 62.10, Epoch: 900, Loss: 0.2924, Train: 94.68%, Valid: 70.86% Test: 69.76%\n",
      "Run 01:\n",
      "Highest Train: 95.39\n",
      "Highest Valid: 72.44\n",
      "  Final Train: 82.60\n",
      "   Final Test: 71.77\n",
      "Run: 02, Time elapsed: 0.77, Epoch: 00, Loss: 3.7260, Train: 43.77%, Valid: 47.02% Test: 47.34%\n",
      "Run: 02, Time elapsed: 62.17, Epoch: 100, Loss: 0.8155, Train: 77.97%, Valid: 71.72% Test: 70.10%\n",
      "Run: 02, Time elapsed: 62.18, Epoch: 200, Loss: 0.6092, Train: 83.92%, Valid: 72.37% Test: 71.39%\n",
      "Run: 02, Time elapsed: 62.16, Epoch: 300, Loss: 0.5084, Train: 87.30%, Valid: 71.46% Test: 69.58%\n",
      "Run: 02, Time elapsed: 62.21, Epoch: 400, Loss: 0.4388, Train: 89.68%, Valid: 71.72% Test: 70.43%\n",
      "Run: 02, Time elapsed: 62.11, Epoch: 500, Loss: 0.3914, Train: 91.35%, Valid: 71.25% Test: 70.29%\n",
      "Run: 02, Time elapsed: 62.39, Epoch: 600, Loss: 0.3574, Train: 92.72%, Valid: 70.70% Test: 69.02%\n",
      "Run: 02, Time elapsed: 62.09, Epoch: 700, Loss: 0.3344, Train: 93.31%, Valid: 70.65% Test: 69.15%\n",
      "Run: 02, Time elapsed: 63.77, Epoch: 800, Loss: 0.3048, Train: 94.10%, Valid: 70.47% Test: 69.46%\n",
      "Run: 02, Time elapsed: 62.19, Epoch: 900, Loss: 0.2937, Train: 94.42%, Valid: 70.15% Test: 68.62%\n",
      "Run 02:\n",
      "Highest Train: 95.43\n",
      "Highest Valid: 72.45\n",
      "  Final Train: 83.61\n",
      "   Final Test: 71.37\n",
      "Run: 03, Time elapsed: 0.77, Epoch: 00, Loss: 3.6792, Train: 42.81%, Valid: 43.88% Test: 42.10%\n",
      "Run: 03, Time elapsed: 62.13, Epoch: 100, Loss: 0.8142, Train: 77.73%, Valid: 72.00% Test: 70.57%\n",
      "Run: 03, Time elapsed: 62.09, Epoch: 200, Loss: 0.6201, Train: 83.43%, Valid: 71.83% Test: 70.87%\n",
      "Run: 03, Time elapsed: 62.12, Epoch: 300, Loss: 0.5077, Train: 87.34%, Valid: 71.81% Test: 70.62%\n",
      "Run: 03, Time elapsed: 62.11, Epoch: 400, Loss: 0.4408, Train: 89.43%, Valid: 71.29% Test: 70.13%\n",
      "Run: 03, Time elapsed: 62.08, Epoch: 500, Loss: 0.3952, Train: 91.52%, Valid: 71.15% Test: 69.99%\n",
      "Run: 03, Time elapsed: 62.10, Epoch: 600, Loss: 0.3607, Train: 92.66%, Valid: 70.91% Test: 69.92%\n",
      "Run: 03, Time elapsed: 62.07, Epoch: 700, Loss: 0.3273, Train: 93.37%, Valid: 70.76% Test: 69.27%\n",
      "Run: 03, Time elapsed: 62.08, Epoch: 800, Loss: 0.3100, Train: 94.08%, Valid: 70.63% Test: 69.40%\n",
      "Run: 03, Time elapsed: 62.07, Epoch: 900, Loss: 0.2908, Train: 94.76%, Valid: 70.69% Test: 69.54%\n",
      "Run 03:\n",
      "Highest Train: 95.32\n",
      "Highest Valid: 72.43\n",
      "  Final Train: 79.89\n",
      "   Final Test: 71.10\n",
      "Run: 04, Time elapsed: 0.76, Epoch: 00, Loss: 3.7123, Train: 43.25%, Valid: 43.39% Test: 40.32%\n",
      "Run: 04, Time elapsed: 62.71, Epoch: 100, Loss: 0.8162, Train: 77.79%, Valid: 71.97% Test: 70.51%\n",
      "Run: 04, Time elapsed: 62.87, Epoch: 200, Loss: 0.6076, Train: 84.00%, Valid: 72.12% Test: 70.40%\n",
      "Run: 04, Time elapsed: 62.11, Epoch: 300, Loss: 0.5025, Train: 87.49%, Valid: 72.19% Test: 70.71%\n",
      "Run: 04, Time elapsed: 62.08, Epoch: 400, Loss: 0.4404, Train: 89.62%, Valid: 71.19% Test: 69.21%\n",
      "Run: 04, Time elapsed: 62.14, Epoch: 500, Loss: 0.3906, Train: 91.31%, Valid: 71.66% Test: 70.04%\n",
      "Run: 04, Time elapsed: 62.09, Epoch: 600, Loss: 0.3563, Train: 92.72%, Valid: 71.43% Test: 70.02%\n",
      "Run: 04, Time elapsed: 62.22, Epoch: 700, Loss: 0.3311, Train: 93.38%, Valid: 71.17% Test: 69.93%\n",
      "Run: 04, Time elapsed: 63.89, Epoch: 800, Loss: 0.3074, Train: 94.35%, Valid: 71.21% Test: 69.56%\n",
      "Run: 04, Time elapsed: 62.39, Epoch: 900, Loss: 0.2852, Train: 94.79%, Valid: 70.88% Test: 69.53%\n",
      "Run 04:\n",
      "Highest Train: 95.43\n",
      "Highest Valid: 72.52\n",
      "  Final Train: 80.14\n",
      "   Final Test: 71.43\n",
      "Run: 05, Time elapsed: 0.78, Epoch: 00, Loss: 3.6909, Train: 39.69%, Valid: 40.57% Test: 39.25%\n",
      "Run: 05, Time elapsed: 62.07, Epoch: 100, Loss: 0.8197, Train: 78.02%, Valid: 71.90% Test: 70.22%\n",
      "Run: 05, Time elapsed: 62.19, Epoch: 200, Loss: 0.6225, Train: 83.23%, Valid: 71.83% Test: 70.61%\n",
      "Run: 05, Time elapsed: 62.20, Epoch: 300, Loss: 0.5096, Train: 87.40%, Valid: 71.46% Test: 70.43%\n",
      "Run: 05, Time elapsed: 62.09, Epoch: 400, Loss: 0.4403, Train: 89.71%, Valid: 71.48% Test: 70.22%\n",
      "Run: 05, Time elapsed: 62.10, Epoch: 500, Loss: 0.3945, Train: 91.47%, Valid: 70.67% Test: 69.28%\n",
      "Run: 05, Time elapsed: 62.74, Epoch: 600, Loss: 0.3662, Train: 92.36%, Valid: 70.60% Test: 68.99%\n",
      "Run: 05, Time elapsed: 65.75, Epoch: 700, Loss: 0.3268, Train: 93.51%, Valid: 70.71% Test: 69.34%\n",
      "Run: 05, Time elapsed: 65.65, Epoch: 800, Loss: 0.3087, Train: 94.17%, Valid: 70.31% Test: 69.23%\n",
      "Run: 05, Time elapsed: 64.58, Epoch: 900, Loss: 0.2883, Train: 94.94%, Valid: 69.97% Test: 68.69%\n",
      "Run 05:\n",
      "Highest Train: 95.42\n",
      "Highest Valid: 72.47\n",
      "  Final Train: 79.62\n",
      "   Final Test: 71.51\n",
      "Run: 06, Time elapsed: 0.78, Epoch: 00, Loss: 3.6978, Train: 43.17%, Valid: 44.34% Test: 42.85%\n",
      "Run: 06, Time elapsed: 64.28, Epoch: 100, Loss: 0.8176, Train: 77.85%, Valid: 72.09% Test: 70.61%\n",
      "Run: 06, Time elapsed: 64.31, Epoch: 200, Loss: 0.6101, Train: 83.92%, Valid: 71.99% Test: 70.43%\n",
      "Run: 06, Time elapsed: 63.81, Epoch: 300, Loss: 0.5135, Train: 86.95%, Valid: 71.52% Test: 70.02%\n",
      "Run: 06, Time elapsed: 62.96, Epoch: 400, Loss: 0.4411, Train: 89.56%, Valid: 71.57% Test: 70.07%\n",
      "Run: 06, Time elapsed: 62.50, Epoch: 500, Loss: 0.3942, Train: 91.32%, Valid: 71.00% Test: 69.31%\n",
      "Run: 06, Time elapsed: 63.60, Epoch: 600, Loss: 0.3624, Train: 92.31%, Valid: 71.06% Test: 69.21%\n",
      "Run: 06, Time elapsed: 62.56, Epoch: 700, Loss: 0.3308, Train: 93.48%, Valid: 71.20% Test: 69.74%\n",
      "Run: 06, Time elapsed: 62.12, Epoch: 800, Loss: 0.3189, Train: 93.91%, Valid: 70.72% Test: 69.37%\n",
      "Run: 06, Time elapsed: 62.11, Epoch: 900, Loss: 0.2898, Train: 94.78%, Valid: 70.41% Test: 68.49%\n",
      "Run 06:\n",
      "Highest Train: 95.21\n",
      "Highest Valid: 72.52\n",
      "  Final Train: 84.29\n",
      "   Final Test: 71.72\n",
      "Run: 07, Time elapsed: 0.76, Epoch: 00, Loss: 3.7083, Train: 42.49%, Valid: 43.94% Test: 42.25%\n",
      "Run: 07, Time elapsed: 62.10, Epoch: 100, Loss: 0.8198, Train: 77.61%, Valid: 71.71% Test: 70.39%\n",
      "Run: 07, Time elapsed: 62.09, Epoch: 200, Loss: 0.6136, Train: 83.81%, Valid: 71.95% Test: 70.44%\n",
      "Run: 07, Time elapsed: 62.08, Epoch: 300, Loss: 0.5089, Train: 87.25%, Valid: 71.46% Test: 69.57%\n",
      "Run: 07, Time elapsed: 62.10, Epoch: 400, Loss: 0.4403, Train: 89.38%, Valid: 70.63% Test: 68.72%\n",
      "Run: 07, Time elapsed: 62.09, Epoch: 500, Loss: 0.3935, Train: 90.56%, Valid: 70.16% Test: 68.02%\n",
      "Run: 07, Time elapsed: 62.09, Epoch: 600, Loss: 0.3577, Train: 92.52%, Valid: 70.77% Test: 69.49%\n",
      "Run: 07, Time elapsed: 62.08, Epoch: 700, Loss: 0.3262, Train: 93.56%, Valid: 70.31% Test: 68.48%\n",
      "Run: 07, Time elapsed: 62.09, Epoch: 800, Loss: 0.3115, Train: 94.06%, Valid: 70.41% Test: 68.91%\n",
      "Run: 07, Time elapsed: 62.09, Epoch: 900, Loss: 0.2879, Train: 94.91%, Valid: 70.52% Test: 69.08%\n",
      "Run 07:\n",
      "Highest Train: 95.36\n",
      "Highest Valid: 72.26\n",
      "  Final Train: 84.47\n",
      "   Final Test: 71.19\n",
      "Run: 08, Time elapsed: 0.75, Epoch: 00, Loss: 3.6564, Train: 45.47%, Valid: 49.23% Test: 49.80%\n",
      "Run: 08, Time elapsed: 62.12, Epoch: 100, Loss: 0.8065, Train: 78.21%, Valid: 72.29% Test: 71.26%\n",
      "Run: 08, Time elapsed: 62.12, Epoch: 200, Loss: 0.6111, Train: 83.70%, Valid: 72.24% Test: 71.11%\n",
      "Run: 08, Time elapsed: 62.10, Epoch: 300, Loss: 0.5042, Train: 87.43%, Valid: 71.48% Test: 69.64%\n",
      "Run: 08, Time elapsed: 62.12, Epoch: 400, Loss: 0.4401, Train: 89.88%, Valid: 71.50% Test: 69.95%\n",
      "Run: 08, Time elapsed: 62.12, Epoch: 500, Loss: 0.3953, Train: 91.04%, Valid: 71.01% Test: 69.19%\n",
      "Run: 08, Time elapsed: 62.09, Epoch: 600, Loss: 0.3638, Train: 92.15%, Valid: 70.96% Test: 69.35%\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run: 08, Time elapsed: 62.14, Epoch: 700, Loss: 0.3319, Train: 93.54%, Valid: 70.60% Test: 69.00%\n",
      "Run: 08, Time elapsed: 62.10, Epoch: 800, Loss: 0.3090, Train: 94.23%, Valid: 70.45% Test: 68.64%\n",
      "Run: 08, Time elapsed: 62.12, Epoch: 900, Loss: 0.2868, Train: 94.60%, Valid: 70.28% Test: 69.06%\n",
      "Run 08:\n",
      "Highest Train: 95.44\n",
      "Highest Valid: 72.61\n",
      "  Final Train: 83.08\n",
      "   Final Test: 70.94\n",
      "Run: 09, Time elapsed: 0.76, Epoch: 00, Loss: 3.6948, Train: 44.23%, Valid: 48.41% Test: 48.98%\n",
      "Run: 09, Time elapsed: 62.11, Epoch: 100, Loss: 0.8080, Train: 78.08%, Valid: 72.09% Test: 70.70%\n",
      "Run: 09, Time elapsed: 62.08, Epoch: 200, Loss: 0.6086, Train: 84.05%, Valid: 71.71% Test: 70.06%\n",
      "Run: 09, Time elapsed: 62.11, Epoch: 300, Loss: 0.5042, Train: 87.41%, Valid: 72.08% Test: 71.23%\n",
      "Run: 09, Time elapsed: 62.09, Epoch: 400, Loss: 0.4339, Train: 90.03%, Valid: 71.84% Test: 70.52%\n",
      "Run: 09, Time elapsed: 62.08, Epoch: 500, Loss: 0.3925, Train: 90.80%, Valid: 71.11% Test: 69.27%\n",
      "Run: 09, Time elapsed: 62.08, Epoch: 600, Loss: 0.3541, Train: 92.68%, Valid: 70.88% Test: 69.31%\n",
      "Run: 09, Time elapsed: 62.10, Epoch: 700, Loss: 0.3257, Train: 93.60%, Valid: 70.95% Test: 69.33%\n",
      "Run: 09, Time elapsed: 62.10, Epoch: 800, Loss: 0.3036, Train: 94.07%, Valid: 70.89% Test: 69.85%\n",
      "Run: 09, Time elapsed: 62.10, Epoch: 900, Loss: 0.2887, Train: 94.89%, Valid: 70.61% Test: 68.90%\n",
      "Run 09:\n",
      "Highest Train: 95.40\n",
      "Highest Valid: 72.55\n",
      "  Final Train: 81.09\n",
      "   Final Test: 71.52\n",
      "Run: 10, Time elapsed: 0.76, Epoch: 00, Loss: 3.7154, Train: 43.57%, Valid: 42.73% Test: 40.53%\n",
      "Run: 10, Time elapsed: 62.13, Epoch: 100, Loss: 0.8176, Train: 77.82%, Valid: 71.90% Test: 70.04%\n",
      "Run: 10, Time elapsed: 62.09, Epoch: 200, Loss: 0.6133, Train: 83.86%, Valid: 72.23% Test: 70.96%\n",
      "Run: 10, Time elapsed: 62.11, Epoch: 300, Loss: 0.5043, Train: 87.59%, Valid: 71.92% Test: 70.74%\n",
      "Run: 10, Time elapsed: 62.14, Epoch: 400, Loss: 0.4426, Train: 89.83%, Valid: 71.53% Test: 70.41%\n",
      "Run: 10, Time elapsed: 62.13, Epoch: 500, Loss: 0.3932, Train: 91.40%, Valid: 71.35% Test: 70.66%\n",
      "Run: 10, Time elapsed: 62.11, Epoch: 600, Loss: 0.3571, Train: 92.38%, Valid: 70.92% Test: 69.84%\n",
      "Run: 10, Time elapsed: 62.13, Epoch: 700, Loss: 0.3265, Train: 93.39%, Valid: 70.88% Test: 69.75%\n",
      "Run: 10, Time elapsed: 62.10, Epoch: 800, Loss: 0.3066, Train: 94.31%, Valid: 70.08% Test: 68.37%\n",
      "Run: 10, Time elapsed: 62.11, Epoch: 900, Loss: 0.2829, Train: 94.99%, Valid: 70.23% Test: 68.63%\n",
      "Run 10:\n",
      "Highest Train: 95.44\n",
      "Highest Valid: 72.67\n",
      "  Final Train: 83.26\n",
      "   Final Test: 71.54\n",
      "All runs:\n",
      "Highest Train: 95.38 ± 0.07\n",
      "Highest Valid: 72.49 ± 0.11\n",
      "  Final Train: 82.21 ± 1.86\n",
      "   Final Test: 71.41 ± 0.26\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<train_obgn_arxiv.Logger at 0x217b375da60>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "main(args)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
