{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "aaf596da",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "def load_results(dirname, dname, gnn, diagram_type, fb_one=False, no_ofst=False):\n",
    "    fp = f\"{dirname}/{dname}/2/8/16_hidden/64_outdim/True_dim1/{diagram_type}_{gnn}_2\"\n",
    "    if diagram_type == \"forward_backward\" and fb_one:\n",
    "        fp += f\"_fbone{int(fb_one)}\"\n",
    "    if diagram_type == \"forward_backward\" and no_ofst:\n",
    "        fp += f\"_noofst{int(no_ofst)}\"\n",
    "    results = torch.load(f\"{fp}.results\")\n",
    "    return results\n",
    "\n",
    "def make_table(dirname, gnn, mode=\"all\"):\n",
    "    \"\"\"\n",
    "    mode = \"all\"  → return (test_accuracies, max, val_best)\n",
    "    mode = \"max\"  → return only max\n",
    "    mode = \"es_max\"  → return only val_best\n",
    "    \"\"\"\n",
    "    datasets = [\"NCI109\", \"PROTEINS\", \"IMDB-BINARY\",\"NCI1\"]\n",
    "    methods = [\n",
    "        (\"standard\", False, False),\n",
    "        (\"rephine\", False, False),\n",
    "        (\"forward_backward\", True, True),   # fb_one + no_ofst\n",
    "        (\"forward_backward\", True, False),  # fb_one only\n",
    "        (\"forward_backward\", False, False)  # plain forward_backward\n",
    "    ]\n",
    "\n",
    "    table = []\n",
    "    for dname in datasets:\n",
    "        row = []\n",
    "        for diagram_type, fb_one, no_ofst in methods:\n",
    "            try:\n",
    "                res = load_results(dirname, dname, gnn, diagram_type, fb_one, no_ofst)\n",
    "                test_acc = res['test_accuracies']\n",
    "                max_acc = test_acc.max().item()\n",
    "                val_best_test_acc = test_acc[res['val_accuracies'].argmax()].item()\n",
    "          \n",
    "                if mode == \"max\":\n",
    "                    row.append(max_acc)\n",
    "                elif mode == \"es_max\":\n",
    "                    row.append(val_best_test_acc)\n",
    "            except FileNotFoundError:\n",
    "                row.append(\"N/A\")\n",
    "        table.append((dname, row))\n",
    "    return table\n",
    "\n",
    "\n",
    "def print_table(table, gnn, mode=\"max\"):\n",
    "    headers = [\"standard\", \"rephine\", \"Ours\"]\n",
    "    print(f\"\\n=== Results for {gnn.upper()} ({mode}) ===\")\n",
    "    print(f\"{'Dataset':<12} \" + \" | \".join(f\"{h:<14}\" for h in headers))\n",
    "    print(\"-\" * 60)\n",
    "    for dname, row in table:\n",
    "        vals = []\n",
    "        # Baselines\n",
    "        for entry in row[:2]:\n",
    "            if isinstance(entry, (int, float)):\n",
    "                vals.append(f\"{entry:.4f}\")\n",
    "            else:\n",
    "                vals.append(str(entry))\n",
    "        our_variants = [x for x in row[2:] if isinstance(x, (int, float))]\n",
    "        if our_variants:\n",
    "            vals.append(f\"{max(our_variants):.4f}\")\n",
    "        else:\n",
    "            vals.append(\"N/A\")\n",
    "        print(f\"{dname:<12} \" + \" | \".join(f\"{v:<14}\" for v in vals))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e1fc2f6e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== Results for GIN (es_max) ===\n",
      "Dataset      standard       | rephine        | Ours          \n",
      "------------------------------------------------------------\n",
      "NCI109       0.7942         | 0.7966         | 0.8015        \n",
      "PROTEINS     0.7232         | 0.7500         | 0.7589        \n",
      "IMDB-BINARY  0.7500         | 0.7600         | 0.7600        \n",
      "NCI1         0.8029         | 0.8418         | 0.8370        \n"
     ]
    }
   ],
   "source": [
    "dirname = \"results/main\"\n",
    "gnn = \"gin\"\n",
    "mode = \"es_max\"\n",
    "print_table(make_table(dirname, gnn, mode=\"max\"), gnn, mode)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5114d034",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== Results for GCN (es_max) ===\n",
      "Dataset      standard       | rephine        | Ours          \n",
      "------------------------------------------------------------\n",
      "NCI109       0.7724         | 0.8039         | 0.7990        \n",
      "PROTEINS     0.7411         | 0.7411         | 0.7321        \n",
      "IMDB-BINARY  0.7200         | 0.7300         | 0.7400        \n",
      "NCI1         0.7810         | 0.8127         | 0.8224        \n"
     ]
    }
   ],
   "source": [
    "dirname = \"results/main\"\n",
    "gnn = \"gcn\"\n",
    "mode = \"es_max\"\n",
    "print_table(make_table(dirname, gnn, mode=\"max\"), gnn, mode)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a752f52",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae4dad45",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
