{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MIMIC-III "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Input Parsing (for command line use)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-28T18:19:17.408293Z",
     "iopub.status.busy": "2022-09-28T18:19:17.407915Z",
     "iopub.status.idle": "2022-09-28T18:19:17.416701Z",
     "shell.execute_reply": "2022-09-28T18:19:17.416386Z",
     "shell.execute_reply.started": "2022-09-28T18:19:17.408277Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Namespace(quiet=False, run_id=None, config=None, epochs=50, fold=0, batch_size=64, learn_rate=0.001, betas=(0.9, 0.999), weight_decay=0.001, hidden_size=128, latent_size=160, kernel_init='skew-symmetric', note='', seed=None)\n"
     ]
    }
   ],
   "source": [
    "import argparse\n",
    "\n",
    "# fmt: off\n",
    "parser = argparse.ArgumentParser(description=\"Training Script for MIMIC dataset.\")\n",
    "parser.add_argument(\"-q\",  \"--quiet\",        default=False,  const=True, help=\"kernel-inititialization\", nargs=\"?\")\n",
    "parser.add_argument(\"-r\",  \"--run_id\",       default=None,   type=str,   help=\"run_id\")\n",
    "parser.add_argument(\"-c\",  \"--config\",       default=None,   type=str,   help=\"load external config\", nargs=2)\n",
    "parser.add_argument(\"-e\",  \"--epochs\",       default=50,    type=int,   help=\"maximum epochs\")\n",
    "parser.add_argument(\"-f\",  \"--fold\",         default=0,      type=int,   help=\"fold number\")\n",
    "parser.add_argument(\"-bs\", \"--batch-size\",   default=64,     type=int,   help=\"batch-size\")\n",
    "parser.add_argument(\"-lr\", \"--learn-rate\",   default=0.001,  type=float, help=\"learn-rate\")\n",
    "parser.add_argument(\"-b\",  \"--betas\", default=(0.9, 0.999),  type=float, help=\"adam betas\", nargs=2)\n",
    "parser.add_argument(\"-wd\", \"--weight-decay\", default=0.001,  type=float, help=\"weight-decay\")\n",
    "parser.add_argument(\"-hs\", \"--hidden-size\",  default=128,    type=int,   help=\"hidden-size\")\n",
    "parser.add_argument(\"-ls\", \"--latent-size\",  default=160,    type=int,   help=\"latent-size\")\n",
    "parser.add_argument(\"-ki\", \"--kernel-init\",  default=\"skew-symmetric\",   help=\"kernel-inititialization\")\n",
    "parser.add_argument(\"-n\",  \"--note\",         default=\"\",     type=str,   help=\"Note that can be added\")\n",
    "parser.add_argument(\"-s\",  \"--seed\",         default=None,   type=int,   help=\"Set the random seed.\")\n",
    "# fmt: on\n",
    "\n",
    "try:\n",
    "    get_ipython().run_line_magic(\n",
    "        \"config\", \"InteractiveShell.ast_node_interactivity='last_expr_or_assign'\"\n",
    "    )\n",
    "except NameError:\n",
    "    ARGS = parser.parse_args()\n",
    "else:\n",
    "    ARGS = parser.parse_args(\"\")\n",
    "\n",
    "print(ARGS)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load from config file if provided"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-28T18:19:18.057470Z",
     "iopub.status.busy": "2022-09-28T18:19:18.057289Z",
     "iopub.status.idle": "2022-09-28T18:19:18.060361Z",
     "shell.execute_reply": "2022-09-28T18:19:18.059985Z",
     "shell.execute_reply.started": "2022-09-28T18:19:18.057458Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Namespace(quiet=False, run_id=None, config=None, epochs=50, fold=0, batch_size=64, learn_rate=0.001, betas=(0.9, 0.999), weight_decay=0.001, hidden_size=128, latent_size=160, kernel_init='skew-symmetric', note='', seed=None)\n"
     ]
    }
   ],
   "source": [
    "import yaml\n",
    "\n",
    "if ARGS.config is not None:\n",
    "    cfg_file, cfg_id = ARGS.config\n",
    "    with open(cfg_file, \"r\") as file:\n",
    "        cfg_dict = yaml.safe_load(file)\n",
    "        vars(ARGS).update(**cfg_dict[int(cfg_id)])\n",
    "\n",
    "print(ARGS)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Global Variables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-28T18:19:18.346543Z",
     "iopub.status.busy": "2022-09-28T18:19:18.346304Z",
     "iopub.status.idle": "2022-09-28T18:19:18.351812Z",
     "shell.execute_reply": "2022-09-28T18:19:18.351528Z",
     "shell.execute_reply.started": "2022-09-28T18:19:18.346527Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>.jp-OutputArea-prompt:empty {padding: 0; border: 0;}</style>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import logging\n",
    "import os\n",
    "import random\n",
    "import warnings\n",
    "from datetime import datetime\n",
    "from pathlib import Path\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torchinfo\n",
    "from IPython.core.display import HTML\n",
    "from torch import Tensor, jit\n",
    "from tqdm.autonotebook import tqdm, trange\n",
    "\n",
    "os.environ[\"CUDA_LAUNCH_BLOCKING\"] = \"1\"\n",
    "torch.backends.cuda.matmul.allow_tf32 = True\n",
    "torch.backends.cudnn.allow_tf32 = True\n",
    "# torch.jit.enable_onednn_fusion(True)\n",
    "torch.backends.cudnn.benchmark = True\n",
    "# torch.multiprocessing.set_start_method('spawn')\n",
    "\n",
    "warnings.filterwarnings(action=\"ignore\", category=UserWarning, module=\"torch\")\n",
    "logging.basicConfig(level=logging.WARN)\n",
    "HTML(\"<style>.jp-OutputArea-prompt:empty {padding: 0; border: 0;}</style>\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hyperparameter choices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-28T18:19:18.843610Z",
     "iopub.status.busy": "2022-09-28T18:19:18.843426Z",
     "iopub.status.idle": "2022-09-28T18:19:18.848799Z",
     "shell.execute_reply": "2022-09-28T18:19:18.848422Z",
     "shell.execute_reply.started": "2022-09-28T18:19:18.843598Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'dataset': 'MIMIC-III',\n",
       " 'model': 'LinODEnet',\n",
       " 'fold': 0,\n",
       " 'seed': None,\n",
       " 'max_epochs': 50,\n",
       " 'batch_size': 64,\n",
       " 'hidden_size': 128,\n",
       " 'latent_size': 160,\n",
       " 'kernel-initialization': 'skew-symmetric',\n",
       " 'lr': 0.001,\n",
       " 'betas': tensor([0.9000, 0.9990]),\n",
       " 'weight_decay': 0.001}"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "if ARGS.seed is not None:\n",
    "    torch.manual_seed(ARGS.seed)\n",
    "    random.seed(ARGS.seed)\n",
    "    np.random.seed(ARGS.seed)\n",
    "\n",
    "OPTIMIZER_CONFIG = {\n",
    "    \"lr\": ARGS.learn_rate,\n",
    "    \"betas\": torch.tensor(ARGS.betas),\n",
    "    \"weight_decay\": ARGS.weight_decay,\n",
    "}\n",
    "\n",
    "hparam_dict = {\n",
    "    \"dataset\": (DATASET := \"MIMIC-III\"),\n",
    "    \"model\": (MODEL_NAME := \"LinODEnet\"),\n",
    "    \"fold\": ARGS.fold,\n",
    "    \"seed\": ARGS.seed,\n",
    "    \"max_epochs\": ARGS.epochs,\n",
    "    \"batch_size\": ARGS.batch_size,\n",
    "    \"hidden_size\": ARGS.hidden_size,\n",
    "    \"latent_size\": ARGS.latent_size,\n",
    "    \"kernel-initialization\": ARGS.kernel_init,\n",
    "} | OPTIMIZER_CONFIG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-28T18:19:19.092108Z",
     "iopub.status.busy": "2022-09-28T18:19:19.091928Z",
     "iopub.status.idle": "2022-09-28T18:19:19.095824Z",
     "shell.execute_reply": "2022-09-28T18:19:19.095522Z",
     "shell.execute_reply.started": "2022-09-28T18:19:19.092095Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "CONFIG_STR = f\"f={ARGS.fold}_bs={ARGS.batch_size}_lr={ARGS.learn_rate}_hs={ARGS.hidden_size}_ls={ARGS.latent_size}\"\n",
    "RUN_ID = ARGS.run_id or datetime.now().isoformat(timespec=\"seconds\")\n",
    "CFG_ID = 0 if ARGS.config is None else ARGS.config[1]\n",
    "HOME = Path.cwd()\n",
    "\n",
    "LOGGING_DIR = HOME / \"tensorboard\" / DATASET / MODEL_NAME / RUN_ID / CONFIG_STR\n",
    "CKPOINT_DIR = HOME / \"checkpoints\" / DATASET / MODEL_NAME / RUN_ID / CONFIG_STR\n",
    "RESULTS_DIR = HOME / \"results\" / DATASET / MODEL_NAME / RUN_ID\n",
    "LOGGING_DIR.mkdir(parents=True, exist_ok=True)\n",
    "CKPOINT_DIR.mkdir(parents=True, exist_ok=True)\n",
    "RESULTS_DIR.mkdir(parents=True, exist_ok=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize Task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-28T18:19:19.787148Z",
     "iopub.status.busy": "2022-09-28T18:19:19.786833Z",
     "iopub.status.idle": "2022-09-28T18:19:20.089145Z",
     "shell.execute_reply": "2022-09-28T18:19:20.088733Z",
     "shell.execute_reply.started": "2022-09-28T18:19:19.787129Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/rscholz/Projects/KIWI/tsdm/src/tsdm/datasets/base.py:308: UserWarning: File 'timeseries.parquet' failed to validate!File hash 'e9eea0cf21e141fc7725a0a6b37f97833cc12199cc302b7aebd6f6b0b120b3c5' does not match reference '2ebb7da820560f420f71c0b6fb068a46449ef89b238e97ba81659220fae8151b'.𝗜𝗴𝗻𝗼𝗿𝗲 𝘁𝗵𝗶𝘀 𝘄𝗮𝗿𝗻𝗶𝗻𝗴 𝗶𝗳 𝘁𝗵𝗲 𝗳𝗶𝗹𝗲 𝗳𝗼𝗿𝗺𝗮𝘁 𝗶𝘀 𝗽𝗮𝗿𝗾𝘂𝗲𝘁.\n",
      "  warnings.warn(\n",
      "/home/rscholz/Projects/KIWI/tsdm/src/tsdm/datasets/base.py:308: UserWarning: File 'metadata.parquet' failed to validate!File hash '554b9759a5e8d79c3d98619ac371c35bd6f03dbfca26750252be9af9137bb71e' does not match reference '4779aa3639f468126ea263645510d5395d85b73caf1c7abb0a486561b761f5b4'.𝗜𝗴𝗻𝗼𝗿𝗲 𝘁𝗵𝗶𝘀 𝘄𝗮𝗿𝗻𝗶𝗻𝗴 𝗶𝗳 𝘁𝗵𝗲 𝗳𝗶𝗹𝗲 𝗳𝗼𝗿𝗺𝗮𝘁 𝗶𝘀 𝗽𝗮𝗿𝗾𝘂𝗲𝘁.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "MIMIC_III_DeBrouwer2019(test_metric=MSELoss)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from tsdm.tasks.mimic_iii_debrouwer2019 import MIMIC_III_DeBrouwer2019\n",
    "\n",
    "TASK = MIMIC_III_DeBrouwer2019()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize DataLoaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-28T18:19:22.374413Z",
     "iopub.status.busy": "2022-09-28T18:19:22.373780Z",
     "iopub.status.idle": "2022-09-28T18:19:24.910227Z",
     "shell.execute_reply": "2022-09-28T18:19:24.909742Z",
     "shell.execute_reply.started": "2022-09-28T18:19:22.374396Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'train': <torch.utils.data.dataloader.DataLoader at 0x7fe2f04ccaf0>,\n",
       " 'valid': <torch.utils.data.dataloader.DataLoader at 0x7fe316a43670>,\n",
       " 'test': <torch.utils.data.dataloader.DataLoader at 0x7fe2f0517880>}"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from tsdm.tasks.mimic_iii_debrouwer2019 import mimic_collate as task_collate_fn\n",
    "\n",
    "dloader_config_train = {\n",
    "    \"batch_size\": ARGS.batch_size,\n",
    "    \"shuffle\": True,\n",
    "    \"drop_last\": True,\n",
    "    \"pin_memory\": True,\n",
    "    \"num_workers\": 4,\n",
    "    \"collate_fn\": task_collate_fn,\n",
    "}\n",
    "\n",
    "dloader_config_infer = {\n",
    "    \"batch_size\": 256,\n",
    "    \"shuffle\": False,\n",
    "    \"drop_last\": False,\n",
    "    \"pin_memory\": True,\n",
    "    \"num_workers\": 0,\n",
    "    \"collate_fn\": task_collate_fn,\n",
    "}\n",
    "\n",
    "TRAIN_LOADER = TASK.get_dataloader((ARGS.fold, \"train\"), **dloader_config_train)\n",
    "INFER_LOADER = TASK.get_dataloader((ARGS.fold, \"train\"), **dloader_config_infer)\n",
    "VALID_LOADER = TASK.get_dataloader((ARGS.fold, \"valid\"), **dloader_config_infer)\n",
    "TEST_LOADER = TASK.get_dataloader((ARGS.fold, \"test\"), **dloader_config_infer)\n",
    "EVAL_LOADERS = {\"train\": INFER_LOADER, \"valid\": VALID_LOADER, \"test\": TEST_LOADER}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-28T18:19:24.910970Z",
     "iopub.status.busy": "2022-09-28T18:19:24.910818Z",
     "iopub.status.idle": "2022-09-28T18:19:24.916217Z",
     "shell.execute_reply": "2022-09-28T18:19:24.915926Z",
     "shell.execute_reply.started": "2022-09-28T18:19:24.910960Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch.jit.ScriptFunction at 0x7fe31321c680>"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def MSE(y: Tensor, yhat: Tensor) -> Tensor:\n",
    "    return torch.mean((y - yhat) ** 2)\n",
    "\n",
    "\n",
    "def MAE(y: Tensor, yhat: Tensor) -> Tensor:\n",
    "    return torch.mean(torch.abs(y - yhat))\n",
    "\n",
    "\n",
    "def RMSE(y: Tensor, yhat: Tensor) -> Tensor:\n",
    "    return torch.sqrt(torch.mean((y - yhat) ** 2))\n",
    "\n",
    "\n",
    "METRICS = {\n",
    "    \"RMSE\": jit.script(RMSE),\n",
    "    \"MSE\": jit.script(MSE),\n",
    "    \"MAE\": jit.script(MAE),\n",
    "}\n",
    "LOSS = jit.script(MSE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-28T18:19:25.562138Z",
     "iopub.status.busy": "2022-09-28T18:19:25.561777Z",
     "iopub.status.idle": "2022-09-28T18:19:25.702885Z",
     "shell.execute_reply": "2022-09-28T18:19:25.702375Z",
     "shell.execute_reply.started": "2022-09-28T18:19:25.562122Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "=================================================================\n",
       "Layer (type:depth-idx)                   Param #\n",
       "=================================================================\n",
       "LinODEnet                                160\n",
       "├─ConcatEmbedding: 1-1                   32\n",
       "├─ResNet: 1-2                            --\n",
       "│    └─ResNetBlock: 2-1                  --\n",
       "│    │    └─ReverseDense: 3-1            25,760\n",
       "│    │    └─ReverseDense: 3-2            25,760\n",
       "│    │    └─ReZeroCell: 3-3              1\n",
       "│    └─ResNetBlock: 2-2                  --\n",
       "│    │    └─ReverseDense: 3-4            25,760\n",
       "│    │    └─ReverseDense: 3-5            25,760\n",
       "│    │    └─ReZeroCell: 3-6              1\n",
       "│    └─ResNetBlock: 2-3                  --\n",
       "│    │    └─ReverseDense: 3-7            25,760\n",
       "│    │    └─ReverseDense: 3-8            25,760\n",
       "│    │    └─ReZeroCell: 3-9              1\n",
       "│    └─ResNetBlock: 2-4                  --\n",
       "│    │    └─ReverseDense: 3-10           25,760\n",
       "│    │    └─ReverseDense: 3-11           25,760\n",
       "│    │    └─ReZeroCell: 3-12             1\n",
       "│    └─ResNetBlock: 2-5                  --\n",
       "│    │    └─ReverseDense: 3-13           25,760\n",
       "│    │    └─ReverseDense: 3-14           25,760\n",
       "│    │    └─ReZeroCell: 3-15             1\n",
       "├─LinODECell: 1-3                        25,601\n",
       "├─ResNet: 1-4                            --\n",
       "│    └─ResNetBlock: 2-6                  --\n",
       "│    │    └─ReverseDense: 3-16           25,760\n",
       "│    │    └─ReverseDense: 3-17           25,760\n",
       "│    │    └─ReZeroCell: 3-18             1\n",
       "│    └─ResNetBlock: 2-7                  --\n",
       "│    │    └─ReverseDense: 3-19           25,760\n",
       "│    │    └─ReverseDense: 3-20           25,760\n",
       "│    │    └─ReZeroCell: 3-21             1\n",
       "│    └─ResNetBlock: 2-8                  --\n",
       "│    │    └─ReverseDense: 3-22           25,760\n",
       "│    │    └─ReverseDense: 3-23           25,760\n",
       "│    │    └─ReZeroCell: 3-24             1\n",
       "│    └─ResNetBlock: 2-9                  --\n",
       "│    │    └─ReverseDense: 3-25           25,760\n",
       "│    │    └─ReverseDense: 3-26           25,760\n",
       "│    │    └─ReZeroCell: 3-27             1\n",
       "│    └─ResNetBlock: 2-10                 --\n",
       "│    │    └─ReverseDense: 3-28           25,760\n",
       "│    │    └─ReverseDense: 3-29           25,760\n",
       "│    │    └─ReZeroCell: 3-30             1\n",
       "├─ConcatProjection: 1-5                  32\n",
       "├─SequentialFilter: 1-6                  --\n",
       "│    └─LinearFilter: 2-11                32,771\n",
       "│    └─NonLinearFilter: 2-12             32,769\n",
       "│    │    └─Sequential: 3-31             32,768\n",
       "│    └─NonLinearFilter: 2-13             32,769\n",
       "│    │    └─Sequential: 3-32             32,768\n",
       "=================================================================\n",
       "Total params: 704,880\n",
       "Trainable params: 704,879\n",
       "Non-trainable params: 1\n",
       "================================================================="
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from linodenet.models import LinODEnet, ResNet, embeddings, filters, system\n",
    "\n",
    "MODEL_CONFIG = {\n",
    "    \"__name__\": \"LinODEnet\",\n",
    "    \"input_size\": TASK.dataset.shape[-1],\n",
    "    \"hidden_size\": ARGS.hidden_size,\n",
    "    \"latent_size\": ARGS.latent_size,\n",
    "    \"Filter\": filters.SequentialFilter.HP | {\"autoregressive\": True},\n",
    "    \"System\": system.LinODECell.HP | {\"kernel_initialization\": ARGS.kernel_init},\n",
    "    \"Encoder\": ResNet.HP,\n",
    "    \"Decoder\": ResNet.HP,\n",
    "    \"Embedding\": embeddings.ConcatEmbedding.HP,\n",
    "    \"Projection\": embeddings.ConcatProjection.HP,\n",
    "}\n",
    "\n",
    "MODEL = LinODEnet(**MODEL_CONFIG).to(DEVICE)\n",
    "MODEL = torch.jit.script(MODEL)\n",
    "torchinfo.summary(MODEL)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Warm-Up"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def predict_fn(model, batch) -> tuple[Tensor, Tensor]:\n",
    "    \"\"\"Get targets and predictions.\"\"\"\n",
    "    T, X, M, _, Y, MY = (tensor.to(DEVICE) for tensor in batch)\n",
    "    YHAT = model(T, X)\n",
    "    return Y[MY], YHAT[M]\n",
    "\n",
    "\n",
    "batch = next(iter(TRAIN_LOADER))\n",
    "MODEL.zero_grad(set_to_none=True)\n",
    "\n",
    "# Forward\n",
    "Y, YHAT = predict_fn(MODEL, batch)\n",
    "\n",
    "# Backward\n",
    "R = LOSS(Y, YHAT)\n",
    "assert torch.isfinite(R).item(), \"Model Collapsed!\"\n",
    "R.backward()\n",
    "\n",
    "# Reset\n",
    "MODEL.zero_grad(set_to_none=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Initialize Optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from torch.optim import AdamW\n",
    "\n",
    "OPTIMIZER = AdamW(MODEL.parameters(), **OPTIMIZER_CONFIG)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize Logging"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from torch.utils.tensorboard import SummaryWriter\n",
    "\n",
    "from tsdm.logutils import StandardLogger\n",
    "\n",
    "WRITER = SummaryWriter(LOGGING_DIR)\n",
    "LOGGER = StandardLogger(\n",
    "    writer=WRITER,\n",
    "    model=MODEL,\n",
    "    optimizer=OPTIMIZER,\n",
    "    metrics=METRICS,\n",
    "    dataloaders=EVAL_LOADERS,\n",
    "    hparam_dict=hparam_dict,\n",
    "    checkpoint_dir=CKPOINT_DIR,\n",
    "    predict_fn=predict_fn,\n",
    "    results_dir=RESULTS_DIR,\n",
    ")\n",
    "LOGGER.log_epoch_end(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "total_num_batches = 0\n",
    "for epoch in trange(1, ARGS.epochs, desc=\"Epoch\", position=0):\n",
    "    for batch in tqdm(\n",
    "        TRAIN_LOADER, desc=\"Batch\", leave=False, position=1, disable=ARGS.quiet\n",
    "    ):\n",
    "        total_num_batches += 1\n",
    "        MODEL.zero_grad(set_to_none=True)\n",
    "\n",
    "        # Forward\n",
    "        Y, YHAT = predict_fn(MODEL, batch)\n",
    "        R = LOSS(Y, YHAT)\n",
    "        assert torch.isfinite(R).item(), \"Model Collapsed!\"\n",
    "\n",
    "        # Backward\n",
    "        R.backward()\n",
    "        OPTIMIZER.step()\n",
    "\n",
    "        # Logging\n",
    "        LOGGER.log_batch_end(total_num_batches, targets=Y, predics=YHAT)\n",
    "    LOGGER.log_epoch_end(epoch)\n",
    "\n",
    "LOGGER.log_history(CFG_ID)\n",
    "LOGGER.log_hparams(CFG_ID)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "vscode": {
   "interpreter": {
    "hash": "ec22ec550e91a21331b7b83c9c957f9d692fda7e80c00eddaeb353649ec6f8f4"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
