{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp model.optimization.nn.tsc.vittsc.insect_wingbeat_training_mask_tune\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "# declare a list tasks whose products you want to use as inputs\n",
    "upstream = ['tabular_to_timeseries_insect_wingbeat']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# |export\n",
    "upstream = {\n",
    "    \"tabular_to_timeseries_insect_wingbeat\": {\n",
    "        \"nb\": \"/home/ubuntu/vitmtsc_nbdev/output/302_feature_preprocessing.insect_wingbeat.tabular_to_timeseries.html\",\n",
    "        \"InsectWingbeat_TRAIN_MODEL_INPUT\": \"/home/ubuntu/vitmtsc_nbdev/output/InsectWingbeat/target_encoding-nn/train\",\n",
    "        \"InsectWingbeat_VALID_MODEL_INPUT\": \"/home/ubuntu/vitmtsc_nbdev/output/InsectWingbeat/target_encoding-nn/valid\",\n",
    "        \"InsectWingbeat_TEST_MODEL_INPUT\": \"/home/ubuntu/vitmtsc_nbdev/output/InsectWingbeat/target_encoding-nn/test\",\n",
    "    }\n",
    "}\n",
    "product = {\n",
    "    \"nb\": \"/home/ubuntu/vitmtsc_nbdev/output/402_model.optimization.nn.tsc.vittsc.insect_wingbeat_training_mask_tune.html\",\n",
    "    \"InsectWingbeat_MODEL_TUNE_OUTPUT\": \"/home/ubuntu/vitmtsc_nbdev/output/InsectWingbeat/ray_results\",\n",
    "    \"InsectWingbeat_MODEL_TRAINING_OUTPUT\": \"/home/ubuntu/vitmtsc_nbdev/output/InsectWingbeat/experiments_result\",\n",
    "    \"InsectWingbeat_MODEL_TRAINING_CHECKPOINT_OUTPUT\": \"/home/ubuntu/vitmtsc_nbdev/output/InsectWingbeat/experiments_result/checkpoint\",\n",
    "    \"InsectWingbeat_BEST_MODEL\": \"/home/ubuntu/vitmtsc_nbdev/output/InsectWingbeat/experiments_result/best_model.ckpt\",\n",
    "    \"InsectWingbeat_BEST_MODEL_CONFIG\": \"/home/ubuntu/vitmtsc_nbdev/output/InsectWingbeat/experiments_result/best_model_config.json\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "import sys\n",
    "import pathlib as p\n",
    "\n",
    "def is_running_from_ipython():\n",
    "    from IPython import get_ipython\n",
    "    return get_ipython() is not None\n",
    "\n",
    "if not is_running_from_ipython() and __package__ is None:\n",
    "    DIR = p.Path(__file__).resolve().parent\n",
    "    sys.path.insert(0, str(DIR.parent))\n",
    "    __package__ = DIR.name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "import torch\n",
    "import pytorch_lightning\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import os\n",
    "import math \n",
    "\n",
    "from torch.nn import functional as F\n",
    "from torch import nn\n",
    "from torchmetrics import functional as FM\n",
    "from pytorch_lightning import loggers as pl_loggers\n",
    "from pytorch_lightning.callbacks import ModelCheckpoint\n",
    "from pytorch_lightning.callbacks import LearningRateMonitor\n",
    "from petastorm import make_batch_reader\n",
    "from petastorm.pytorch import DataLoader\n",
    "from einops import rearrange, repeat\n",
    "from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n",
    "from pytorch_lightning.loggers import TensorBoardLogger\n",
    "import pytorch_lightning as pl"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Vision Transformer for Multivariate Time-Series Classification (VitMTSC) Model Training with Masking - Hyperparameter search\n",
    "\n",
    "> Classification Task\n",
    "\n",
    "> Data Loader Module\n",
    "\n",
    "> Hyperparameter Search"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "DATASET_NAME = 'InsectWingbeat'\n",
    "NUM_TARGET = 10\n",
    "SEQUENCE_LENGTH = 22\n",
    "NUMBER_OF_FEATURES = 200\n",
    "NUM_WORKERS = 1\n",
    "NUM_GPUS = 1\n",
    "MAX_EPOCHS = 50\n",
    "TUNE_EPOCHS = 5\n",
    "NUM_SAMPLES = 1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "import dask_cudf\n",
    "import numpy as np\n",
    "import sklearn.utils.class_weight\n",
    "\n",
    "def get_train_dataset_size():\n",
    "    gdf = dask_cudf.read_parquet(upstream['tabular_to_timeseries_insect_wingbeat']['InsectWingbeat_TRAIN_MODEL_INPUT'], columns = ['case_id'])\n",
    "    return gdf.case_id.nunique().compute()\n",
    "\n",
    "def get_valid_dataset_size():\n",
    "    gdf = dask_cudf.read_parquet(upstream['tabular_to_timeseries_insect_wingbeat']['InsectWingbeat_VALID_MODEL_INPUT'], columns = ['case_id'])\n",
    "    return gdf.case_id.nunique().compute()\n",
    "\n",
    "def get_test_dataset_size():\n",
    "    gdf = dask_cudf.read_parquet(upstream['tabular_to_timeseries_insect_wingbeat']['InsectWingbeat_TEST_MODEL_INPUT'], columns = ['case_id'])\n",
    "    return gdf.case_id.nunique().compute()\n",
    "\n",
    "def get_class_weight():\n",
    "    train_gdf = dask_cudf.read_parquet(upstream['tabular_to_timeseries_insect_wingbeat']['InsectWingbeat_TRAIN_MODEL_INPUT'], columns = ['case_id', 'class_vals'])\n",
    "    y_train = train_gdf['class_vals'].compute().to_numpy()\n",
    "    class_weight = sklearn.utils.class_weight.compute_class_weight('balanced', classes = np.unique(y_train), y = y_train)\n",
    "    class_weight = class_weight/2\n",
    "    print(f'class_weight: {class_weight}')\n",
    "    return class_weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "get_train_dataset_size(), get_valid_dataset_size(), get_test_dataset_size(), get_class_weight()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. Model Definition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class Residual(nn.Module):\n",
    "    def __init__(self, fn):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.fn = fn\n",
    "        \n",
    "    def forward(self, x, **kwargs):\n",
    "        return self.fn(x, **kwargs) + x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class PreNorm(nn.Module):\n",
    "    def __init__(self, dim, fn):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.norm = nn.LayerNorm(dim)\n",
    "        self.fn = fn\n",
    "        \n",
    "    def forward(self, x, **kwargs):\n",
    "        return self.fn(self.norm(x), **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class FeedForward(nn.Module):\n",
    "    def __init__(self, dim, hidden_dim, dropout = 0.):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.net = nn.Sequential(\n",
    "            nn.Linear(dim, hidden_dim),\n",
    "            nn.GELU(),\n",
    "            nn.Dropout(dropout),\n",
    "            nn.Linear(hidden_dim, dim),\n",
    "            nn.Dropout(dropout)\n",
    "        )\n",
    "        \n",
    "    def forward(self, x):\n",
    "        return self.net(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class Attention(nn.Module):\n",
    "    def __init__(self, dim, heads = 10, dim_head = 32, dropout = 0.):\n",
    "        super().__init__()\n",
    "        \n",
    "        inner_dim = dim_head *  heads\n",
    "        self.heads = heads\n",
    "        self.scale = dim_head ** -0.5\n",
    "        \n",
    "        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)\n",
    "        self.to_out = nn.Sequential(\n",
    "            nn.Linear(inner_dim, dim),\n",
    "            nn.Dropout(dropout)\n",
    "        )\n",
    "        \n",
    "        self.attn_gradients = None\n",
    "        self.attention_map = None\n",
    "        \n",
    "    def save_attn_gradients(self, attn_gradients):\n",
    "        self.attn_gradients = attn_gradients\n",
    "\n",
    "    def get_attn_gradients(self):\n",
    "        return self.attn_gradients\n",
    "\n",
    "    def save_attention_map(self, attention_map):\n",
    "        self.attention_map = attention_map\n",
    "\n",
    "    def get_attention_map(self):\n",
    "        return self.attention_map\n",
    "\n",
    "    def forward(self, x, mask = None, register_hook = False):\n",
    "        b, n, _, h = *x.shape, self.heads\n",
    "        qkv = self.to_qkv(x).chunk(3, dim = -1)\n",
    "        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)\n",
    "    \n",
    "        dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale\n",
    "        mask_value = -torch.finfo(dots.dtype).max\n",
    "        \n",
    "        #print('mask1.shape', mask.shape)\n",
    "        if mask is not None:\n",
    "            #mask = F.pad(mask, (1, 0), value = True)\n",
    "            mask = F.pad(mask.flatten(1), (1, 0), value = True)\n",
    "            mask = mask.unsqueeze(1).unsqueeze(2)\n",
    "            \n",
    "            #print('mask2.shape', mask.shape)\n",
    "           # print('mask:', mask)\n",
    "            assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'\n",
    "            dots.masked_fill_(mask == 0.0, mask_value)\n",
    "            del mask\n",
    "        \n",
    "        attn = dots.softmax(dim=-1)\n",
    "        #print('attn.shape: ', attn.shape)\n",
    "        #print('attn: ', attn)\n",
    "        \n",
    "        out = torch.einsum('bhij,bhjd->bhid', attn, v)\n",
    "        \n",
    "        if register_hook:\n",
    "            self.save_attention_map(attn)\n",
    "            attn.register_hook(self.save_attn_gradients)\n",
    "            \n",
    "        out = rearrange(out, 'b h n d -> b n (h d)')\n",
    "        out =  self.to_out(out)\n",
    "        #print('out.shape: ', out.shape)\n",
    "        #print('out: ', out)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class Transformer(nn.Module):\n",
    "    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.layers = nn.ModuleList([])\n",
    "        for _ in range(depth):\n",
    "            self.layers.append(nn.ModuleList([\n",
    "                Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),\n",
    "                Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))\n",
    "            ]))\n",
    "            \n",
    "    def forward(self, x, mask = None, register_hook = False):\n",
    "        for attn, ff in self.layers:\n",
    "            x = attn(x, mask = mask, register_hook = register_hook)\n",
    "            x = ff(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def petastorm_collate_fn(rows):\n",
    "    data_df = pd.DataFrame(rows)\n",
    "    #print(f'data_df.shape: {data_df.shape}') # data_df.shape: (2, 4402) 22 * 200 + 2\n",
    "    \n",
    "    case_id_df = data_df.iloc[:, NUMBER_OF_FEATURES*SEQUENCE_LENGTH+1:NUMBER_OF_FEATURES*SEQUENCE_LENGTH+2] # NUMBER_OF_FEATURES*SEQUENCE_LENGTH+1:NUMBER_OF_FEATURES*SEQUENCE_LENGTH+2\n",
    "    case_id_tensor = torch.tensor(case_id_df.values.astype(np.float64))\n",
    "    \n",
    "    target_df = data_df.iloc[:, NUMBER_OF_FEATURES*SEQUENCE_LENGTH+0:NUMBER_OF_FEATURES*SEQUENCE_LENGTH+1] # NUMBER_OF_FEATURES*SEQUENCE_LENGTH+0:NUMBER_OF_FEATURES*SEQUENCE_LENGTH+1\n",
    "    target_tensor = torch.tensor(target_df.values.astype(np.float32))\n",
    "    \n",
    "    data_tensor_df = data_df.iloc[:, 0*SEQUENCE_LENGTH:NUMBER_OF_FEATURES*SEQUENCE_LENGTH] # 0*SEQUENCE_LENGTH:NUMBER_OF_FEATURES*SEQUENCE_LENGTH\n",
    "    data_tensor = torch.tensor(data_tensor_df.values.astype(np.float32))\n",
    "    data_tensor = rearrange(data_tensor, 't (b h)-> t h b', h = SEQUENCE_LENGTH)\n",
    "    \n",
    "    mask_df = data_df.iloc[:, 0*SEQUENCE_LENGTH:1*SEQUENCE_LENGTH] # 0*SEQUENCE_LENGTH:1*SEQUENCE_LENGTH\n",
    "    mask_tensor = torch.tensor(mask_df.values.astype(np.float32))\n",
    "    \n",
    "    return data_tensor, target_tensor.squeeze(), case_id_tensor.squeeze(), mask_tensor.squeeze()\n",
    "\n",
    "class VitMTSCPetastormDataModule(pl.LightningDataModule):\n",
    "    def __init__(self, config,\n",
    "                 num_workers=NUM_WORKERS,\n",
    "                 transform_spec = None,\n",
    "                 shard_count = NUM_GPUS, \n",
    "                 num_epochs = MAX_EPOCHS):\n",
    "        super().__init__()\n",
    "        self.train_files = f\"file://{upstream['tabular_to_timeseries_insect_wingbeat']['InsectWingbeat_TRAIN_MODEL_INPUT']}\"\n",
    "        self.valid_files = f\"file://{upstream['tabular_to_timeseries_insect_wingbeat']['InsectWingbeat_VALID_MODEL_INPUT']}\"\n",
    "        self.test_files = f\"file://{upstream['tabular_to_timeseries_insect_wingbeat']['InsectWingbeat_TEST_MODEL_INPUT']}\"\n",
    "        self.batch_size =  config[\"batch_size\"]\n",
    "        self.num_workers = num_workers\n",
    "        self.transform_spec = transform_spec\n",
    "        self.shard_count = shard_count\n",
    "        self.num_epochs = num_epochs     \n",
    "    \n",
    "    def train_dataloader(self):\n",
    "        \n",
    "        self.train_ds = make_batch_reader(self.train_files, workers_count=self.num_workers, transform_spec=self.transform_spec, \n",
    "                                          cur_shard = int(os.environ['LOCAL_RANK']), shard_count = self.shard_count, num_epochs = self.num_epochs)\n",
    "        return DataLoader(self.train_ds, batch_size = self.batch_size, collate_fn= petastorm_collate_fn)\n",
    "\n",
    "    def val_dataloader(self):\n",
    "        print('val_dataloader: local rank :', int(os.environ['LOCAL_RANK']), 'shard count: ', self.shard_count)\n",
    "        self.val_ds = make_batch_reader(self.valid_files, workers_count=self.num_workers, transform_spec=self.transform_spec, \n",
    "                                        cur_shard = int(os.environ['LOCAL_RANK']), shard_count = self.shard_count, num_epochs = self.num_epochs)\n",
    "        return DataLoader(self.val_ds, batch_size = self.batch_size, collate_fn= petastorm_collate_fn)\n",
    "\n",
    "    def test_dataloader(self):\n",
    "        print('test_dataloader: local rank :', int(os.environ['LOCAL_RANK']), 'shard count: ', self.shard_count)\n",
    "        self.test_ds = make_batch_reader(self.test_files, workers_count=self.num_workers, transform_spec=self.transform_spec, \n",
    "                                         cur_shard = int(os.environ['LOCAL_RANK']), shard_count = self.shard_count, num_epochs = self.num_epochs)\n",
    "        return DataLoader(self.test_ds, batch_size = self.batch_size, collate_fn= petastorm_collate_fn) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class VitTimeSeriesTransformer(pl.LightningModule):\n",
    "    def __init__(self, config, c_in = NUMBER_OF_FEATURES, c_out = NUM_TARGET, \n",
    "                 seq_len = SEQUENCE_LENGTH,class_weight = torch.FloatTensor(get_class_weight())):\n",
    "        super(VitTimeSeriesTransformer, self).__init__()\n",
    "        \n",
    "        self.d_model = config[\"d_model\"]\n",
    "        self.depth = config[\"depth\"]\n",
    "        self.heads = config[\"heads\"]\n",
    "        self.mlp_dim = config[\"mlp_dim\"]\n",
    "        self.dim_head = config[\"dim_head\"]\n",
    "        self.dropout_p = config[\"dropout\"]\n",
    "        self.emb_dropout_p = config[\"emb_dropout\"]\n",
    "        self.lr = config[\"lr\"]\n",
    "        self.weight_decay = config[\"weight_decay\"]\n",
    "        self.patience = config[\"patience\"]\n",
    "        \n",
    "        self.pos_embedding = nn.Parameter(torch.randn(1, seq_len + 1, self.d_model))\n",
    "        self.patch_to_embedding = nn.Linear(c_in, self.d_model)\n",
    "        self.cls_token = nn.Parameter(torch.randn(1, 1, self.d_model))\n",
    "        self.dropout = nn.Dropout(self.emb_dropout_p)\n",
    "        self.transformer = Transformer(self.d_model, self.depth, self.heads, self.dim_head, self.mlp_dim, self.dropout_p)\n",
    "        self.to_cls_token = nn.Identity()\n",
    "        self.mlp_head = nn.Sequential(\n",
    "            nn.LayerNorm(self.d_model),\n",
    "            nn.Linear(self.d_model, self.mlp_dim),\n",
    "            nn.GELU(),\n",
    "            nn.Dropout(self.dropout_p),\n",
    "            nn.Linear(self.mlp_dim, c_out)\n",
    "        )\n",
    "    \n",
    "        self.c_out = c_out\n",
    "        self.register_buffer('class_weight', class_weight)\n",
    "\n",
    "    def forward(self, x, mask = None, register_hook = False):\n",
    "        #x = rearrange(x, 'b v s-> b s v') # bs x nvars x seq_len ->  bs x seq_len x nvars\n",
    "        x = self.patch_to_embedding(x) # bs x seq_len x nvars -> bs x seq_len x d_model\n",
    "        b, n, _ = x.shape # bs, seq_len\n",
    "\n",
    "        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) # bs x 1 x d_model\n",
    "        x = torch.cat((cls_tokens, x), dim=1) # bs x (seq_len + 1) x d_model\n",
    "        x += self.pos_embedding[:, :(n + 1)] # +=  1 x (seq_len + 1) x d_model -> # bs x (seq_len + 1) x d_model\n",
    "        x = self.dropout(x) # bs x (seq_len + 1) x d_model\n",
    "\n",
    "        x = self.transformer(x, mask = mask, register_hook = register_hook) # bs x (seq_len + 1) x d_model\n",
    "\n",
    "        x = self.to_cls_token(x[:, 0]) # bs x d_model\n",
    "        return self.mlp_head(x) # bs x num_classes\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        #optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)\n",
    "        #scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.step_size, gamma=self.gamma)\n",
    "        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)\n",
    "        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=self.patience)\n",
    "        return {\"optimizer\": optimizer, \"lr_scheduler\": scheduler, \"monitor\": \"train_loss\"}\n",
    "        \n",
    "    def training_step(self, batch, batch_idx):\n",
    "        x, y, _, mask = batch\n",
    "        y_hat = self(x, mask)\n",
    "        y = y.long()\n",
    "        train_loss = F.cross_entropy(y_hat, y, weight = self.class_weight)\n",
    "        train_auc =  FM.accuracy(F.softmax(y_hat, dim=1), y, num_classes = self.c_out)\n",
    "        train_auroc =  FM.auroc(F.softmax(y_hat, dim=1), y, num_classes = self.c_out)\n",
    "        self.log('train_loss', train_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)\n",
    "        self.log('train_auc', train_auc, on_step=False, on_epoch=True, prog_bar=True, logger=True)\n",
    "        return train_loss\n",
    "\n",
    "    def validation_step(self, batch, batch_idx):\n",
    "        x, y, _, mask = batch\n",
    "        y_hat = self(x, mask)\n",
    "        y = y.long()\n",
    "        val_loss = F.cross_entropy(y_hat, y, weight = self.class_weight)\n",
    "        val_auc =  FM.accuracy(F.softmax(y_hat, dim=1), y, num_classes = self.c_out)\n",
    "        val_auroc =  FM.auroc(F.softmax(y_hat, dim=1), y, num_classes = self.c_out)\n",
    "        self.log('val_loss', val_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)\n",
    "        self.log('val_auc', val_auc, on_step=False, on_epoch=True, prog_bar=True, logger=True)\n",
    "        return val_loss\n",
    "\n",
    "    def test_step(self, batch, batch_idx):\n",
    "        x, y, _, mask = batch\n",
    "        y_hat = self(x, mask)\n",
    "        y = y.long()\n",
    "        test_loss = F.cross_entropy(y_hat, y, weight = self.class_weight)\n",
    "        test_auc =  FM.accuracy(F.softmax(y_hat, dim=1), y, num_classes = self.c_out)\n",
    "        test_auroc =  FM.auroc(F.softmax(y_hat, dim=1), y, num_classes = self.c_out)\n",
    "        self.log('test_loss', test_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)\n",
    "        self.log('test_auc', test_auc, on_step=False, on_epoch=True, prog_bar=True, logger=True)\n",
    "        return test_loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. Routine for Single/Multi-GPU DDP Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def get_model(config):\n",
    "    model = VitTimeSeriesTransformer(config)\n",
    "    return model\n",
    "\n",
    "def get_datamodule(config):\n",
    "    return VitMTSCPetastormDataModule(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%env LOCAL_RANK=0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "import json\n",
    "import ray\n",
    "from ray.tune import ExperimentAnalysis\n",
    "from ray import tune\n",
    "from ray.tune import CLIReporter\n",
    "from ray.tune.schedulers import ASHAScheduler\n",
    "from ray.tune.integration.pytorch_lightning import TuneReportCallback\n",
    "\n",
    "def tune_training(config, num_epochs = TUNE_EPOCHS, num_gpus = NUM_GPUS):\n",
    "    pl.seed_everything(42, workers=True)\n",
    "    model = get_model(config)\n",
    "    dm = get_datamodule(config)\n",
    "    metrics = {\"loss\": \"val_loss\", \"auc\": \"val_auc\"}\n",
    "    callbacks = [TuneReportCallback(metrics, on=\"validation_end\")]\n",
    "    \n",
    "    trainer = pl.Trainer(\n",
    "        max_epochs=num_epochs,\n",
    "        # If fractional GPUs passed in, convert to int.\n",
    "        #gpus= math.ceil(num_gpus),\n",
    "        accelerator='gpu', devices=math.ceil(num_gpus),\n",
    "        strategy= \"dp\",\n",
    "        callbacks=callbacks,\n",
    "        limit_train_batches= math.ceil(get_train_dataset_size()/config['batch_size']), \n",
    "        limit_val_batches= math.ceil(get_valid_dataset_size()/config['batch_size']), \n",
    "        val_check_interval= math.ceil(get_train_dataset_size()/config['batch_size']), \n",
    "        num_sanity_val_steps=0,\n",
    "        reload_dataloaders_every_n_epochs=1,\n",
    "        deterministic=True\n",
    "    )\n",
    "    \n",
    "    trainer.fit(model, dm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def tune_training_asha(num_samples=NUM_SAMPLES, num_epochs=TUNE_EPOCHS, num_gpus = NUM_GPUS, gpus_per_trial=0.2):\n",
    "    config = {\n",
    "        \"d_model\": tune.choice([32, 64, 128, 256, 512]),\n",
    "        \"depth\": tune.choice([4, 6, 8, 10]),\n",
    "        \"heads\": tune.choice([6, 8, 10, 12, 14, 16]),\n",
    "        \"mlp_dim\": tune.choice([8, 10, 12, 14, 16]),\n",
    "        \"dim_head\": tune.choice([8, 10, 12, 14, 16]),\n",
    "        \"dropout\": tune.loguniform(1e-6, 1e-3),\n",
    "        \"emb_dropout\": tune.loguniform(1e-6, 1e-3),\n",
    "        \"weight_decay\": tune.loguniform(1e-5, 1e-1),\n",
    "        \"lr\": tune.loguniform(1e-6, 1e-3),\n",
    "        \"patience\": tune.choice([1, 2]),\n",
    "        \"batch_size\": tune.choice([16, 32, 64, 128, 256])\n",
    "    }\n",
    "    \n",
    "    scheduler = ASHAScheduler(\n",
    "        max_t=num_epochs,\n",
    "        grace_period=1,\n",
    "        reduction_factor=2)\n",
    "\n",
    "    reporter = CLIReporter(\n",
    "        parameter_columns=[\"d_model\", \"depth\", \"heads\", \"mlp_dim\", \"dim_head\", \"dropout\", \"emb_dropout\", \"weight_decay\", \"lr\", \"patience\", \"batch_size\"],\n",
    "        metric_columns=[\"loss\", \"auc\", \"training_iteration\"])\n",
    "\n",
    "    trainable = tune.with_parameters(\n",
    "            tune_training,\n",
    "            num_epochs=num_epochs,\n",
    "            num_gpus=num_gpus)\n",
    "    analysis = tune.run(\n",
    "        trainable,\n",
    "        resources_per_trial={\n",
    "           \"cpu\": 1,\n",
    "            \"gpu\": gpus_per_trial\n",
    "        },\n",
    "        metric=\"loss\",\n",
    "        mode=\"min\",\n",
    "        config=config,\n",
    "        num_samples=num_samples,\n",
    "        scheduler=scheduler,\n",
    "        progress_reporter=reporter,\n",
    "        verbose = 1,\n",
    "        name=\"InsectWingbeat\",\n",
    "        raise_on_failed_trial = False)\n",
    "\n",
    "    print(\"Best hyperparameters found were: \", analysis.best_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%env LOCAL_RANK=0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!rm -rf ~/ray_results/InsectWingbeat/ \n",
    "!rm -rf ./output/InsectWingbeat/ray_results/\n",
    "!rm -rf ./output/InsectWingbeat/experiments_result\n",
    "!mkdir -p  output/InsectWingbeat/experiments_result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "tune"
    ]
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "#ray.init(num_cpus=16, num_gpus=8) # p3dn.24xlarge\n",
    "ray.init(num_cpus=4, num_gpus=4) # g4dn.12xlarge\n",
    "#ray.init(num_cpus=1, num_gpus=1)\n",
    "tune_training_asha(num_gpus=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!cp -rf ~/ray_results/InsectWingbeat/ output/InsectWingbeat/ray_results/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.set_option('display.max_columns', None)\n",
    "pd.set_option('display.expand_frame_repr', False)\n",
    "pd.set_option('max_colwidth', None)\n",
    "\n",
    "#analysis = ExperimentAnalysis(f'~/ray_results/{DATASET_NAME}')\n",
    "analysis = ExperimentAnalysis(product['InsectWingbeat_MODEL_TUNE_OUTPUT'])\n",
    "tune_result_df = analysis.results_df[['loss', 'auc', 'training_iteration', 'experiment_tag']]\n",
    "tune_result_df.nsmallest(5, 'loss')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_config = analysis.get_best_config('loss', 'min')\n",
    "print(best_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "import json\n",
    "def write_best_model_config():\n",
    "    analysis = ExperimentAnalysis(product['InsectWingbeat_MODEL_TUNE_OUTPUT'])\n",
    "    best_config = analysis.get_best_config('loss', 'min')\n",
    "    with open(product['InsectWingbeat_BEST_MODEL_CONFIG'], 'w') as outfile:\n",
    "        # Serializing json\n",
    "        json_object = json.dumps(best_config, indent=4)\n",
    "        outfile.write(json_object)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "write_best_model_config()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def get_best_model_config():\n",
    "    with open(product['InsectWingbeat_BEST_MODEL_CONFIG'], 'r') as json_file:\n",
    "        return json.load(json_file)\n",
    "        \n",
    "best_config = get_best_model_config()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def training_loop(TB_LOG_DIR, max_epochs = MAX_EPOCHS, config = best_config):\n",
    "    pl.seed_everything(42, workers=True)\n",
    "    model = get_model(config)\n",
    "    dm = get_datamodule(config)\n",
    "\n",
    "    checkpoint_callback = ModelCheckpoint(dirpath=product['InsectWingbeat_MODEL_TRAINING_CHECKPOINT_OUTPUT'], \n",
    "                                          save_top_k = 1, #-1, \n",
    "                                          filename=f\"{DATASET_NAME}\" + '-vittsc-mask-{epoch:02d}')\n",
    "    tb_logger = pl_loggers.TensorBoardLogger(TB_LOG_DIR)\n",
    "    \n",
    "    lr_monitor = LearningRateMonitor(logging_interval='step')\n",
    "\n",
    "    early_stop_callback = EarlyStopping(monitor=\"val_loss\", min_delta=0.001, patience=3, verbose=False, mode=\"min\")\n",
    "    \n",
    "    trainer = pl.Trainer(\n",
    "        #gpus=1,\n",
    "        accelerator='gpu', devices=1,\n",
    "        #track_grad_norm=2,\n",
    "        #plugins='deepspeed', \n",
    "        #stochastic_weight_avg=True,\n",
    "        #precision=16,\n",
    "        max_epochs=max_epochs,\n",
    "        strategy= 'dp',   #'ddp',\n",
    "        logger=tb_logger,\n",
    "        callbacks=[lr_monitor, checkpoint_callback, early_stop_callback],\n",
    "        limit_train_batches= math.ceil(get_train_dataset_size()/config['batch_size']), \n",
    "        limit_val_batches= math.ceil(get_valid_dataset_size()/config['batch_size']), \n",
    "        val_check_interval= math.ceil(get_train_dataset_size()/config['batch_size']),  \n",
    "        num_sanity_val_steps=0,\n",
    "        reload_dataloaders_every_n_epochs=1,\n",
    "        deterministic=True\n",
    "    )\n",
    "    \n",
    "    trainer.fit(model, dm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%env LOCAL_RANK=0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4. Model Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "if __name__ == \"__main__\":\n",
    "    training_loop(TB_LOG_DIR = product['InsectWingbeat_MODEL_TRAINING_OUTPUT'],\n",
    "                  max_epochs = MAX_EPOCHS, \n",
    "                  config = get_best_model_config())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5. Training Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "import shutil\n",
    "source_file = glob.glob(product['InsectWingbeat_MODEL_TRAINING_CHECKPOINT_OUTPUT'] + '/*.ckpt')[0]\n",
    "print(source_file)\n",
    "shutil.copyfile(source_file, product['InsectWingbeat_BEST_MODEL'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#%load_ext tensorboard\n",
    "#%tensorboard --logdir experiments_result/InsectWingbeat/vittsc_mask --port 8199"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "__We shutdown the kernel!!!__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nbdev import nbdev_export\n",
    "nbdev_export()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "__Multi-GPU Training__"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "rapids-22.08_ploomber",
   "language": "python",
   "name": "rapids-22.08_ploomber"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "papermill": {
   "environment_variables": {},
   "parameters": {
    "product": {
     "InsectWingbeat_BEST_MODEL": "/home/ubuntu/vitmtsc_nbdev/output/InsectWingbeat/experiments_result/best_model.ckpt",
     "InsectWingbeat_BEST_MODEL_CONFIG": "/home/ubuntu/vitmtsc_nbdev/output/InsectWingbeat/experiments_result/best_model_config.json",
     "InsectWingbeat_MODEL_TRAINING_CHECKPOINT_OUTPUT": "/home/ubuntu/vitmtsc_nbdev/output/InsectWingbeat/experiments_result/checkpoint",
     "InsectWingbeat_MODEL_TRAINING_OUTPUT": "/home/ubuntu/vitmtsc_nbdev/output/InsectWingbeat/experiments_result",
     "InsectWingbeat_MODEL_TUNE_OUTPUT": "/home/ubuntu/vitmtsc_nbdev/output/InsectWingbeat/ray_results",
     "nb": "/home/ubuntu/vitmtsc_nbdev/output/402_model.optimization.nn.tsc.vittsc.insect_wingbeat_training_mask_tune.html"
    },
    "upstream": {
     "tabular_to_timeseries_insect_wingbeat": {
      "InsectWingbeat_TEST_MODEL_INPUT": "/home/ubuntu/vitmtsc_nbdev/output/InsectWingbeat/target_encoding-nn/test",
      "InsectWingbeat_TRAIN_MODEL_INPUT": "/home/ubuntu/vitmtsc_nbdev/output/InsectWingbeat/target_encoding-nn/train",
      "InsectWingbeat_VALID_MODEL_INPUT": "/home/ubuntu/vitmtsc_nbdev/output/InsectWingbeat/target_encoding-nn/valid",
      "nb": "/home/ubuntu/vitmtsc_nbdev/output/302_feature_preprocessing.insect_wingbeat.tabular_to_timeseries.html"
     }
    }
   },
   "version": null
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
