{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MIMIC-IV"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Input Parsing (for command line use)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import argparse\n",
    "\n",
    "# fmt: off\n",
    "parser = argparse.ArgumentParser(description=\"Training Script for MIMIC-IV dataset.\")\n",
    "parser.add_argument(\"-q\",  \"--quiet\",        default=False,  const=True, help=\"kernel-inititialization\", nargs=\"?\")\n",
    "parser.add_argument(\"-r\",  \"--run_id\",       default=None,   type=str,   help=\"run_id\")\n",
    "parser.add_argument(\"-c\",  \"--config\",       default=None,   type=str,   help=\"load external config\", nargs=2)\n",
    "parser.add_argument(\"-e\",  \"--epochs\",       default=50,    type=int,   help=\"maximum epochs\")\n",
    "parser.add_argument(\"-f\",  \"--fold\",         default=0,      type=int,   help=\"fold number\")\n",
    "parser.add_argument(\"-bs\", \"--batch-size\",   default=64,     type=int,   help=\"batch-size\")\n",
    "parser.add_argument(\"-lr\", \"--learn-rate\",   default=0.001,  type=float, help=\"learn-rate\")\n",
    "parser.add_argument(\"-b\",  \"--betas\", default=(0.9, 0.999),  type=float, help=\"adam betas\", nargs=2)\n",
    "parser.add_argument(\"-wd\", \"--weight-decay\", default=0.001,  type=float, help=\"weight-decay\")\n",
    "parser.add_argument(\"-hs\", \"--hidden-size\",  default=128,    type=int,   help=\"hidden-size\")\n",
    "parser.add_argument(\"-ls\", \"--latent-size\",  default=256,    type=int,   help=\"latent-size\")\n",
    "parser.add_argument(\"-ki\", \"--kernel-init\",  default=\"skew-symmetric\",   help=\"kernel-inititialization\")\n",
    "parser.add_argument(\"-n\",  \"--note\",         default=\"\",     type=str,   help=\"Note that can be added\")\n",
    "parser.add_argument(\"-s\",  \"--seed\",         default=None,   type=int,   help=\"Set the random seed.\")\n",
    "# fmt: on\n",
    "\n",
    "try:\n",
    "    get_ipython().run_line_magic(\n",
    "        \"config\", \"InteractiveShell.ast_node_interactivity='last_expr_or_assign'\"\n",
    "    )\n",
    "except NameError:\n",
    "    ARGS = parser.parse_args()\n",
    "else:\n",
    "    ARGS = parser.parse_args(\"\")\n",
    "\n",
    "print(ARGS)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load from config file if provided"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import yaml\n",
    "\n",
    "if ARGS.config is not None:\n",
    "    cfg_file, cfg_id = ARGS.config\n",
    "    with open(cfg_file, \"r\") as file:\n",
    "        cfg_dict = yaml.safe_load(file)\n",
    "        vars(ARGS).update(**cfg_dict[int(cfg_id)])\n",
    "\n",
    "print(ARGS)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Global Variables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import logging\n",
    "import os\n",
    "import random\n",
    "import warnings\n",
    "from datetime import datetime\n",
    "from pathlib import Path\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torchinfo\n",
    "from IPython.core.display import HTML\n",
    "from torch import Tensor, jit\n",
    "from tqdm.autonotebook import tqdm, trange\n",
    "\n",
    "os.environ[\"CUDA_LAUNCH_BLOCKING\"] = \"1\"\n",
    "torch.backends.cuda.matmul.allow_tf32 = True\n",
    "torch.backends.cudnn.allow_tf32 = True\n",
    "# torch.jit.enable_onednn_fusion(True)\n",
    "torch.backends.cudnn.benchmark = True\n",
    "# torch.multiprocessing.set_start_method('spawn')\n",
    "\n",
    "warnings.filterwarnings(action=\"ignore\", category=UserWarning, module=\"torch\")\n",
    "logging.basicConfig(level=logging.WARN)\n",
    "HTML(\"<style>.jp-OutputArea-prompt:empty {padding: 0; border: 0;}</style>\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hyperparameter choices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "if ARGS.seed is not None:\n",
    "    torch.manual_seed(ARGS.seed)\n",
    "    random.seed(ARGS.seed)\n",
    "    np.random.seed(ARGS.seed)\n",
    "\n",
    "OPTIMIZER_CONFIG = {\n",
    "    \"lr\": ARGS.learn_rate,\n",
    "    \"betas\": torch.tensor(ARGS.betas),\n",
    "    \"weight_decay\": ARGS.weight_decay,\n",
    "}\n",
    "\n",
    "hparam_dict = {\n",
    "    \"dataset\": (DATASET := \"MIMIC-IV\"),\n",
    "    \"model\": (MODEL_NAME := \"LinODEnet\"),\n",
    "    \"fold\": ARGS.fold,\n",
    "    \"seed\": ARGS.seed,\n",
    "    \"max_epochs\": ARGS.epochs,\n",
    "    \"batch_size\": ARGS.batch_size,\n",
    "    \"hidden_size\": ARGS.hidden_size,\n",
    "    \"latent_size\": ARGS.latent_size,\n",
    "    \"kernel-initialization\": ARGS.kernel_init,\n",
    "} | OPTIMIZER_CONFIG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "CONFIG_STR = f\"f={ARGS.fold}_bs={ARGS.batch_size}_lr={ARGS.learn_rate}_hs={ARGS.hidden_size}_ls={ARGS.latent_size}\"\n",
    "RUN_ID = ARGS.run_id or datetime.now().isoformat(timespec=\"seconds\")\n",
    "CFG_ID = 0 if ARGS.config is None else ARGS.config[1]\n",
    "HOME = Path.cwd()\n",
    "\n",
    "LOGGING_DIR = HOME / \"tensorboard\" / DATASET / MODEL_NAME / RUN_ID / CONFIG_STR\n",
    "CKPOINT_DIR = HOME / \"checkpoints\" / DATASET / MODEL_NAME / RUN_ID / CONFIG_STR\n",
    "RESULTS_DIR = HOME / \"results\" / DATASET / MODEL_NAME / RUN_ID\n",
    "LOGGING_DIR.mkdir(parents=True, exist_ok=True)\n",
    "CKPOINT_DIR.mkdir(parents=True, exist_ok=True)\n",
    "RESULTS_DIR.mkdir(parents=True, exist_ok=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize Task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from tsdm.tasks.mimic_iv_bilos2021 import MIMIC_IV_Bilos2021\n",
    "\n",
    "TASK = MIMIC_IV_Bilos2021()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize DataLoaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from tsdm.tasks.mimic_iv_bilos2021 import mimic_collate as task_collate_fn\n",
    "\n",
    "dloader_config_train = {\n",
    "    \"batch_size\": ARGS.batch_size,\n",
    "    \"shuffle\": True,\n",
    "    \"drop_last\": True,\n",
    "    \"pin_memory\": True,\n",
    "    \"num_workers\": 4,\n",
    "    \"collate_fn\": task_collate_fn,\n",
    "}\n",
    "\n",
    "dloader_config_infer = {\n",
    "    \"batch_size\": 256,\n",
    "    \"shuffle\": False,\n",
    "    \"drop_last\": False,\n",
    "    \"pin_memory\": True,\n",
    "    \"num_workers\": 0,\n",
    "    \"collate_fn\": task_collate_fn,\n",
    "}\n",
    "\n",
    "TRAIN_LOADER = TASK.get_dataloader((ARGS.fold, \"train\"), **dloader_config_train)\n",
    "INFER_LOADER = TASK.get_dataloader((ARGS.fold, \"train\"), **dloader_config_infer)\n",
    "VALID_LOADER = TASK.get_dataloader((ARGS.fold, \"valid\"), **dloader_config_infer)\n",
    "TEST_LOADER = TASK.get_dataloader((ARGS.fold, \"test\"), **dloader_config_infer)\n",
    "EVAL_LOADERS = {\"train\": INFER_LOADER, \"valid\": VALID_LOADER, \"test\": TEST_LOADER}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def MSE(y: Tensor, yhat: Tensor) -> Tensor:\n",
    "    return torch.mean((y - yhat) ** 2)\n",
    "\n",
    "\n",
    "def MAE(y: Tensor, yhat: Tensor) -> Tensor:\n",
    "    return torch.mean(torch.abs(y - yhat))\n",
    "\n",
    "\n",
    "def RMSE(y: Tensor, yhat: Tensor) -> Tensor:\n",
    "    return torch.sqrt(torch.mean((y - yhat) ** 2))\n",
    "\n",
    "\n",
    "METRICS = {\n",
    "    \"RMSE\": jit.script(RMSE),\n",
    "    \"MSE\": jit.script(MSE),\n",
    "    \"MAE\": jit.script(MAE),\n",
    "}\n",
    "LOSS = jit.script(MSE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from linodenet.models import LinODEnet, ResNet, embeddings, filters, system\n",
    "\n",
    "MODEL_CONFIG = {\n",
    "    \"__name__\": \"LinODEnet\",\n",
    "    \"input_size\": TASK.dataset.shape[-1],\n",
    "    \"hidden_size\": ARGS.hidden_size,\n",
    "    \"latent_size\": ARGS.latent_size,\n",
    "    \"Filter\": filters.SequentialFilter.HP | {\"autoregressive\": True},\n",
    "    \"System\": system.LinODECell.HP | {\"kernel_initialization\": ARGS.kernel_init},\n",
    "    \"Encoder\": ResNet.HP,\n",
    "    \"Decoder\": ResNet.HP,\n",
    "    \"Embedding\": embeddings.ConcatEmbedding.HP,\n",
    "    \"Projection\": embeddings.ConcatProjection.HP,\n",
    "}\n",
    "\n",
    "MODEL = LinODEnet(**MODEL_CONFIG).to(DEVICE)\n",
    "MODEL = torch.jit.script(MODEL)\n",
    "torchinfo.summary(MODEL)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Warm-Up"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def predict_fn(model, batch) -> tuple[Tensor, Tensor]:\n",
    "    \"\"\"Get targets and predictions.\"\"\"\n",
    "    T, X, M, _, Y, MY = (tensor.to(DEVICE) for tensor in batch)\n",
    "    YHAT = model(T, X)\n",
    "    return Y[MY], YHAT[M]\n",
    "\n",
    "\n",
    "batch = next(iter(TRAIN_LOADER))\n",
    "MODEL.zero_grad(set_to_none=True)\n",
    "\n",
    "# Forward\n",
    "Y, YHAT = predict_fn(MODEL, batch)\n",
    "\n",
    "# Backward\n",
    "R = LOSS(Y, YHAT)\n",
    "assert torch.isfinite(R).item(), \"Model Collapsed!\"\n",
    "R.backward()\n",
    "\n",
    "# Reset\n",
    "MODEL.zero_grad(set_to_none=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize Optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from torch.optim import AdamW\n",
    "\n",
    "OPTIMIZER = AdamW(MODEL.parameters(), **OPTIMIZER_CONFIG)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize Logging"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from torch.utils.tensorboard import SummaryWriter\n",
    "\n",
    "from tsdm.logutils import StandardLogger\n",
    "\n",
    "WRITER = SummaryWriter(LOGGING_DIR)\n",
    "LOGGER = StandardLogger(\n",
    "    writer=WRITER,\n",
    "    model=MODEL,\n",
    "    optimizer=OPTIMIZER,\n",
    "    metrics=METRICS,\n",
    "    dataloaders=EVAL_LOADERS,\n",
    "    hparam_dict=hparam_dict,\n",
    "    checkpoint_dir=CKPOINT_DIR,\n",
    "    predict_fn=predict_fn,\n",
    "    results_dir=RESULTS_DIR,\n",
    ")\n",
    "LOGGER.log_epoch_end(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "total_num_batches = 0\n",
    "for epoch in trange(1, ARGS.epochs, desc=\"Epoch\", position=0):\n",
    "    for batch in tqdm(\n",
    "        TRAIN_LOADER, desc=\"Batch\", leave=False, position=1, disable=ARGS.quiet\n",
    "    ):\n",
    "        total_num_batches += 1\n",
    "        MODEL.zero_grad(set_to_none=True)\n",
    "\n",
    "        # Forward\n",
    "        Y, YHAT = predict_fn(MODEL, batch)\n",
    "        R = LOSS(Y, YHAT)\n",
    "        assert torch.isfinite(R).item(), \"Model Collapsed!\"\n",
    "\n",
    "        # Backward\n",
    "        R.backward()\n",
    "        OPTIMIZER.step()\n",
    "\n",
    "        # Logging\n",
    "        LOGGER.log_batch_end(total_num_batches, targets=Y, predics=YHAT)\n",
    "    LOGGER.log_epoch_end(epoch)\n",
    "\n",
    "LOGGER.log_history(CFG_ID)\n",
    "LOGGER.log_hparams(CFG_ID)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "vscode": {
   "interpreter": {
    "hash": "ec22ec550e91a21331b7b83c9c957f9d692fda7e80c00eddaeb353649ec6f8f4"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
