{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import os\n",
    "from torch_geometric.datasets import TUDataset, Planetoid, MNISTSuperpixels, ShapeNet \n",
    "from torch_geometric.loader import DataLoader\n",
    "from torch_geometric.data import Data\n",
    "\n",
    "import torch\n",
    "from torch.optim import Optimizer\n",
    "\n",
    "from gnnverification.experiment_config import ExperimentConfig, LogConfig, TrainingConfig\n",
    "from gnnverification.experiment import Experiment\n",
    "from gnnverification.models import MyGCN\n",
    "from gnnverification.exporting import export_model\n",
    "from gnnverification.utils import load_saved_model\n",
    "from gnnverification.utils import StandardEarlyStopper\n",
    "\n",
    "import ray\n",
    "from ray import train, tune\n",
    "from ray.tune import Callback\n",
    "from ray.tune.experiment import Trial"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def tune_objective(config):\n",
    "\n",
    "    experiment_config = ExperimentConfig(\n",
    "        dataset_type = \"TUDataset\",\n",
    "        dataset_name = \"PROTEINS\",\n",
    "        batch_size = config[\"experiment.batch_size\"],\n",
    "        train_validation_split = 0.8,\n",
    "        normalization=config[\"experiment.normalization\"],\n",
    "    )\n",
    "\n",
    "    log_config = LogConfig(\n",
    "        working_dir = \"C:\\\\Users\\\\meich\\\\DEV\\\\gnn-verification\\\\experiments\\\\tuning\\\\models\",\n",
    "        data_dir = \"C:\\\\Users\\\\meich\\\\DEV\\\\gnn-verification\\\\experiments\\\\data\",\n",
    "        run = f\"{experiment_config.dataset_name}_{random.randint(0, 1e6)}\",\n",
    "        wandb_log = False,\n",
    "        wandb_mode = \"online\",\n",
    "        wandb_project = \"gnn-verification\",\n",
    "        wandb_entity = \"meichelbeck\",\n",
    "    )\n",
    "\n",
    "    training_config = TrainingConfig(\n",
    "        epochs = 2e4,\n",
    "        eval_freq = 1e2,\n",
    "        early_stopper = StandardEarlyStopper(patience=-1),  # no early stopping but checkpointing\n",
    "        model = MyGCN,\n",
    "        mdl_kwargs= dict(\n",
    "            hidden_channels=config[\"model.hidden_channels\"], \n",
    "            num_layers=config[\"model.num_layers\"],\n",
    "            num_lin_layers=config[\"model.num_lin_layers\"],\n",
    "            num_rnn_layers=config[\"model.num_rnn_layers\"],\n",
    "            glob_pool_mode=config[\"model.glob_pool_mode\"],\n",
    "            act=config[\"model.act\"],\n",
    "            dropout=config[\"model.dropout\"],\n",
    "        ),\n",
    "        optimizer = torch.optim.Adam,\n",
    "        optimizer_kwargs = dict(\n",
    "            lr=config[\"optimizer.lr\"],\n",
    "            weight_decay=config[\"optimizer.weight_decay\"],\n",
    "            betas=(0.9, 0.999)\n",
    "        ),\n",
    "    )    \n",
    "\n",
    "    experiment = Experiment(\n",
    "        experiment_config=experiment_config,\n",
    "        log_config=log_config,\n",
    "    )\n",
    "\n",
    "    experiment.train(\n",
    "        training_config,\n",
    "        tuning=True,\n",
    "        with_final_eval=True,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ExportDataCallback(Callback):\n",
    "    \n",
    "    def __init__(self, experiment: Experiment):\n",
    "        super().__init__()\n",
    "        self.experiment = experiment\n",
    "    \n",
    "    def on_trial_complete(self, iteration: int, trials: list[Trial], trial: Trial, **info):\n",
    "        self.experiment.evaluate(self.experiment.model)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "hyperparameter_space = {\n",
    "    \n",
    "    \"experiment.batch_size\": tune.choice([128, 256, 512]),\n",
    "    \"experiment.normalization\": True,\n",
    "    \n",
    "    \"optimizer.lr\": tune.loguniform(1e-5, 1e-1),\n",
    "    \"optimizer.weight_decay\": 0.0, #tune.loguniform(1e-5, 1e-1),\n",
    "    \n",
    "    \"model.hidden_channels\": tune.choice([32, 64, 128]),\n",
    "    \"model.num_layers\": tune.randint(1, 4),\n",
    "    \"model.num_lin_layers\": tune.randint(1, 4),\n",
    "    \"model.num_rnn_layers\": 0,\n",
    "    \"model.dropout\": 0.0, #tune.choice([0.0, 0.25, 0.5]),\n",
    "    \"model.glob_pool_mode\": \"mean\",\n",
    "    \"model.act\": tune.choice([\"relu\", \"tanh\", \"sigmoid\"]),\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ray.tune.search.hyperopt import HyperOptSearch\n",
    "from ray.tune.schedulers import ASHAScheduler\n",
    "from ray.tune.stopper import TrialPlateauStopper\n",
    "\n",
    "tuner = tune.Tuner(\n",
    "    tune_objective,\n",
    "    tune_config=tune.TuneConfig(\n",
    "        max_concurrent_trials=8,\n",
    "        search_alg=HyperOptSearch(metric=\"eval_accuracy_mean\", mode=\"max\"),\n",
    "        num_samples=20,\n",
    "        scheduler=ASHAScheduler(metric=\"eval_accuracy_mean\", mode=\"max\", grace_period=5),\n",
    "    ),\n",
    "    run_config=train.RunConfig(\n",
    "            stop=TrialPlateauStopper(metric=\"eval_accuracy_mean\", std=0.01, num_results=5, grace_period=5),\n",
    "    ),\n",
    "    param_space=hyperparameter_space,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"tuneStatus\">\n",
       "  <div style=\"display: flex;flex-direction: row\">\n",
       "    <div style=\"display: flex;flex-direction: column;\">\n",
       "      <h3>Tune Status</h3>\n",
       "      <table>\n",
       "<tbody>\n",
       "<tr><td>Current time:</td><td>2024-03-05 12:27:20</td></tr>\n",
       "<tr><td>Running for: </td><td>01:23:29.55        </td></tr>\n",
       "<tr><td>Memory:      </td><td>16.9/31.7 GiB      </td></tr>\n",
       "</tbody>\n",
       "</table>\n",
       "    </div>\n",
       "    <div class=\"vDivider\"></div>\n",
       "    <div class=\"systemInfo\">\n",
       "      <h3>System Info</h3>\n",
       "      Using AsyncHyperBand: num_stopped=9<br>Bracket: Iter 80.000: None | Iter 20.000: None | Iter 5.000: 0.6509009009009009<br>Logical resource usage: 1.0/16 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:G)\n",
       "    </div>\n",
       "    \n",
       "  </div>\n",
       "  <div class=\"hDivider\"></div>\n",
       "  <div class=\"trialStatus\">\n",
       "    <h3>Trial Status</h3>\n",
       "    <table>\n",
       "<thead>\n",
       "<tr><th>Trial name             </th><th>status    </th><th>loc            </th><th style=\"text-align: right;\">    experiment.batch_siz\n",
       "e</th><th>experiment.normaliza\n",
       "tion     </th><th>model.act  </th><th style=\"text-align: right;\">  model.dropout</th><th>model.glob_pool_mode  </th><th style=\"text-align: right;\">    model.hidden_channel\n",
       "s</th><th style=\"text-align: right;\">  model.num_layers</th><th style=\"text-align: right;\">  model.num_lin_layers</th><th style=\"text-align: right;\">  model.num_rnn_layers</th><th style=\"text-align: right;\">  optimizer.lr</th><th style=\"text-align: right;\">  optimizer.weight_dec\n",
       "ay</th><th style=\"text-align: right;\">  iter</th><th style=\"text-align: right;\">  total time (s)</th><th style=\"text-align: right;\">  eval_accuracy_mean</th></tr>\n",
       "</thead>\n",
       "<tbody>\n",
       "<tr><td>tune_objective_96029a9f</td><td>TERMINATED</td><td>127.0.0.1:32600</td><td style=\"text-align: right;\">512</td><td>True</td><td>tanh       </td><td style=\"text-align: right;\">              0</td><td>mean                  </td><td style=\"text-align: right;\"> 32</td><td style=\"text-align: right;\">                 3</td><td style=\"text-align: right;\">                     3</td><td style=\"text-align: right;\">                     0</td><td style=\"text-align: right;\">   0.00272799 </td><td style=\"text-align: right;\">0</td><td style=\"text-align: right;\">     6</td><td style=\"text-align: right;\">         775.925</td><td style=\"text-align: right;\">            0.752252</td></tr>\n",
       "<tr><td>tune_objective_617c9781</td><td>TERMINATED</td><td>127.0.0.1:27308</td><td style=\"text-align: right;\">128</td><td>True</td><td>sigmoid    </td><td style=\"text-align: right;\">              0</td><td>mean                  </td><td style=\"text-align: right;\"> 64</td><td style=\"text-align: right;\">                 3</td><td style=\"text-align: right;\">                     1</td><td style=\"text-align: right;\">                     0</td><td style=\"text-align: right;\">   0.000132654</td><td style=\"text-align: right;\">0</td><td style=\"text-align: right;\">     5</td><td style=\"text-align: right;\">         943.801</td><td style=\"text-align: right;\">            0.601479</td></tr>\n",
       "<tr><td>tune_objective_d9781e14</td><td>TERMINATED</td><td>127.0.0.1:25388</td><td style=\"text-align: right;\">256</td><td>True</td><td>tanh       </td><td style=\"text-align: right;\">              0</td><td>mean                  </td><td style=\"text-align: right;\">128</td><td style=\"text-align: right;\">                 3</td><td style=\"text-align: right;\">                     3</td><td style=\"text-align: right;\">                     0</td><td style=\"text-align: right;\">   0.0235569  </td><td style=\"text-align: right;\">0</td><td style=\"text-align: right;\">     5</td><td style=\"text-align: right;\">        2841.51 </td><td style=\"text-align: right;\">            0.585586</td></tr>\n",
       "<tr><td>tune_objective_ebe65210</td><td>TERMINATED</td><td>127.0.0.1:30812</td><td style=\"text-align: right;\">512</td><td>True</td><td>relu       </td><td style=\"text-align: right;\">              0</td><td>mean                  </td><td style=\"text-align: right;\"> 64</td><td style=\"text-align: right;\">                 2</td><td style=\"text-align: right;\">                     1</td><td style=\"text-align: right;\">                     0</td><td style=\"text-align: right;\">   0.00725904 </td><td style=\"text-align: right;\">0</td><td style=\"text-align: right;\">     5</td><td style=\"text-align: right;\">         867.883</td><td style=\"text-align: right;\">            0.572072</td></tr>\n",
       "<tr><td>tune_objective_49442c7b</td><td>TERMINATED</td><td>127.0.0.1:15212</td><td style=\"text-align: right;\">128</td><td>True</td><td>tanh       </td><td style=\"text-align: right;\">              0</td><td>mean                  </td><td style=\"text-align: right;\">128</td><td style=\"text-align: right;\">                 1</td><td style=\"text-align: right;\">                     2</td><td style=\"text-align: right;\">                     0</td><td style=\"text-align: right;\">   0.0364751  </td><td style=\"text-align: right;\">0</td><td style=\"text-align: right;\">     5</td><td style=\"text-align: right;\">         954.126</td><td style=\"text-align: right;\">            0.355552</td></tr>\n",
       "<tr><td>tune_objective_e41335a7</td><td>TERMINATED</td><td>127.0.0.1:19584</td><td style=\"text-align: right;\">128</td><td>True</td><td>sigmoid    </td><td style=\"text-align: right;\">              0</td><td>mean                  </td><td style=\"text-align: right;\"> 64</td><td style=\"text-align: right;\">                 1</td><td style=\"text-align: right;\">                     3</td><td style=\"text-align: right;\">                     0</td><td style=\"text-align: right;\">   0.000200029</td><td style=\"text-align: right;\">0</td><td style=\"text-align: right;\">     5</td><td style=\"text-align: right;\">         565.427</td><td style=\"text-align: right;\">            0.561752</td></tr>\n",
       "<tr><td>tune_objective_3ebd34e0</td><td>TERMINATED</td><td>127.0.0.1:39624</td><td style=\"text-align: right;\">256</td><td>True</td><td>sigmoid    </td><td style=\"text-align: right;\">              0</td><td>mean                  </td><td style=\"text-align: right;\"> 64</td><td style=\"text-align: right;\">                 2</td><td style=\"text-align: right;\">                     1</td><td style=\"text-align: right;\">                     0</td><td style=\"text-align: right;\">   0.000479479</td><td style=\"text-align: right;\">0</td><td style=\"text-align: right;\">     8</td><td style=\"text-align: right;\">        1451.09 </td><td style=\"text-align: right;\">            0.698198</td></tr>\n",
       "<tr><td>tune_objective_0f798334</td><td>TERMINATED</td><td>127.0.0.1:36800</td><td style=\"text-align: right;\">512</td><td>True</td><td>tanh       </td><td style=\"text-align: right;\">              0</td><td>mean                  </td><td style=\"text-align: right;\"> 32</td><td style=\"text-align: right;\">                 2</td><td style=\"text-align: right;\">                     2</td><td style=\"text-align: right;\">                     0</td><td style=\"text-align: right;\">   0.00876753 </td><td style=\"text-align: right;\">0</td><td style=\"text-align: right;\">     6</td><td style=\"text-align: right;\">         616.486</td><td style=\"text-align: right;\">            0.68018 </td></tr>\n",
       "<tr><td>tune_objective_8eff3eb8</td><td>TERMINATED</td><td>127.0.0.1:19584</td><td style=\"text-align: right;\">256</td><td>True</td><td>relu       </td><td style=\"text-align: right;\">              0</td><td>mean                  </td><td style=\"text-align: right;\">128</td><td style=\"text-align: right;\">                 3</td><td style=\"text-align: right;\">                     1</td><td style=\"text-align: right;\">                     0</td><td style=\"text-align: right;\">   0.0145339  </td><td style=\"text-align: right;\">0</td><td style=\"text-align: right;\">     5</td><td style=\"text-align: right;\">        2445.68 </td><td style=\"text-align: right;\">            0.599099</td></tr>\n",
       "<tr><td>tune_objective_c0e20784</td><td>TERMINATED</td><td>127.0.0.1:36800</td><td style=\"text-align: right;\">256</td><td>True</td><td>relu       </td><td style=\"text-align: right;\">              0</td><td>mean                  </td><td style=\"text-align: right;\"> 32</td><td style=\"text-align: right;\">                 2</td><td style=\"text-align: right;\">                     3</td><td style=\"text-align: right;\">                     0</td><td style=\"text-align: right;\">   0.00156156 </td><td style=\"text-align: right;\">0</td><td style=\"text-align: right;\">     5</td><td style=\"text-align: right;\">         438.871</td><td style=\"text-align: right;\">            0.594595</td></tr>\n",
       "<tr><td>tune_objective_036bccb8</td><td>TERMINATED</td><td>127.0.0.1:32600</td><td style=\"text-align: right;\">256</td><td>True</td><td>tanh       </td><td style=\"text-align: right;\">              0</td><td>mean                  </td><td style=\"text-align: right;\">128</td><td style=\"text-align: right;\">                 1</td><td style=\"text-align: right;\">                     3</td><td style=\"text-align: right;\">                     0</td><td style=\"text-align: right;\">   0.00137612 </td><td style=\"text-align: right;\">0</td><td style=\"text-align: right;\">    12</td><td style=\"text-align: right;\">        2932.48 </td><td style=\"text-align: right;\">            0.693694</td></tr>\n",
       "<tr><td>tune_objective_73f8a538</td><td>TERMINATED</td><td>127.0.0.1:30812</td><td style=\"text-align: right;\">128</td><td>True</td><td>relu       </td><td style=\"text-align: right;\">              0</td><td>mean                  </td><td style=\"text-align: right;\">128</td><td style=\"text-align: right;\">                 3</td><td style=\"text-align: right;\">                     2</td><td style=\"text-align: right;\">                     0</td><td style=\"text-align: right;\">   0.0268323  </td><td style=\"text-align: right;\">0</td><td style=\"text-align: right;\">     5</td><td style=\"text-align: right;\">        2466.2  </td><td style=\"text-align: right;\">            0.617104</td></tr>\n",
       "<tr><td>tune_objective_1ffa6ca1</td><td>TERMINATED</td><td>127.0.0.1:27308</td><td style=\"text-align: right;\">512</td><td>True</td><td>tanh       </td><td style=\"text-align: right;\">              0</td><td>mean                  </td><td style=\"text-align: right;\">128</td><td style=\"text-align: right;\">                 3</td><td style=\"text-align: right;\">                     3</td><td style=\"text-align: right;\">                     0</td><td style=\"text-align: right;\">   0.00028506 </td><td style=\"text-align: right;\">0</td><td style=\"text-align: right;\">    13</td><td style=\"text-align: right;\">        4048.22 </td><td style=\"text-align: right;\">            0.689189</td></tr>\n",
       "<tr><td>tune_objective_c18eafe2</td><td>TERMINATED</td><td>127.0.0.1:15212</td><td style=\"text-align: right;\">256</td><td>True</td><td>relu       </td><td style=\"text-align: right;\">              0</td><td>mean                  </td><td style=\"text-align: right;\"> 32</td><td style=\"text-align: right;\">                 1</td><td style=\"text-align: right;\">                     2</td><td style=\"text-align: right;\">                     0</td><td style=\"text-align: right;\">   0.00072707 </td><td style=\"text-align: right;\">0</td><td style=\"text-align: right;\">     5</td><td style=\"text-align: right;\">         245.032</td><td style=\"text-align: right;\">            0.594595</td></tr>\n",
       "<tr><td>tune_objective_cb3d2f32</td><td>TERMINATED</td><td>127.0.0.1:36800</td><td style=\"text-align: right;\">128</td><td>True</td><td>relu       </td><td style=\"text-align: right;\">              0</td><td>mean                  </td><td style=\"text-align: right;\">128</td><td style=\"text-align: right;\">                 2</td><td style=\"text-align: right;\">                     1</td><td style=\"text-align: right;\">                     0</td><td style=\"text-align: right;\">   0.00334782 </td><td style=\"text-align: right;\">0</td><td style=\"text-align: right;\">     5</td><td style=\"text-align: right;\">        1486.88 </td><td style=\"text-align: right;\">            0.625332</td></tr>\n",
       "<tr><td>tune_objective_3f0fd44b</td><td>TERMINATED</td><td>127.0.0.1:15212</td><td style=\"text-align: right;\">256</td><td>True</td><td>relu       </td><td style=\"text-align: right;\">              0</td><td>mean                  </td><td style=\"text-align: right;\">128</td><td style=\"text-align: right;\">                 3</td><td style=\"text-align: right;\">                     2</td><td style=\"text-align: right;\">                     0</td><td style=\"text-align: right;\">   0.0792366  </td><td style=\"text-align: right;\">0</td><td style=\"text-align: right;\">     5</td><td style=\"text-align: right;\">        2361.3  </td><td style=\"text-align: right;\">            0.585586</td></tr>\n",
       "<tr><td>tune_objective_88a3949b</td><td>TERMINATED</td><td>127.0.0.1:39624</td><td style=\"text-align: right;\">256</td><td>True</td><td>relu       </td><td style=\"text-align: right;\">              0</td><td>mean                  </td><td style=\"text-align: right;\"> 64</td><td style=\"text-align: right;\">                 3</td><td style=\"text-align: right;\">                     1</td><td style=\"text-align: right;\">                     0</td><td style=\"text-align: right;\">   0.0236022  </td><td style=\"text-align: right;\">0</td><td style=\"text-align: right;\">     5</td><td style=\"text-align: right;\">        1054.87 </td><td style=\"text-align: right;\">            0.630631</td></tr>\n",
       "<tr><td>tune_objective_9f814e04</td><td>TERMINATED</td><td>127.0.0.1:39624</td><td style=\"text-align: right;\">256</td><td>True</td><td>sigmoid    </td><td style=\"text-align: right;\">              0</td><td>mean                  </td><td style=\"text-align: right;\"> 64</td><td style=\"text-align: right;\">                 2</td><td style=\"text-align: right;\">                     1</td><td style=\"text-align: right;\">                     0</td><td style=\"text-align: right;\">   0.000573038</td><td style=\"text-align: right;\">0</td><td style=\"text-align: right;\">     5</td><td style=\"text-align: right;\">         774.343</td><td style=\"text-align: right;\">            0.644144</td></tr>\n",
       "<tr><td>tune_objective_d4deee10</td><td>TERMINATED</td><td>127.0.0.1:36800</td><td style=\"text-align: right;\">256</td><td>True</td><td>sigmoid    </td><td style=\"text-align: right;\">              0</td><td>mean                  </td><td style=\"text-align: right;\"> 64</td><td style=\"text-align: right;\">                 3</td><td style=\"text-align: right;\">                     1</td><td style=\"text-align: right;\">                     0</td><td style=\"text-align: right;\">   0.00112487 </td><td style=\"text-align: right;\">0</td><td style=\"text-align: right;\">     5</td><td style=\"text-align: right;\">         914.736</td><td style=\"text-align: right;\">            0.635135</td></tr>\n",
       "<tr><td>tune_objective_550a8c5d</td><td>TERMINATED</td><td>127.0.0.1:25388</td><td style=\"text-align: right;\">128</td><td>True</td><td>tanh       </td><td style=\"text-align: right;\">              0</td><td>mean                  </td><td style=\"text-align: right;\"> 32</td><td style=\"text-align: right;\">                 2</td><td style=\"text-align: right;\">                     1</td><td style=\"text-align: right;\">                     0</td><td style=\"text-align: right;\">   0.000214757</td><td style=\"text-align: right;\">0</td><td style=\"text-align: right;\">     5</td><td style=\"text-align: right;\">         379.136</td><td style=\"text-align: right;\">            0.589428</td></tr>\n",
       "</tbody>\n",
       "</table>\n",
       "  </div>\n",
       "</div>\n",
       "<style>\n",
       ".tuneStatus {\n",
       "  color: var(--jp-ui-font-color1);\n",
       "}\n",
       ".tuneStatus .systemInfo {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "}\n",
       ".tuneStatus td {\n",
       "  white-space: nowrap;\n",
       "}\n",
       ".tuneStatus .trialStatus {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "}\n",
       ".tuneStatus h3 {\n",
       "  font-weight: bold;\n",
       "}\n",
       ".tuneStatus .hDivider {\n",
       "  border-bottom-width: var(--jp-border-width);\n",
       "  border-bottom-color: var(--jp-border-color0);\n",
       "  border-bottom-style: solid;\n",
       "}\n",
       ".tuneStatus .vDivider {\n",
       "  border-left-width: var(--jp-border-width);\n",
       "  border-left-color: var(--jp-border-color0);\n",
       "  border-left-style: solid;\n",
       "  margin: 0.5em 1em 0.5em 1em;\n",
       "}\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-05 12:27:20,293\tINFO tune.py:1042 -- Total run time: 5009.66 seconds (5009.53 seconds for the tuning loop).\n"
     ]
    }
   ],
   "source": [
    "results = tuner.fit()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "results.get_dataframe().sort_values(\"eval_accuracy_mean\", ascending=False).to_csv(\"tuning_results.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['eval_accuracy_mean', 'timestamp', 'checkpoint_dir_name', 'done',\n",
       "       'training_iteration', 'trial_id', 'date', 'time_this_iter_s',\n",
       "       'time_total_s', 'pid', 'hostname', 'node_ip', 'time_since_restore',\n",
       "       'iterations_since_restore', 'config/experiment.batch_size',\n",
       "       'config/experiment.normalization', 'config/optimizer.lr',\n",
       "       'config/optimizer.weight_decay', 'config/model.hidden_channels',\n",
       "       'config/model.num_layers', 'config/model.num_lin_layers',\n",
       "       'config/model.num_rnn_layers', 'config/model.dropout',\n",
       "       'config/model.glob_pool_mode', 'config/model.act', 'logdir'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results.get_dataframe().columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
