{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ffb61892-8908-4369-a18b-27323cfb3e48",
   "metadata": {},
   "outputs": [],
   "source": [
    "# | default_exp model.optimization.nn.tsc.vittsc.face_detection_training_mask_tune\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2aeed576-9a64-45e1-8c49-307b94f252fe",
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "# declare a list tasks whose products you want to use as inputs\n",
    "upstream = ['tabular_to_timeseries_face_detection']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "756ee18a-123b-4428-b5da-6cf8de54c1e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "upstream = {\n",
    "    \"tabular_to_timeseries_face_detection\": {\n",
    "        \"nb\": \"/home/ubuntu/vitmtsc_nbdev/output/301_feature_preprocessing.face_detection.tabular_to_timeseries.html\",\n",
    "        \"FaceDetection_TRAIN_MODEL_INPUT\": \"/home/ubuntu/vitmtsc_nbdev/output/FaceDetection/target_encoding-nn/train\",\n",
    "        \"FaceDetection_VALID_MODEL_INPUT\": \"/home/ubuntu/vitmtsc_nbdev/output/FaceDetection/target_encoding-nn/valid\",\n",
    "        \"FaceDetection_TEST_MODEL_INPUT\": \"/home/ubuntu/vitmtsc_nbdev/output/FaceDetection/target_encoding-nn/test\",\n",
    "    }\n",
    "}\n",
    "product = {\n",
    "    \"nb\": \"/home/ubuntu/vitmtsc_nbdev/output/401_model.optimization.nn.tsc.vittsc.face_detection_training_mask_tune.html\",\n",
    "    \"FaceDetection_MODEL_TUNE_OUTPUT\": \"/home/ubuntu/vitmtsc_nbdev/output/FaceDetection/ray_results\",\n",
    "    \"FaceDetection_MODEL_TRAINING_OUTPUT\": \"/home/ubuntu/vitmtsc_nbdev/output/FaceDetection/experiments_result\",\n",
    "    \"FaceDetection_MODEL_TRAINING_CHECKPOINT_OUTPUT\": \"/home/ubuntu/vitmtsc_nbdev/output/FaceDetection/experiments_result/checkpoint\",\n",
    "    \"FaceDetection_BEST_MODEL\": \"/home/ubuntu/vitmtsc_nbdev/output/FaceDetection/experiments_result/best_model.ckpt\",\n",
    "    \"FaceDetection_BEST_MODEL_CONFIG\": \"/home/ubuntu/vitmtsc_nbdev/output/FaceDetection/experiments_result/best_model_config.json\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd0246d5-a032-40ba-831e-896da94df143",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33ed5f57",
   "metadata": {
    "lines_to_next_cell": 0
   },
   "outputs": [],
   "source": [
    "#| export\n",
    "import sys\n",
    "import pathlib as p\n",
    "\n",
    "def is_running_from_ipython():\n",
    "    from IPython import get_ipython\n",
    "    return get_ipython() is not None\n",
    "\n",
    "if not is_running_from_ipython() and __package__ is None:\n",
    "    DIR = p.Path(__file__).resolve().parent\n",
    "    sys.path.insert(0, str(DIR.parent))\n",
    "    __package__ = DIR.name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1021f6aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "import torch\n",
    "import pytorch_lightning\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import os\n",
    "import math \n",
    "\n",
    "from torch.nn import functional as F\n",
    "from torch import nn\n",
    "from torchmetrics import functional as FM\n",
    "from pytorch_lightning import loggers as pl_loggers\n",
    "from pytorch_lightning.callbacks import ModelCheckpoint\n",
    "from pytorch_lightning.callbacks import LearningRateMonitor\n",
    "from petastorm import make_batch_reader\n",
    "from petastorm.pytorch import DataLoader\n",
    "from einops import rearrange, repeat\n",
    "from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n",
    "from pytorch_lightning.loggers import TensorBoardLogger\n",
    "import pytorch_lightning as pl"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "168b9d1c",
   "metadata": {},
   "source": [
    "## Vision Transformer for Multivariate Time-Series Classification (VitMTSC) Model Training with Masking - Hyperparameter search\n",
    "\n",
    "> Classification Task\n",
    "\n",
    "> Data Loader Module\n",
    "\n",
    "> Hyperparameter Search"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70008e38-2085-453a-a79f-e716eb49699f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# | export\n",
    "DATASET_NAME = \"FaceDetection\"\n",
    "NUM_TARGET = 2\n",
    "SEQUENCE_LENGTH = 62\n",
    "NUMBER_OF_FEATURES = 144\n",
    "NUM_WORKERS = 1\n",
    "NUM_GPUS = 1\n",
    "MAX_EPOCHS = 50\n",
    "TUNE_EPOCHS = 5\n",
    "NUM_SAMPLES = 1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "017dc42a-3310-4a8c-b1f3-d31d1566b8d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "import dask_cudf\n",
    "import numpy as np\n",
    "import sklearn.utils.class_weight\n",
    "\n",
    "def get_train_dataset_size():\n",
    "    gdf = dask_cudf.read_parquet(upstream['tabular_to_timeseries_face_detection']['FaceDetection_TRAIN_MODEL_INPUT'], columns = ['case_id'])\n",
    "    return gdf.case_id.nunique().compute()\n",
    "\n",
    "def get_valid_dataset_size():\n",
    "    gdf = dask_cudf.read_parquet(upstream['tabular_to_timeseries_face_detection']['FaceDetection_VALID_MODEL_INPUT'], columns = ['case_id'])\n",
    "    return gdf.case_id.nunique().compute()\n",
    "\n",
    "def get_test_dataset_size():\n",
    "    gdf = dask_cudf.read_parquet(upstream['tabular_to_timeseries_face_detection']['FaceDetection_TEST_MODEL_INPUT'], columns = ['case_id'])\n",
    "    return gdf.case_id.nunique().compute()\n",
    "\n",
    "def get_class_weight():\n",
    "    train_gdf = dask_cudf.read_parquet(upstream['tabular_to_timeseries_face_detection']['FaceDetection_TRAIN_MODEL_INPUT'], columns = ['case_id', 'class_vals'])\n",
    "    y_train = train_gdf['class_vals'].compute().to_numpy()\n",
    "    class_weight = sklearn.utils.class_weight.compute_class_weight('balanced', classes = np.unique(y_train), y = y_train)\n",
    "    class_weight = class_weight/2\n",
    "    print(f'class_weight: {class_weight}')\n",
    "    return class_weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca526f2b-c1fb-47cf-9015-9cd2cbf82f22",
   "metadata": {},
   "outputs": [],
   "source": [
    "get_train_dataset_size(), get_valid_dataset_size(), get_test_dataset_size(), get_class_weight()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6bf025e0",
   "metadata": {},
   "source": [
    "### 1. Model Definition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfa4328d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# | export\n",
    "class Residual(nn.Module):\n",
    "    def __init__(self, fn):\n",
    "        super().__init__()\n",
    "\n",
    "        self.fn = fn\n",
    "\n",
    "    def forward(self, x, **kwargs):\n",
    "        return self.fn(x, **kwargs) + x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc5e1911",
   "metadata": {},
   "outputs": [],
   "source": [
    "# | export\n",
    "class PreNorm(nn.Module):\n",
    "    def __init__(self, dim, fn):\n",
    "        super().__init__()\n",
    "\n",
    "        self.norm = nn.LayerNorm(dim)\n",
    "        self.fn = fn\n",
    "\n",
    "    def forward(self, x, **kwargs):\n",
    "        return self.fn(self.norm(x), **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1bd6cb12",
   "metadata": {},
   "outputs": [],
   "source": [
    "# | export\n",
    "class FeedForward(nn.Module):\n",
    "    def __init__(self, dim, hidden_dim, dropout=0.0):\n",
    "        super().__init__()\n",
    "\n",
    "        self.net = nn.Sequential(\n",
    "            nn.Linear(dim, hidden_dim),\n",
    "            nn.GELU(),\n",
    "            nn.Dropout(dropout),\n",
    "            nn.Linear(hidden_dim, dim),\n",
    "            nn.Dropout(dropout),\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.net(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5b48fb0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# | export\n",
    "class Attention(nn.Module):\n",
    "    def __init__(self, dim, heads=10, dim_head=32, dropout=0.0):\n",
    "        super().__init__()\n",
    "\n",
    "        inner_dim = dim_head * heads\n",
    "        self.heads = heads\n",
    "        self.scale = dim_head**-0.5\n",
    "\n",
    "        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)\n",
    "        self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))\n",
    "\n",
    "        self.attn_gradients = None\n",
    "        self.attention_map = None\n",
    "\n",
    "    def save_attn_gradients(self, attn_gradients):\n",
    "        self.attn_gradients = attn_gradients\n",
    "\n",
    "    def get_attn_gradients(self):\n",
    "        return self.attn_gradients\n",
    "\n",
    "    def save_attention_map(self, attention_map):\n",
    "        self.attention_map = attention_map\n",
    "\n",
    "    def get_attention_map(self):\n",
    "        return self.attention_map\n",
    "\n",
    "    def forward(self, x, mask=None, register_hook=False):\n",
    "        b, n, _, h = *x.shape, self.heads\n",
    "        qkv = self.to_qkv(x).chunk(3, dim=-1)\n",
    "        q, k, v = map(lambda t: rearrange(t, \"b n (h d) -> b h n d\", h=h), qkv)\n",
    "\n",
    "        dots = torch.einsum(\"bhid,bhjd->bhij\", q, k) * self.scale\n",
    "        mask_value = -torch.finfo(dots.dtype).max\n",
    "\n",
    "        # print('mask1.shape', mask.shape)\n",
    "        if mask is not None:\n",
    "            # mask = F.pad(mask, (1, 0), value = True)\n",
    "            mask = F.pad(mask.flatten(1), (1, 0), value=True)\n",
    "            mask = mask.unsqueeze(1).unsqueeze(2)\n",
    "\n",
    "            # print('mask2.shape', mask.shape)\n",
    "            # print('mask:', mask)\n",
    "            assert mask.shape[-1] == dots.shape[-1], \"mask has incorrect dimensions\"\n",
    "            dots.masked_fill_(mask == 0.0, mask_value)\n",
    "            del mask\n",
    "\n",
    "        attn = dots.softmax(dim=-1)\n",
    "        # print('attn.shape: ', attn.shape)\n",
    "        # print('attn: ', attn)\n",
    "\n",
    "        out = torch.einsum(\"bhij,bhjd->bhid\", attn, v)\n",
    "\n",
    "        if register_hook:\n",
    "            self.save_attention_map(attn)\n",
    "            attn.register_hook(self.save_attn_gradients)\n",
    "\n",
    "        out = rearrange(out, \"b h n d -> b n (h d)\")\n",
    "        out = self.to_out(out)\n",
    "        # print('out.shape: ', out.shape)\n",
    "        # print('out: ', out)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de09b9e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# | export\n",
    "class Transformer(nn.Module):\n",
    "    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout):\n",
    "        super().__init__()\n",
    "\n",
    "        self.layers = nn.ModuleList([])\n",
    "        for _ in range(depth):\n",
    "            self.layers.append(\n",
    "                nn.ModuleList(\n",
    "                    [\n",
    "                        Residual(\n",
    "                            PreNorm(\n",
    "                                dim,\n",
    "                                Attention(\n",
    "                                    dim, heads=heads, dim_head=dim_head, dropout=dropout\n",
    "                                ),\n",
    "                            )\n",
    "                        ),\n",
    "                        Residual(\n",
    "                            PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))\n",
    "                        ),\n",
    "                    ]\n",
    "                )\n",
    "            )\n",
    "\n",
    "    def forward(self, x, mask=None, register_hook=False):\n",
    "        for attn, ff in self.layers:\n",
    "            x = attn(x, mask=mask, register_hook=register_hook)\n",
    "            x = ff(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "665a3104",
   "metadata": {},
   "outputs": [],
   "source": [
    "# | export\n",
    "def petastorm_collate_fn(rows):\n",
    "    data_df = pd.DataFrame(rows)\n",
    "    # print(f'data_df.shape: {data_df.shape}') # data_df.shape: (2, 4402) 22 * 200 + 2\n",
    "\n",
    "    case_id_df = data_df.iloc[\n",
    "        :,\n",
    "        NUMBER_OF_FEATURES * SEQUENCE_LENGTH\n",
    "        + 1 : NUMBER_OF_FEATURES * SEQUENCE_LENGTH\n",
    "        + 2,\n",
    "    ]  # NUMBER_OF_FEATURES*SEQUENCE_LENGTH+1:NUMBER_OF_FEATURES*SEQUENCE_LENGTH+2\n",
    "    case_id_tensor = torch.tensor(case_id_df.values.astype(np.float64))\n",
    "\n",
    "    target_df = data_df.iloc[\n",
    "        :,\n",
    "        NUMBER_OF_FEATURES * SEQUENCE_LENGTH\n",
    "        + 0 : NUMBER_OF_FEATURES * SEQUENCE_LENGTH\n",
    "        + 1,\n",
    "    ]  # NUMBER_OF_FEATURES*SEQUENCE_LENGTH+0:NUMBER_OF_FEATURES*SEQUENCE_LENGTH+1\n",
    "    target_tensor = torch.tensor(target_df.values.astype(np.float32))\n",
    "\n",
    "    data_tensor_df = data_df.iloc[\n",
    "        :, 0 * SEQUENCE_LENGTH : NUMBER_OF_FEATURES * SEQUENCE_LENGTH\n",
    "    ]  # 0*SEQUENCE_LENGTH:NUMBER_OF_FEATURES*SEQUENCE_LENGTH\n",
    "    data_tensor = torch.tensor(data_tensor_df.values.astype(np.float32))\n",
    "    data_tensor = rearrange(data_tensor, \"t (b h)-> t h b\", h=SEQUENCE_LENGTH)\n",
    "\n",
    "    mask_df = data_df.iloc[\n",
    "        :, 0 * SEQUENCE_LENGTH : 1 * SEQUENCE_LENGTH\n",
    "    ]  # 0*SEQUENCE_LENGTH:1*SEQUENCE_LENGTH\n",
    "    mask_tensor = torch.tensor(mask_df.values.astype(np.float32))\n",
    "\n",
    "    return (\n",
    "        data_tensor,\n",
    "        target_tensor.squeeze(),\n",
    "        case_id_tensor.squeeze(),\n",
    "        mask_tensor.squeeze(),\n",
    "    )\n",
    "\n",
    "\n",
    "class VitMTSCPetastormDataModule(pl.LightningDataModule):\n",
    "    def __init__(\n",
    "        self,\n",
    "        config,\n",
    "        #data_dir=f\"file:///home/ubuntu/vitmtsc_nbdev/Multivariate_parquet/{DATASET_NAME}/target_encoding-nn/\",\n",
    "        num_workers=NUM_WORKERS,\n",
    "        transform_spec=None,\n",
    "        shard_count=NUM_GPUS,\n",
    "        num_epochs=MAX_EPOCHS,\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.train_files = f\"file://{upstream['tabular_to_timeseries_face_detection']['FaceDetection_TRAIN_MODEL_INPUT']}\"\n",
    "        self.valid_files = f\"file://{upstream['tabular_to_timeseries_face_detection']['FaceDetection_VALID_MODEL_INPUT']}\"\n",
    "        self.test_files = f\"file://{upstream['tabular_to_timeseries_face_detection']['FaceDetection_TEST_MODEL_INPUT']}\"\n",
    "        self.batch_size = config[\"batch_size\"]\n",
    "        self.num_workers = num_workers\n",
    "        self.transform_spec = transform_spec\n",
    "        self.shard_count = shard_count\n",
    "        self.num_epochs = num_epochs\n",
    "\n",
    "    def train_dataloader(self):\n",
    "\n",
    "        self.train_ds = make_batch_reader(\n",
    "            self.train_files,\n",
    "            workers_count=self.num_workers,\n",
    "            transform_spec=self.transform_spec,\n",
    "            cur_shard=int(os.environ[\"LOCAL_RANK\"]),\n",
    "            shard_count=self.shard_count,\n",
    "            num_epochs=self.num_epochs,\n",
    "        )\n",
    "        return DataLoader(\n",
    "            self.train_ds, batch_size=self.batch_size, collate_fn=petastorm_collate_fn\n",
    "        )\n",
    "\n",
    "    def val_dataloader(self):\n",
    "        print(\n",
    "            \"val_dataloader: local rank :\",\n",
    "            int(os.environ[\"LOCAL_RANK\"]),\n",
    "            \"shard count: \",\n",
    "            self.shard_count,\n",
    "        )\n",
    "        self.val_ds = make_batch_reader(\n",
    "            self.valid_files,\n",
    "            workers_count=self.num_workers,\n",
    "            transform_spec=self.transform_spec,\n",
    "            cur_shard=int(os.environ[\"LOCAL_RANK\"]),\n",
    "            shard_count=self.shard_count,\n",
    "            num_epochs=self.num_epochs,\n",
    "        )\n",
    "        return DataLoader(\n",
    "            self.val_ds, batch_size=self.batch_size, collate_fn=petastorm_collate_fn\n",
    "        )\n",
    "\n",
    "    def test_dataloader(self):\n",
    "        print(\n",
    "            \"test_dataloader: local rank :\",\n",
    "            int(os.environ[\"LOCAL_RANK\"]),\n",
    "            \"shard count: \",\n",
    "            self.shard_count,\n",
    "        )\n",
    "        self.test_ds = make_batch_reader(\n",
    "            self.test_files,\n",
    "            workers_count=self.num_workers,\n",
    "            transform_spec=self.transform_spec,\n",
    "            cur_shard=int(os.environ[\"LOCAL_RANK\"]),\n",
    "            shard_count=self.shard_count,\n",
    "            num_epochs=self.num_epochs,\n",
    "        )\n",
    "        return DataLoader(\n",
    "            self.test_ds, batch_size=self.batch_size, collate_fn=petastorm_collate_fn\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ed0d695",
   "metadata": {},
   "outputs": [],
   "source": [
    "# | export\n",
    "class VitTimeSeriesTransformer(pl.LightningModule):\n",
    "    def __init__(\n",
    "        self,\n",
    "        config,\n",
    "        c_in=NUMBER_OF_FEATURES,\n",
    "        c_out=NUM_TARGET,\n",
    "        seq_len=SEQUENCE_LENGTH,\n",
    "        class_weight=torch.FloatTensor(get_class_weight()),\n",
    "    ):\n",
    "        super(VitTimeSeriesTransformer, self).__init__()\n",
    "\n",
    "        self.d_model = config[\"d_model\"]\n",
    "        self.depth = config[\"depth\"]\n",
    "        self.heads = config[\"heads\"]\n",
    "        self.mlp_dim = config[\"mlp_dim\"]\n",
    "        self.dim_head = config[\"dim_head\"]\n",
    "        self.dropout_p = config[\"dropout\"]\n",
    "        self.emb_dropout_p = config[\"emb_dropout\"]\n",
    "        self.lr = config[\"lr\"]\n",
    "        self.weight_decay = config[\"weight_decay\"]\n",
    "        self.patience = config[\"patience\"]\n",
    "\n",
    "        self.pos_embedding = nn.Parameter(torch.randn(1, seq_len + 1, self.d_model))\n",
    "        self.patch_to_embedding = nn.Linear(c_in, self.d_model)\n",
    "        self.cls_token = nn.Parameter(torch.randn(1, 1, self.d_model))\n",
    "        self.dropout = nn.Dropout(self.emb_dropout_p)\n",
    "        self.transformer = Transformer(\n",
    "            self.d_model,\n",
    "            self.depth,\n",
    "            self.heads,\n",
    "            self.dim_head,\n",
    "            self.mlp_dim,\n",
    "            self.dropout_p,\n",
    "        )\n",
    "        self.to_cls_token = nn.Identity()\n",
    "        self.mlp_head = nn.Sequential(\n",
    "            nn.LayerNorm(self.d_model),\n",
    "            nn.Linear(self.d_model, self.mlp_dim),\n",
    "            nn.GELU(),\n",
    "            nn.Dropout(self.dropout_p),\n",
    "            nn.Linear(self.mlp_dim, c_out),\n",
    "        )\n",
    "\n",
    "        self.c_out = c_out\n",
    "        self.register_buffer(\"class_weight\", class_weight)\n",
    "\n",
    "    def forward(self, x, mask=None, register_hook=False):\n",
    "        # x = rearrange(x, 'b v s-> b s v') # bs x nvars x seq_len ->  bs x seq_len x nvars\n",
    "        x = self.patch_to_embedding(x)  # bs x seq_len x nvars -> bs x seq_len x d_model\n",
    "        b, n, _ = x.shape  # bs, seq_len\n",
    "\n",
    "        cls_tokens = repeat(self.cls_token, \"() n d -> b n d\", b=b)  # bs x 1 x d_model\n",
    "        x = torch.cat((cls_tokens, x), dim=1)  # bs x (seq_len + 1) x d_model\n",
    "        x += self.pos_embedding[\n",
    "            :, : (n + 1)\n",
    "        ]  # +=  1 x (seq_len + 1) x d_model -> # bs x (seq_len + 1) x d_model\n",
    "        x = self.dropout(x)  # bs x (seq_len + 1) x d_model\n",
    "\n",
    "        x = self.transformer(\n",
    "            x, mask=mask, register_hook=register_hook\n",
    "        )  # bs x (seq_len + 1) x d_model\n",
    "\n",
    "        x = self.to_cls_token(x[:, 0])  # bs x d_model\n",
    "        return self.mlp_head(x)  # bs x num_classes\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        # optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)\n",
    "        # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.step_size, gamma=self.gamma)\n",
    "        optimizer = torch.optim.AdamW(\n",
    "            self.parameters(), lr=self.lr, weight_decay=self.weight_decay\n",
    "        )\n",
    "        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n",
    "            optimizer, patience=self.patience\n",
    "        )\n",
    "        return {\n",
    "            \"optimizer\": optimizer,\n",
    "            \"lr_scheduler\": scheduler,\n",
    "            \"monitor\": \"train_loss\",\n",
    "        }\n",
    "\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        x, y, _, mask = batch\n",
    "        y_hat = self(x, mask)\n",
    "        y = y.long()\n",
    "        train_loss = F.cross_entropy(y_hat, y, weight=self.class_weight)\n",
    "        train_auc = FM.accuracy(F.softmax(y_hat, dim=1), y, num_classes=self.c_out)\n",
    "        train_auroc = FM.auroc(F.softmax(y_hat, dim=1), y, num_classes=self.c_out)\n",
    "        self.log(\n",
    "            \"train_loss\",\n",
    "            train_loss,\n",
    "            on_step=False,\n",
    "            on_epoch=True,\n",
    "            prog_bar=True,\n",
    "            logger=True,\n",
    "        )\n",
    "        self.log(\n",
    "            \"train_auc\",\n",
    "            train_auc,\n",
    "            on_step=False,\n",
    "            on_epoch=True,\n",
    "            prog_bar=True,\n",
    "            logger=True,\n",
    "        )\n",
    "        return train_loss\n",
    "\n",
    "    def validation_step(self, batch, batch_idx):\n",
    "        x, y, _, mask = batch\n",
    "        y_hat = self(x, mask)\n",
    "        y = y.long()\n",
    "        val_loss = F.cross_entropy(y_hat, y, weight=self.class_weight)\n",
    "        val_auc = FM.accuracy(F.softmax(y_hat, dim=1), y, num_classes=self.c_out)\n",
    "        val_auroc = FM.auroc(F.softmax(y_hat, dim=1), y, num_classes=self.c_out)\n",
    "        self.log(\n",
    "            \"val_loss\",\n",
    "            val_loss,\n",
    "            on_step=False,\n",
    "            on_epoch=True,\n",
    "            prog_bar=True,\n",
    "            logger=True,\n",
    "            sync_dist=True,\n",
    "        )\n",
    "        self.log(\n",
    "            \"val_auc\", val_auc, on_step=False, on_epoch=True, prog_bar=True, logger=True\n",
    "        )\n",
    "        return val_loss\n",
    "\n",
    "    def test_step(self, batch, batch_idx):\n",
    "        x, y, _, mask = batch\n",
    "        y_hat = self(x, mask)\n",
    "        y = y.long()\n",
    "        test_loss = F.cross_entropy(y_hat, y, weight=self.class_weight)\n",
    "        test_auc = FM.accuracy(F.softmax(y_hat, dim=1), y, num_classes=self.c_out)\n",
    "        test_auroc = FM.auroc(F.softmax(y_hat, dim=1), y, num_classes=self.c_out)\n",
    "        self.log(\n",
    "            \"test_loss\",\n",
    "            test_loss,\n",
    "            on_step=False,\n",
    "            on_epoch=True,\n",
    "            prog_bar=True,\n",
    "            logger=True,\n",
    "            sync_dist=True,\n",
    "        )\n",
    "        self.log(\n",
    "            \"test_auc\",\n",
    "            test_auc,\n",
    "            on_step=False,\n",
    "            on_epoch=True,\n",
    "            prog_bar=True,\n",
    "            logger=True,\n",
    "        )\n",
    "        return test_loss"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7da96b65",
   "metadata": {},
   "source": [
    "### 2. Routine for Single/Multi-GPU DDP Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08994660",
   "metadata": {},
   "outputs": [],
   "source": [
    "# | export\n",
    "def get_model(config):\n",
    "    model = VitTimeSeriesTransformer(config)\n",
    "    return model\n",
    "\n",
    "\n",
    "def get_datamodule(config):\n",
    "    return VitMTSCPetastormDataModule(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9a6b1f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "%env LOCAL_RANK=0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "666c97b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "import json\n",
    "import ray\n",
    "from ray.tune import ExperimentAnalysis\n",
    "from ray import tune\n",
    "from ray.tune import CLIReporter\n",
    "from ray.tune.schedulers import ASHAScheduler\n",
    "from ray.tune.integration.pytorch_lightning import TuneReportCallback\n",
    "\n",
    "def tune_training(config, num_epochs = TUNE_EPOCHS, num_gpus = NUM_GPUS):\n",
    "    pl.seed_everything(42, workers=True)\n",
    "    model = get_model(config)\n",
    "    dm = get_datamodule(config)\n",
    "    metrics = {\"loss\": \"val_loss\", \"auc\": \"val_auc\"}\n",
    "    callbacks = [TuneReportCallback(metrics, on=\"validation_end\")]\n",
    "    \n",
    "    trainer = pl.Trainer(\n",
    "        max_epochs=num_epochs,\n",
    "        # If fractional GPUs passed in, convert to int.\n",
    "        #gpus= math.ceil(num_gpus),\n",
    "        accelerator='gpu', devices=math.ceil(num_gpus),\n",
    "        strategy= \"dp\",\n",
    "        callbacks=callbacks,\n",
    "        limit_train_batches= math.ceil(get_train_dataset_size()/config['batch_size']), \n",
    "        limit_val_batches= math.ceil(get_valid_dataset_size()/config['batch_size']), \n",
    "        val_check_interval= math.ceil(get_train_dataset_size()/config['batch_size']), \n",
    "        num_sanity_val_steps=0,\n",
    "        reload_dataloaders_every_n_epochs=1,\n",
    "        deterministic=True\n",
    "    )\n",
    "    \n",
    "    trainer.fit(model, dm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db685738",
   "metadata": {},
   "outputs": [],
   "source": [
    "# | export\n",
    "def tune_training_asha(\n",
    "    num_samples=NUM_SAMPLES,\n",
    "    num_epochs=TUNE_EPOCHS,\n",
    "    num_gpus=NUM_GPUS,\n",
    "    gpus_per_trial=0.2,\n",
    "):\n",
    "    config = {\n",
    "        \"d_model\": tune.choice([16, 32, 48, 64, 128, 256, 512]),\n",
    "        \"depth\": tune.choice([2, 4, 6, 8]),\n",
    "        \"heads\": tune.choice([2, 4, 6, 8]),\n",
    "        \"mlp_dim\": tune.choice([8, 10, 12, 14, 16, 20, 24, 32]),\n",
    "        \"dim_head\": tune.choice([8, 10, 12, 14, 16]),\n",
    "        \"dropout\": tune.loguniform(1e-6, 1e-3),\n",
    "        \"emb_dropout\": tune.loguniform(1e-6, 1e-3),\n",
    "        \"weight_decay\": tune.loguniform(1e-3, 1e-1),\n",
    "        \"lr\": tune.loguniform(1e-6, 1e-1),\n",
    "        \"patience\": tune.choice([1, 2]),\n",
    "        \"batch_size\": tune.choice([64, 128, 256, 512, 1024]),\n",
    "    }\n",
    "\n",
    "    scheduler = ASHAScheduler(max_t=num_epochs, grace_period=1, reduction_factor=2)\n",
    "\n",
    "    reporter = CLIReporter(\n",
    "        parameter_columns=[\n",
    "            \"d_model\",\n",
    "            \"depth\",\n",
    "            \"heads\",\n",
    "            \"mlp_dim\",\n",
    "            \"dim_head\",\n",
    "            \"dropout\",\n",
    "            \"emb_dropout\",\n",
    "            \"weight_decay\",\n",
    "            \"lr\",\n",
    "            \"patience\",\n",
    "            \"batch_size\",\n",
    "        ],\n",
    "        metric_columns=[\"loss\", \"auc\", \"training_iteration\"],\n",
    "    )\n",
    "\n",
    "    trainable = tune.with_parameters(\n",
    "        tune_training, num_epochs=num_epochs, num_gpus=num_gpus\n",
    "    )\n",
    "    analysis = tune.run(\n",
    "        trainable,\n",
    "        resources_per_trial={\"cpu\": 1, \"gpu\": gpus_per_trial},\n",
    "        metric=\"loss\",\n",
    "        mode=\"min\",\n",
    "        config=config,\n",
    "        num_samples=num_samples,\n",
    "        scheduler=scheduler,\n",
    "        progress_reporter=reporter,\n",
    "        verbose=1,\n",
    "        name=\"FaceDetection\",\n",
    "        raise_on_failed_trial = False\n",
    "    )\n",
    "\n",
    "    print(\"Best hyperparameters found were: \", analysis.best_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0705054b",
   "metadata": {},
   "outputs": [],
   "source": [
    "%env LOCAL_RANK=0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0828d3ee-7968-4425-b68a-0d67ce7626dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "!rm -rf ~/ray_results/FaceDetection/ \n",
    "!rm -rf ./output/FaceDetection/ray_results/\n",
    "!rm -rf ./output/FaceDetection/experiments_result\n",
    "!mkdir -p  output/FaceDetection/experiments_result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8786fb0d-3c51-4c95-9a53-83e75a83093d",
   "metadata": {
    "tags": [
     "tune"
    ]
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "#ray.init(num_cpus=16, num_gpus=8) # p3dn.24xlarge\n",
    "ray.init(num_cpus=4, num_gpus=4) # g4dn.12xlarge\n",
    "# ray.init(num_cpus=1, num_gpus=1)\n",
    "tune_training_asha(num_gpus=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3aa05271-bf28-45d1-9f17-9612f13861c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "!cp -rf ~/ray_results/FaceDetection/ output/FaceDetection/ray_results/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee0ee043",
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.set_option('display.max_columns', None)\n",
    "pd.set_option('display.expand_frame_repr', False)\n",
    "pd.set_option('max_colwidth', None)\n",
    "\n",
    "#analysis = ExperimentAnalysis(f'~/ray_results/{DATASET_NAME}')\n",
    "analysis = ExperimentAnalysis(product['FaceDetection_MODEL_TUNE_OUTPUT'])\n",
    "tune_result_df = analysis.results_df[['loss', 'auc', 'training_iteration', 'experiment_tag']]\n",
    "tune_result_df.nsmallest(5, 'loss')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "164f4c6b-7a33-40fe-b3fc-308dd4bc5b87",
   "metadata": {},
   "outputs": [],
   "source": [
    "best_config = analysis.get_best_config('loss', 'min')\n",
    "print(best_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ceed76c-4cf2-4dfa-9563-f1e4cbbf13cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "import json\n",
    "def write_best_model_config():\n",
    "    analysis = ExperimentAnalysis(product['FaceDetection_MODEL_TUNE_OUTPUT'])\n",
    "    best_config = analysis.get_best_config('loss', 'min')\n",
    "    with open(product['FaceDetection_BEST_MODEL_CONFIG'], 'w') as outfile:\n",
    "        # Serializing json\n",
    "        json_object = json.dumps(best_config, indent=4)\n",
    "        outfile.write(json_object)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c22dfaa2-eb16-4d84-9c0e-90ffec7a5fe3",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "write_best_model_config()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c9d265b-36e6-4dd3-842c-c81adfe61193",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def get_best_model_config():\n",
    "    with open(product['FaceDetection_BEST_MODEL_CONFIG'], 'r') as json_file:\n",
    "        return json.load(json_file)\n",
    "        \n",
    "best_config = get_best_model_config()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70ee758b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def training_loop(TB_LOG_DIR, max_epochs = MAX_EPOCHS, config = best_config):\n",
    "    pl.seed_everything(42, workers=True)\n",
    "    model = get_model(config)\n",
    "    dm = get_datamodule(config)\n",
    "\n",
    "    checkpoint_callback = ModelCheckpoint(dirpath=product['FaceDetection_MODEL_TRAINING_CHECKPOINT_OUTPUT'], \n",
    "                                          save_top_k = 1, #-1, \n",
    "                                          filename=f\"{DATASET_NAME}\" + '-vittsc-mask-{epoch:02d}')\n",
    "    tb_logger = pl_loggers.TensorBoardLogger(TB_LOG_DIR)\n",
    "    \n",
    "    lr_monitor = LearningRateMonitor(logging_interval='step')\n",
    "\n",
    "    early_stop_callback = EarlyStopping(monitor=\"val_loss\", min_delta=0.001, patience=3, verbose=False, mode=\"min\")\n",
    "    \n",
    "    trainer = pl.Trainer(\n",
    "        #gpus=1,\n",
    "        accelerator='gpu', devices=1,\n",
    "        #track_grad_norm=2,\n",
    "        #plugins='deepspeed', \n",
    "        #stochastic_weight_avg=True,\n",
    "        #precision=16,\n",
    "        max_epochs=max_epochs,\n",
    "        strategy= 'dp',   #'ddp',\n",
    "        logger=tb_logger,\n",
    "        callbacks=[lr_monitor, checkpoint_callback, early_stop_callback],\n",
    "        limit_train_batches= math.ceil(get_train_dataset_size()/config['batch_size']), \n",
    "        limit_val_batches= math.ceil(get_valid_dataset_size()/config['batch_size']), \n",
    "        val_check_interval= math.ceil(get_train_dataset_size()/config['batch_size']), \n",
    "        num_sanity_val_steps=0,\n",
    "        reload_dataloaders_every_n_epochs=1,\n",
    "        deterministic=True\n",
    "    )\n",
    "    \n",
    "    trainer.fit(model, dm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1376c72",
   "metadata": {},
   "outputs": [],
   "source": [
    "%env LOCAL_RANK=0"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "293a4295",
   "metadata": {},
   "source": [
    "### 4. Model Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "393a6510",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "if __name__ == \"__main__\":\n",
    "    training_loop(TB_LOG_DIR = product['FaceDetection_MODEL_TRAINING_OUTPUT'],\n",
    "                  max_epochs = MAX_EPOCHS, \n",
    "                  config = get_best_model_config())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "83be4acd",
   "metadata": {},
   "source": [
    "### 5. Training Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6471d8bc-b2e1-4384-964e-3ba7f32fde91",
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "import shutil\n",
    "source_file = glob.glob(product['FaceDetection_MODEL_TRAINING_CHECKPOINT_OUTPUT'] + '/*.ckpt')[0]\n",
    "print(source_file)\n",
    "shutil.copyfile(source_file, product['FaceDetection_BEST_MODEL'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e80ba31c-1483-4a83-b179-818aadb68d51",
   "metadata": {},
   "outputs": [],
   "source": [
    "#%load_ext tensorboard\n",
    "#%tensorboard --logdir experiments_result/FaceDetection/vittsc_mask --port 8199"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "121c62ba",
   "metadata": {},
   "source": [
    "__We shutdown the kernel!!!__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "943620bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "from nbdev import nbdev_export\n",
    "nbdev_export()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f1c362d3",
   "metadata": {},
   "source": [
    "__Multi-GPU Training__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71d015aa-39a4-4e14-9893-b90f1b0dada4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "rapids-22.08_ploomber",
   "language": "python",
   "name": "rapids-22.08_ploomber"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "papermill": {
   "environment_variables": {},
   "parameters": {
    "product": {
     "FaceDetection_BEST_MODEL": "/home/ubuntu/vitmtsc_nbdev/output/FaceDetection/experiments_result/best_model.ckpt",
     "FaceDetection_BEST_MODEL_CONFIG": "/home/ubuntu/vitmtsc_nbdev/output/FaceDetection/experiments_result/best_model_config.json",
     "FaceDetection_MODEL_TRAINING_CHECKPOINT_OUTPUT": "/home/ubuntu/vitmtsc_nbdev/output/FaceDetection/experiments_result/checkpoint",
     "FaceDetection_MODEL_TRAINING_OUTPUT": "/home/ubuntu/vitmtsc_nbdev/output/FaceDetection/experiments_result",
     "FaceDetection_MODEL_TUNE_OUTPUT": "/home/ubuntu/vitmtsc_nbdev/output/FaceDetection/ray_results",
     "nb": "/home/ubuntu/vitmtsc_nbdev/output/401_model.optimization.nn.tsc.vittsc.face_detection_training_mask_tune.html"
    },
    "upstream": {
     "tabular_to_timeseries_face_detection": {
      "FaceDetection_TEST_MODEL_INPUT": "/home/ubuntu/vitmtsc_nbdev/output/FaceDetection/target_encoding-nn/test",
      "FaceDetection_TRAIN_MODEL_INPUT": "/home/ubuntu/vitmtsc_nbdev/output/FaceDetection/target_encoding-nn/train",
      "FaceDetection_VALID_MODEL_INPUT": "/home/ubuntu/vitmtsc_nbdev/output/FaceDetection/target_encoding-nn/valid",
      "nb": "/home/ubuntu/vitmtsc_nbdev/output/301_feature_preprocessing.face_detection.tabular_to_timeseries.html"
     }
    }
   },
   "version": null
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
