{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# USHCN - SmallChunkedSporadic"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Input Parsing (for command line use)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-28T18:17:42.780840Z",
     "iopub.status.busy": "2022-09-28T18:17:42.780663Z",
     "iopub.status.idle": "2022-09-28T18:17:42.789225Z",
     "shell.execute_reply": "2022-09-28T18:17:42.788922Z",
     "shell.execute_reply.started": "2022-09-28T18:17:42.780828Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Namespace(quiet=False, run_id=None, config=None, epochs=100, fold=0, batch_size=32, learn_rate=0.001, betas=(0.9, 0.999), weight_decay=0.001, hidden_size=32, latent_size=64, kernel_init='skew-symmetric', note='', seed=None)\n"
     ]
    }
   ],
   "source": [
    "import argparse\n",
    "\n",
    "# fmt: off\n",
    "parser = argparse.ArgumentParser(description=\"Training Script for USHCN dataset.\")\n",
    "parser.add_argument(\"-q\",  \"--quiet\",        default=False,  const=True, help=\"kernel-inititialization\", nargs=\"?\")\n",
    "parser.add_argument(\"-r\",  \"--run_id\",       default=None,   type=str,   help=\"run_id\")\n",
    "parser.add_argument(\"-c\",  \"--config\",       default=None,   type=str,   help=\"load external config\", nargs=2)\n",
    "parser.add_argument(\"-e\",  \"--epochs\",       default=100,    type=int,   help=\"maximum epochs\")\n",
    "parser.add_argument(\"-f\",  \"--fold\",         default=0,      type=int,   help=\"fold number\")\n",
    "parser.add_argument(\"-bs\", \"--batch-size\",   default=32,     type=int,   help=\"batch-size\")\n",
    "parser.add_argument(\"-lr\", \"--learn-rate\",   default=0.001,  type=float, help=\"learn-rate\")\n",
    "parser.add_argument(\"-b\",  \"--betas\", default=(0.9, 0.999),  type=float, help=\"adam betas\", nargs=2)\n",
    "parser.add_argument(\"-wd\", \"--weight-decay\", default=0.001,  type=float, help=\"weight-decay\")\n",
    "parser.add_argument(\"-hs\", \"--hidden-size\",  default=32,    type=int,   help=\"hidden-size\")\n",
    "parser.add_argument(\"-ls\", \"--latent-size\",  default=64,    type=int,   help=\"latent-size\")\n",
    "parser.add_argument(\"-ki\", \"--kernel-init\",  default=\"skew-symmetric\",   help=\"kernel-inititialization\")\n",
    "parser.add_argument(\"-n\",  \"--note\",         default=\"\",     type=str,   help=\"Note that can be added\")\n",
    "parser.add_argument(\"-s\",  \"--seed\",         default=None,   type=int,   help=\"Set the random seed.\")\n",
    "# fmt: on\n",
    "\n",
    "try:\n",
    "    get_ipython().run_line_magic(\n",
    "        \"config\", \"InteractiveShell.ast_node_interactivity='last_expr_or_assign'\"\n",
    "    )\n",
    "except NameError:\n",
    "    ARGS = parser.parse_args()\n",
    "else:\n",
    "    ARGS = parser.parse_args(\"\")\n",
    "\n",
    "print(ARGS)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load from config file if provided"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-28T18:17:43.488278Z",
     "iopub.status.busy": "2022-09-28T18:17:43.488105Z",
     "iopub.status.idle": "2022-09-28T18:17:43.491113Z",
     "shell.execute_reply": "2022-09-28T18:17:43.490786Z",
     "shell.execute_reply.started": "2022-09-28T18:17:43.488266Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Namespace(quiet=False, run_id=None, config=None, epochs=100, fold=0, batch_size=32, learn_rate=0.001, betas=(0.9, 0.999), weight_decay=0.001, hidden_size=32, latent_size=64, kernel_init='skew-symmetric', note='', seed=None)\n"
     ]
    }
   ],
   "source": [
    "import yaml\n",
    "\n",
    "if ARGS.config is not None:\n",
    "    cfg_file, cfg_id = ARGS.config\n",
    "    with open(cfg_file, \"r\") as file:\n",
    "        cfg_dict = yaml.safe_load(file)\n",
    "        vars(ARGS).update(**cfg_dict[int(cfg_id)])\n",
    "\n",
    "print(ARGS)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Global Variables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-28T18:17:43.970561Z",
     "iopub.status.busy": "2022-09-28T18:17:43.970383Z",
     "iopub.status.idle": "2022-09-28T18:17:43.975125Z",
     "shell.execute_reply": "2022-09-28T18:17:43.974779Z",
     "shell.execute_reply.started": "2022-09-28T18:17:43.970549Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>.jp-OutputArea-prompt:empty {padding: 0; border: 0;}</style>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import logging\n",
    "import os\n",
    "import random\n",
    "import warnings\n",
    "from datetime import datetime\n",
    "from pathlib import Path\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torchinfo\n",
    "from IPython.core.display import HTML\n",
    "from torch import Tensor, jit\n",
    "from tqdm.autonotebook import tqdm, trange\n",
    "\n",
    "os.environ[\"CUDA_LAUNCH_BLOCKING\"] = \"1\"\n",
    "torch.backends.cuda.matmul.allow_tf32 = True\n",
    "torch.backends.cudnn.allow_tf32 = True\n",
    "# torch.jit.enable_onednn_fusion(True)\n",
    "torch.backends.cudnn.benchmark = True\n",
    "# torch.multiprocessing.set_start_method('spawn')\n",
    "\n",
    "warnings.filterwarnings(action=\"ignore\", category=UserWarning, module=\"torch\")\n",
    "logging.basicConfig(level=logging.WARN)\n",
    "HTML(\"<style>.jp-OutputArea-prompt:empty {padding: 0; border: 0;}</style>\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hyperparameter choices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-28T18:17:44.625505Z",
     "iopub.status.busy": "2022-09-28T18:17:44.625330Z",
     "iopub.status.idle": "2022-09-28T18:17:44.630611Z",
     "shell.execute_reply": "2022-09-28T18:17:44.630154Z",
     "shell.execute_reply.started": "2022-09-28T18:17:44.625494Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'dataset': 'USHCN',\n",
       " 'model': 'LinODEnet',\n",
       " 'fold': 0,\n",
       " 'seed': None,\n",
       " 'max_epochs': 100,\n",
       " 'batch_size': 32,\n",
       " 'hidden_size': 32,\n",
       " 'latent_size': 64,\n",
       " 'kernel-initialization': 'skew-symmetric',\n",
       " 'lr': 0.001,\n",
       " 'betas': tensor([0.9000, 0.9990]),\n",
       " 'weight_decay': 0.001}"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "if ARGS.seed is not None:\n",
    "    torch.manual_seed(ARGS.seed)\n",
    "    random.seed(ARGS.seed)\n",
    "    np.random.seed(ARGS.seed)\n",
    "\n",
    "OPTIMIZER_CONFIG = {\n",
    "    \"lr\": ARGS.learn_rate,\n",
    "    \"betas\": torch.tensor(ARGS.betas),\n",
    "    \"weight_decay\": ARGS.weight_decay,\n",
    "}\n",
    "\n",
    "hparam_dict = {\n",
    "    \"dataset\": (DATASET := \"USHCN\"),\n",
    "    \"model\": (MODEL_NAME := \"LinODEnet\"),\n",
    "    \"fold\": ARGS.fold,\n",
    "    \"seed\": ARGS.seed,\n",
    "    \"max_epochs\": ARGS.epochs,\n",
    "    \"batch_size\": ARGS.batch_size,\n",
    "    \"hidden_size\": ARGS.hidden_size,\n",
    "    \"latent_size\": ARGS.latent_size,\n",
    "    \"kernel-initialization\": ARGS.kernel_init,\n",
    "} | OPTIMIZER_CONFIG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-28T18:17:44.849403Z",
     "iopub.status.busy": "2022-09-28T18:17:44.849224Z",
     "iopub.status.idle": "2022-09-28T18:17:44.853616Z",
     "shell.execute_reply": "2022-09-28T18:17:44.853268Z",
     "shell.execute_reply.started": "2022-09-28T18:17:44.849392Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "CONFIG_STR = f\"f={ARGS.fold}_bs={ARGS.batch_size}_lr={ARGS.learn_rate}_hs={ARGS.hidden_size}_ls={ARGS.latent_size}\"\n",
    "RUN_ID = ARGS.run_id or datetime.now().isoformat(timespec=\"seconds\")\n",
    "CFG_ID = 0 if ARGS.config is None else ARGS.config[1]\n",
    "HOME = Path.cwd()\n",
    "\n",
    "LOGGING_DIR = HOME / \"tensorboard\" / DATASET / MODEL_NAME / RUN_ID / CONFIG_STR\n",
    "CKPOINT_DIR = HOME / \"checkpoints\" / DATASET / MODEL_NAME / RUN_ID / CONFIG_STR\n",
    "RESULTS_DIR = HOME / \"results\" / DATASET / MODEL_NAME / RUN_ID\n",
    "LOGGING_DIR.mkdir(parents=True, exist_ok=True)\n",
    "CKPOINT_DIR.mkdir(parents=True, exist_ok=True)\n",
    "RESULTS_DIR.mkdir(parents=True, exist_ok=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize Task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-28T18:17:45.452285Z",
     "iopub.status.busy": "2022-09-28T18:17:45.452047Z",
     "iopub.status.idle": "2022-09-28T18:17:45.514900Z",
     "shell.execute_reply": "2022-09-28T18:17:45.514568Z",
     "shell.execute_reply.started": "2022-09-28T18:17:45.452272Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/rscholz/Projects/KIWI/tsdm/src/tsdm/datasets/base.py:285: UserWarning: File 'USHCN_DeBrouwer2019.parquet' failed to validate!File hash 'f70e1d40b069796f89eaa15e85df0970d9f2c5bb1b8f36fbac635a836959cf56' does not match reference 'bbd12ab38b4b7f9c69a07409c26967fe16af3b608daae9816312859199b5ce86'.𝗜𝗴𝗻𝗼𝗿𝗲 𝘁𝗵𝗶𝘀 𝘄𝗮𝗿𝗻𝗶𝗻𝗴 𝗶𝗳 𝘁𝗵𝗲 𝗳𝗶𝗹𝗲 𝗳𝗼𝗿𝗺𝗮𝘁 𝗶𝘀 𝗽𝗮𝗿𝗾𝘂𝗲𝘁.\n",
      "  warnings.warn(\n",
      "/home/rscholz/Projects/KIWI/tsdm/src/tsdm/datasets/base.py:285: UserWarning: File 'USHCN_DeBrouwer2019.parquet' failed to validate!File hash 'f70e1d40b069796f89eaa15e85df0970d9f2c5bb1b8f36fbac635a836959cf56' does not match reference 'bbd12ab38b4b7f9c69a07409c26967fe16af3b608daae9816312859199b5ce86'.𝗜𝗴𝗻𝗼𝗿𝗲 𝘁𝗵𝗶𝘀 𝘄𝗮𝗿𝗻𝗶𝗻𝗴 𝗶𝗳 𝘁𝗵𝗲 𝗳𝗶𝗹𝗲 𝗳𝗼𝗿𝗺𝗮𝘁 𝗶𝘀 𝗽𝗮𝗿𝗾𝘂𝗲𝘁.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "USHCN_DeBrouwer2019(test_metric=MSELoss)"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from tsdm.tasks import USHCN_DeBrouwer2019\n",
    "\n",
    "TASK = USHCN_DeBrouwer2019(normalize_time=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize DataLoaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-28T18:17:46.049535Z",
     "iopub.status.busy": "2022-09-28T18:17:46.049136Z",
     "iopub.status.idle": "2022-09-28T18:17:46.177690Z",
     "shell.execute_reply": "2022-09-28T18:17:46.177119Z",
     "shell.execute_reply.started": "2022-09-28T18:17:46.049520Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'train': <torch.utils.data.dataloader.DataLoader at 0x7f6daef897b0>,\n",
       " 'valid': <torch.utils.data.dataloader.DataLoader at 0x7f6caad73a00>,\n",
       " 'test': <torch.utils.data.dataloader.DataLoader at 0x7f6caad721d0>}"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from tsdm.tasks.ushcn_debrouwer2019 import ushcn_collate as task_collate_fn\n",
    "\n",
    "dloader_config_train = {\n",
    "    \"batch_size\": ARGS.batch_size,\n",
    "    \"shuffle\": True,\n",
    "    \"drop_last\": True,\n",
    "    \"pin_memory\": True,\n",
    "    \"num_workers\": 4,\n",
    "    \"collate_fn\": task_collate_fn,\n",
    "}\n",
    "\n",
    "dloader_config_infer = {\n",
    "    \"batch_size\": 256,\n",
    "    \"shuffle\": False,\n",
    "    \"drop_last\": False,\n",
    "    \"pin_memory\": True,\n",
    "    \"num_workers\": 0,\n",
    "    \"collate_fn\": task_collate_fn,\n",
    "}\n",
    "\n",
    "TRAIN_LOADER = TASK.get_dataloader((ARGS.fold, \"train\"), **dloader_config_train)\n",
    "INFER_LOADER = TASK.get_dataloader((ARGS.fold, \"train\"), **dloader_config_infer)\n",
    "VALID_LOADER = TASK.get_dataloader((ARGS.fold, \"valid\"), **dloader_config_infer)\n",
    "TEST_LOADER = TASK.get_dataloader((ARGS.fold, \"test\"), **dloader_config_infer)\n",
    "EVAL_LOADERS = {\"train\": INFER_LOADER, \"valid\": VALID_LOADER, \"test\": TEST_LOADER}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-28T18:17:46.907551Z",
     "iopub.status.busy": "2022-09-28T18:17:46.907379Z",
     "iopub.status.idle": "2022-09-28T18:17:46.913866Z",
     "shell.execute_reply": "2022-09-28T18:17:46.913501Z",
     "shell.execute_reply.started": "2022-09-28T18:17:46.907540Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch.jit.ScriptFunction at 0x7f6cf08cb6a0>"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def MSE(y: Tensor, yhat: Tensor) -> Tensor:\n",
    "    return torch.mean((y - yhat) ** 2)\n",
    "\n",
    "\n",
    "def MAE(y: Tensor, yhat: Tensor) -> Tensor:\n",
    "    return torch.mean(torch.abs(y - yhat))\n",
    "\n",
    "\n",
    "def RMSE(y: Tensor, yhat: Tensor) -> Tensor:\n",
    "    return torch.sqrt(torch.mean((y - yhat) ** 2))\n",
    "\n",
    "\n",
    "METRICS = {\n",
    "    \"RMSE\": jit.script(RMSE),\n",
    "    \"MSE\": jit.script(MSE),\n",
    "    \"MAE\": jit.script(MAE),\n",
    "}\n",
    "LOSS = jit.script(MSE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-28T18:17:47.542582Z",
     "iopub.status.busy": "2022-09-28T18:17:47.542326Z",
     "iopub.status.idle": "2022-09-28T18:17:47.671414Z",
     "shell.execute_reply": "2022-09-28T18:17:47.671035Z",
     "shell.execute_reply.started": "2022-09-28T18:17:47.542566Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "=================================================================\n",
       "Layer (type:depth-idx)                   Param #\n",
       "=================================================================\n",
       "LinODEnet                                64\n",
       "├─ConcatEmbedding: 1-1                   32\n",
       "├─ResNet: 1-2                            --\n",
       "│    └─ResNetBlock: 2-1                  --\n",
       "│    │    └─ReverseDense: 3-1            4,160\n",
       "│    │    └─ReverseDense: 3-2            4,160\n",
       "│    │    └─ReZeroCell: 3-3              1\n",
       "│    └─ResNetBlock: 2-2                  --\n",
       "│    │    └─ReverseDense: 3-4            4,160\n",
       "│    │    └─ReverseDense: 3-5            4,160\n",
       "│    │    └─ReZeroCell: 3-6              1\n",
       "│    └─ResNetBlock: 2-3                  --\n",
       "│    │    └─ReverseDense: 3-7            4,160\n",
       "│    │    └─ReverseDense: 3-8            4,160\n",
       "│    │    └─ReZeroCell: 3-9              1\n",
       "│    └─ResNetBlock: 2-4                  --\n",
       "│    │    └─ReverseDense: 3-10           4,160\n",
       "│    │    └─ReverseDense: 3-11           4,160\n",
       "│    │    └─ReZeroCell: 3-12             1\n",
       "│    └─ResNetBlock: 2-5                  --\n",
       "│    │    └─ReverseDense: 3-13           4,160\n",
       "│    │    └─ReverseDense: 3-14           4,160\n",
       "│    │    └─ReZeroCell: 3-15             1\n",
       "├─LinODECell: 1-3                        4,097\n",
       "├─ResNet: 1-4                            --\n",
       "│    └─ResNetBlock: 2-6                  --\n",
       "│    │    └─ReverseDense: 3-16           4,160\n",
       "│    │    └─ReverseDense: 3-17           4,160\n",
       "│    │    └─ReZeroCell: 3-18             1\n",
       "│    └─ResNetBlock: 2-7                  --\n",
       "│    │    └─ReverseDense: 3-19           4,160\n",
       "│    │    └─ReverseDense: 3-20           4,160\n",
       "│    │    └─ReZeroCell: 3-21             1\n",
       "│    └─ResNetBlock: 2-8                  --\n",
       "│    │    └─ReverseDense: 3-22           4,160\n",
       "│    │    └─ReverseDense: 3-23           4,160\n",
       "│    │    └─ReZeroCell: 3-24             1\n",
       "│    └─ResNetBlock: 2-9                  --\n",
       "│    │    └─ReverseDense: 3-25           4,160\n",
       "│    │    └─ReverseDense: 3-26           4,160\n",
       "│    │    └─ReZeroCell: 3-27             1\n",
       "│    └─ResNetBlock: 2-10                 --\n",
       "│    │    └─ReverseDense: 3-28           4,160\n",
       "│    │    └─ReverseDense: 3-29           4,160\n",
       "│    │    └─ReZeroCell: 3-30             1\n",
       "├─ConcatProjection: 1-5                  32\n",
       "├─SequentialFilter: 1-6                  --\n",
       "│    └─LinearFilter: 2-11                2,051\n",
       "│    └─NonLinearFilter: 2-12             2,049\n",
       "│    │    └─Sequential: 3-31             2,048\n",
       "│    └─NonLinearFilter: 2-13             2,049\n",
       "│    │    └─Sequential: 3-32             2,048\n",
       "=================================================================\n",
       "Total params: 97,680\n",
       "Trainable params: 97,679\n",
       "Non-trainable params: 1\n",
       "================================================================="
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from linodenet.models import LinODEnet, ResNet, embeddings, filters, system\n",
    "\n",
    "MODEL_CONFIG = {\n",
    "    \"__name__\": \"LinODEnet\",\n",
    "    \"input_size\": TASK.dataset.shape[-1],\n",
    "    \"hidden_size\": ARGS.hidden_size,\n",
    "    \"latent_size\": ARGS.latent_size,\n",
    "    \"Filter\": filters.SequentialFilter.HP | {\"autoregressive\": True},\n",
    "    \"System\": system.LinODECell.HP | {\"kernel_initialization\": ARGS.kernel_init},\n",
    "    \"Encoder\": ResNet.HP,\n",
    "    \"Decoder\": ResNet.HP,\n",
    "    \"Embedding\": embeddings.ConcatEmbedding.HP,\n",
    "    \"Projection\": embeddings.ConcatProjection.HP,\n",
    "}\n",
    "\n",
    "MODEL = LinODEnet(**MODEL_CONFIG).to(DEVICE)\n",
    "MODEL = torch.jit.script(MODEL)\n",
    "torchinfo.summary(MODEL)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Warm-Up"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def predict_fn(model, batch) -> tuple[Tensor, Tensor]:\n",
    "    \"\"\"Get targets and predictions.\"\"\"\n",
    "    T, X, M, _, Y, MY = (tensor.to(DEVICE) for tensor in batch)\n",
    "    YHAT = model(T, X)\n",
    "    return Y[MY], YHAT[M]\n",
    "\n",
    "\n",
    "batch = next(iter(TRAIN_LOADER))\n",
    "MODEL.zero_grad(set_to_none=True)\n",
    "\n",
    "# Forward\n",
    "Y, YHAT = predict_fn(MODEL, batch)\n",
    "\n",
    "# Backward\n",
    "R = LOSS(Y, YHAT)\n",
    "assert torch.isfinite(R).item(), \"Model Collapsed!\"\n",
    "R.backward()\n",
    "\n",
    "# Reset\n",
    "MODEL.zero_grad(set_to_none=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize Optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.optim import AdamW\n",
    "\n",
    "OPTIMIZER = AdamW(MODEL.parameters(), **OPTIMIZER_CONFIG)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize Logging"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from torch.utils.tensorboard import SummaryWriter\n",
    "\n",
    "from tsdm.logutils import StandardLogger\n",
    "\n",
    "WRITER = SummaryWriter(LOGGING_DIR)\n",
    "LOGGER = StandardLogger(\n",
    "    writer=WRITER,\n",
    "    model=MODEL,\n",
    "    optimizer=OPTIMIZER,\n",
    "    metrics=METRICS,\n",
    "    dataloaders=EVAL_LOADERS,\n",
    "    hparam_dict=hparam_dict,\n",
    "    checkpoint_dir=CKPOINT_DIR,\n",
    "    predict_fn=predict_fn,\n",
    "    results_dir=RESULTS_DIR,\n",
    ")\n",
    "LOGGER.log_epoch_end(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "total_num_batches = 0\n",
    "for epoch in trange(1, ARGS.epochs, desc=\"Epoch\", position=0):\n",
    "    for batch in tqdm(\n",
    "        TRAIN_LOADER, desc=\"Batch\", leave=False, position=1, disable=ARGS.quiet\n",
    "    ):\n",
    "        total_num_batches += 1\n",
    "        MODEL.zero_grad(set_to_none=True)\n",
    "\n",
    "        # Forward\n",
    "        Y, YHAT = predict_fn(MODEL, batch)\n",
    "        R = LOSS(Y, YHAT)\n",
    "        assert torch.isfinite(R).item(), \"Model Collapsed!\"\n",
    "\n",
    "        # Backward\n",
    "        R.backward()\n",
    "        OPTIMIZER.step()\n",
    "\n",
    "        # Logging\n",
    "        LOGGER.log_batch_end(total_num_batches, targets=Y, predics=YHAT)\n",
    "    LOGGER.log_epoch_end(epoch)\n",
    "\n",
    "LOGGER.log_history(CFG_ID)\n",
    "LOGGER.log_hparams(CFG_ID)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "vscode": {
   "interpreter": {
    "hash": "bacf51866a64549aa5a449210e6b149172b67532fa1d272ce9d18986a6d5ce6e"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
