{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import os\n",
    "from torch_geometric.datasets import TUDataset, Planetoid, MNISTSuperpixels, ShapeNet \n",
    "from torch_geometric.loader import DataLoader\n",
    "from torch_geometric.data import Data\n",
    "\n",
    "import torch\n",
    "from torch.optim import Optimizer\n",
    "\n",
    "from gnnverification.experiment_config import ExperimentConfig, LogConfig, TrainingConfig\n",
    "from gnnverification.experiment import Experiment\n",
    "from gnnverification.models import MyGCN\n",
    "from gnnverification.exporting import export_model\n",
    "from gnnverification.utils import load_saved_model\n",
    "from gnnverification.utils import StandardEarlyStopper\n",
    "\n",
    "from hyperparams import *\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "hyperparams = CORA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_config = ExperimentConfig(\n",
    "    dataset_type = \"Planetoid\",\n",
    "    dataset_name = \"Cora\",\n",
    "    batch_size = hyperparams[\"experiment.batch_size\"],\n",
    "    train_validation_split = 0.8,\n",
    "    normalization=hyperparams[\"experiment.normalization\"],\n",
    ")\n",
    "\n",
    "log_config = LogConfig(\n",
    "    working_dir = \"models\",\n",
    "    data_dir = \"data\",\n",
    "    run = f\"{experiment_config.dataset_name}_{random.randint(0, 1e6)}\",\n",
    "    wandb_log = True,\n",
    "    wandb_mode = \"online\",\n",
    "    wandb_project = \"gnn-verification\",\n",
    "    wandb_entity = \"meichelbeck\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_config = TrainingConfig(\n",
    "    epochs = 2e4,\n",
    "    eval_freq = 1e2,\n",
    "    early_stopper = StandardEarlyStopper(patience=10),\n",
    "    model = MyGCN,\n",
    "    mdl_kwargs= dict(\n",
    "        hidden_channels=hyperparams[\"model.hidden_channels\"], \n",
    "        num_layers=hyperparams[\"model.num_layers\"],\n",
    "        num_lin_layers=hyperparams[\"model.num_lin_layers\"],\n",
    "        num_rnn_layers=hyperparams[\"model.num_rnn_layers\"],\n",
    "        glob_pool_mode=hyperparams[\"model.glob_pool_mode\"],\n",
    "        act=hyperparams[\"model.act\"],\n",
    "        dropout=hyperparams[\"model.dropout\"],\n",
    "    ),\n",
    "    optimizer = torch.optim.Adam,\n",
    "    optimizer_kwargs = dict(\n",
    "        lr=hyperparams[\"optimizer.lr\"],\n",
    "        weight_decay=hyperparams[\"optimizer.weight_decay\"],\n",
    "        betas=(0.9, 0.999)\n",
    "    ),\n",
    ")    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment = Experiment(\n",
    "    experiment_config=experiment_config,\n",
    "    log_config=log_config,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mmeichelbeck\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "wandb version 0.16.3 is available!  To upgrade, please run:\n",
       " $ pip install wandb --upgrade"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Tracking run with wandb version 0.16.1"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Run data is saved locally in <code>c:\\Users\\meich\\DEV\\gnn-verification\\experiments\\wandb\\run-20240305_143806-22sddm1u</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Syncing run <strong><a href='https://wandb.ai/meichelbeck/gnn-verification/runs/22sddm1u' target=\"_blank\">Cora_717338</a></strong> to <a href='https://wandb.ai/meichelbeck/gnn-verification' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View project at <a href='https://wandb.ai/meichelbeck/gnn-verification' target=\"_blank\">https://wandb.ai/meichelbeck/gnn-verification</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run at <a href='https://wandb.ai/meichelbeck/gnn-verification/runs/22sddm1u' target=\"_blank\">https://wandb.ai/meichelbeck/gnn-verification/runs/22sddm1u</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 2/20000 [00:00<21:09, 15.75it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving checkpoint\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  1%|          | 101/20000 [00:04<16:04, 20.64it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving checkpoint\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|▌         | 1100/20000 [00:51<14:47, 21.30it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Early stopping at epoch 1100. Using checkpoint from epoch 100\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a18270cb60f34f64b74dd7cf1cd10d26",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\\r'), FloatProgress(value=1.0, max=1.0)))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<style>\n",
       "    table.wandb td:nth-child(1) { padding: 0 10px; text-align: left ; width: auto;} td:nth-child(2) {text-align: left ; width: 100%}\n",
       "    .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; justify-content: flex-start; width: 100% }\n",
       "    .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n",
       "    </style>\n",
       "<div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>epoch_loss_mean</td><td>█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr><tr><td>epoch_loss_std</td><td>▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr><tr><td>eval_accuracy_mean</td><td>▁███████████</td></tr><tr><td>eval_accuracy_std</td><td>▁▁▁▁▁▁▁▁▁▁▁▁</td></tr><tr><td>eval_loss_mean</td><td>█▁▁▁▁▁▁▁▁▁▁▁</td></tr><tr><td>eval_loss_std</td><td>▁▁▁▁▁▁▁▁▁▁▁▁</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>epoch_loss_mean</td><td>0.59448</td></tr><tr><td>epoch_loss_std</td><td>0.0</td></tr><tr><td>eval_accuracy_mean</td><td>0.744</td></tr><tr><td>eval_accuracy_std</td><td>0.0</td></tr><tr><td>eval_loss_mean</td><td>1.08686</td></tr><tr><td>eval_loss_std</td><td>0.0</td></tr></table><br/></div></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run <strong style=\"color:#cdcd00\">Cora_717338</strong> at: <a href='https://wandb.ai/meichelbeck/gnn-verification/runs/22sddm1u' target=\"_blank\">https://wandb.ai/meichelbeck/gnn-verification/runs/22sddm1u</a><br/>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Find logs at: <code>.\\wandb\\run-20240305_143806-22sddm1u\\logs</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'test_loss_mean': 1.0793123535216036, 'test_loss_std': 0.0, 'test_accuracy': 0.75}\n"
     ]
    }
   ],
   "source": [
    "experiment.train(\n",
    "    training_config,\n",
    "    with_final_eval=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load and export model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "run_id = log_config.run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = os.path.join(log_config.working_dir, run_id, \"current_model.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "export_model(load_saved_model(experiment.model, model_path), os.path.join(log_config.working_dir, run_id))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'test_loss_mean': 1.0793123535216036, 'test_loss_std': 0.0, 'test_accuracy': 0.75}\n"
     ]
    }
   ],
   "source": [
    "experiment.evaluate(load_saved_model(experiment.model, model_path))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
