{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp model.optimization.nn.tsc.vittsc.face_detection_evaluation_mask\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "# declare a list tasks whose products you want to use as inputs\n",
    "upstream = ['tabular_to_timeseries_face_detection', 'model_training_face_detection']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "import sys\n",
    "import pathlib as p\n",
    "\n",
    "def is_running_from_ipython():\n",
    "    from IPython import get_ipython\n",
    "    return get_ipython() is not None\n",
    "\n",
    "if not is_running_from_ipython() and __package__ is None:\n",
    "    DIR = p.Path(__file__).resolve().parent\n",
    "    sys.path.insert(0, str(DIR.parent))\n",
    "    __package__ = DIR.name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "from vitmtsc.model.optimization.nn.tsc.vittsc.face_detection_training_mask_tune import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# |export\n",
    "upstream = {\n",
    "    \"tabular_to_timeseries_face_detection\": {\n",
    "        \"nb\": \"/home/ubuntu/vitmtsc_nbdev/output/301_feature_preprocessing.face_detection.tabular_to_timeseries.html\",\n",
    "        \"FaceDetection_TRAIN_MODEL_INPUT\": \"/home/ubuntu/vitmtsc_nbdev/output/FaceDetection/target_encoding-nn/train\",\n",
    "        \"FaceDetection_VALID_MODEL_INPUT\": \"/home/ubuntu/vitmtsc_nbdev/output/FaceDetection/target_encoding-nn/valid\",\n",
    "        \"FaceDetection_TEST_MODEL_INPUT\": \"/home/ubuntu/vitmtsc_nbdev/output/FaceDetection/target_encoding-nn/test\",\n",
    "    },\n",
    "    \"model_training_face_detection\": {\n",
    "        \"nb\": \"/home/ubuntu/vitmtsc_nbdev/output/401_model.optimization.nn.tsc.vittsc.face_detection_training_mask_tune.html\",\n",
    "        \"FaceDetection_MODEL_TUNE_OUTPUT\": \"/home/ubuntu/vitmtsc_nbdev/output/FaceDetection/ray_results\",\n",
    "        \"FaceDetection_MODEL_TRAINING_OUTPUT\": \"/home/ubuntu/vitmtsc_nbdev/output/FaceDetection/experiments_result\",\n",
    "        \"FaceDetection_MODEL_TRAINING_CHECKPOINT_OUTPUT\": \"/home/ubuntu/vitmtsc_nbdev/output/FaceDetection/experiments_result/checkpoint\",\n",
    "        \"FaceDetection_BEST_MODEL\": \"/home/ubuntu/vitmtsc_nbdev/output/FaceDetection/experiments_result/best_model.ckpt\",\n",
    "        \"FaceDetection_BEST_MODEL_CONFIG\": \"/home/ubuntu/vitmtsc_nbdev/output/FaceDetection/experiments_result/best_model_config.json\",\n",
    "    },\n",
    "}\n",
    "product = {\n",
    "    \"nb\": \"/home/ubuntu/vitmtsc_nbdev/output/501_model.optimization.nn.tsc.vittsc.face_detection_evaluation_mask.html\",\n",
    "    \"FaceDetection_MODEL_VALID_EVAL_OUTPUT\": \"/home/ubuntu/vitmtsc_nbdev/output/FaceDetection/experiments_result/evaluation/valid\",\n",
    "    \"FaceDetection_MODEL_TEST_EVAL_OUTPUT\": \"/home/ubuntu/vitmtsc_nbdev/output/FaceDetection/experiments_result/evaluation/test\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# |export\n",
    "import json\n",
    "def get_best_model_config():\n",
    "    with open(upstream['model_training_face_detection']['FaceDetection_BEST_MODEL_CONFIG'], 'r') as json_file:\n",
    "        return json.load(json_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "import pandas as pd\n",
    "import os\n",
    "import torch\n",
    "import math\n",
    "import glob\n",
    "\n",
    "import pytorch_lightning as pl\n",
    "from torch.nn import functional as F\n",
    "import matplotlib.pyplot as plt\n",
    "import scikitplot as skplt\n",
    "from pytorch_lightning import LightningModule\n",
    "from pytorch_lightning import Trainer\n",
    "from petastorm import make_batch_reader\n",
    "from petastorm.pytorch import DataLoader"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Vision Transformer for Multivariate Time-Series Classification (VitMTSC) Model with Masking - Evaluation\n",
    "> Load Model\n",
    "\n",
    "> __Model Evaluation__: Evaluate Model on test and validation dataset using PR-AUC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "DATASET_NAME = 'FaceDetection'\n",
    "VALID_DATA_DIR = f\"file://{upstream['tabular_to_timeseries_face_detection']['FaceDetection_VALID_MODEL_INPUT']}\"\n",
    "TEST_DATA_DIR = f\"file://{upstream['tabular_to_timeseries_face_detection']['FaceDetection_TEST_MODEL_INPUT']}\"\n",
    "VALID_EVAL_OUTPUT_DIR = product['FaceDetection_MODEL_VALID_EVAL_OUTPUT']\n",
    "TEST_EVAL_OUTPUT_DIR = product['FaceDetection_MODEL_TEST_EVAL_OUTPUT']\n",
    "BEST_MODEL_CHECKPOINT = upstream['model_training_face_detection']['FaceDetection_BEST_MODEL']\n",
    "NUM_WORKERS=1\n",
    "SHARD_COUNT=1\n",
    "BATCH_SIZE = 64\n",
    "TOTAL_VALID_BATCHES = math.ceil(get_valid_dataset_size()/BATCH_SIZE)\n",
    "TOTAL_TEST_BATCHES = math.ceil(get_test_dataset_size()/BATCH_SIZE)\n",
    "BEST_MODEL_CHECKPOINT, TOTAL_VALID_BATCHES, TOTAL_TEST_BATCHES, VALID_DATA_DIR, TEST_DATA_DIR, VALID_EVAL_OUTPUT_DIR, TEST_EVAL_OUTPUT_DIR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir -p $VALID_EVAL_OUTPUT_DIR\n",
    "!mkdir -p $TEST_EVAL_OUTPUT_DIR"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. VitMTSC Classification Prediction Task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class VitMTSCClassificationPredictionTask(LightningModule):\n",
    "    def __init__(self, \n",
    "                 model,\n",
    "                 output_pred_dir,\n",
    "                 input_data_dir,\n",
    "                 batch_size=BATCH_SIZE,\n",
    "                 num_workers=NUM_WORKERS,\n",
    "                 shard_count = SHARD_COUNT):\n",
    "        super().__init__()\n",
    "        pl.seed_everything(42, workers=True)\n",
    "        self.model = model\n",
    "        self.case_id = []\n",
    "        self.probability_0 = []\n",
    "        self.probability_1 = []\n",
    "        self.prediction = []\n",
    "        self.target = []\n",
    "        self.output_pred_dir = output_pred_dir\n",
    "        self.input_data_dir = input_data_dir\n",
    "        self.prediction_files = input_data_dir\n",
    "        self.batch_size = batch_size\n",
    "        self.num_workers = num_workers\n",
    "        self.shard_count = shard_count\n",
    "    \n",
    "    def test_step(self, batch, batch_idx):\n",
    "        x, y, case_id_1, mask = batch\n",
    "        y_hat = self.model(x, mask)\n",
    "        self.case_id.extend(case_id_1.to('cpu').numpy())\n",
    "        self.probability_0.extend(F.softmax(y_hat, dim=1)[:,0].to('cpu').numpy())\n",
    "        self.probability_1.extend(F.softmax(y_hat, dim=1)[:,1].to('cpu').numpy())\n",
    "        self.prediction.extend(torch.max(y_hat.data, 1)[1].to('cpu').numpy())\n",
    "        self.target.extend(y.to('cpu').numpy())\n",
    "        \n",
    "    def test_dataloader(self):\n",
    "        print('test_dataloader: local rank :', int(os.environ['LOCAL_RANK']), 'shard count: ', self.shard_count)\n",
    "        self.test_ds = make_batch_reader(self.prediction_files, workers_count=self.num_workers, \n",
    "                                         cur_shard = int(os.environ['LOCAL_RANK']), \n",
    "                                         shard_count = self.shard_count, num_epochs = 2)\n",
    "        return DataLoader(self.test_ds, batch_size = self.batch_size, collate_fn= petastorm_collate_fn) \n",
    "    \n",
    "    def test_epoch_end(self, outputs):\n",
    "        print('Consolidating predictions on GPU:', os.environ['LOCAL_RANK'])\n",
    "        df_text_predictions = pd.DataFrame({'case_id': self.case_id, \n",
    "                                            'probability_0': self.probability_0, \n",
    "                                            'probability_1': self.probability_1,\n",
    "                                            'prediction': self.prediction,\n",
    "                                            'target': self.target\n",
    "                                            })\n",
    "        print('Writing predictions on GPU:', os.environ['LOCAL_RANK'])\n",
    "        df_text_predictions.to_csv(self.output_pred_dir + \"/\" + os.environ['LOCAL_RANK'] + '_predictions.csv', index=False)\n",
    "        print('Finished Writing predictions on GPU:', os.environ['LOCAL_RANK'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def get_model_for_prediction(BEST_MODEL_CHECKPOINT, config, output_pred_dir, input_data_dir, shard_count = SHARD_COUNT):\n",
    "    # load the best model\n",
    "    pl.seed_everything(42, workers=True)\n",
    "    model = VitTimeSeriesTransformer.load_from_checkpoint(BEST_MODEL_CHECKPOINT, config = config)\n",
    "    model.eval()\n",
    "    return VitMTSCClassificationPredictionTask(model = model, shard_count = shard_count, output_pred_dir = output_pred_dir, input_data_dir = input_data_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export   \n",
    "def write_prediction_for_valid_dataset(BEST_MODEL_CHECKPOINT,\n",
    "                                      config,\n",
    "                                      shard_count,\n",
    "                                      output_pred_dir = VALID_EVAL_OUTPUT_DIR,\n",
    "                                      input_data_dir=VALID_DATA_DIR):\n",
    "    pl.seed_everything(42, workers=True)\n",
    "    model = get_model_for_prediction(BEST_MODEL_CHECKPOINT = BEST_MODEL_CHECKPOINT, \n",
    "                                     config = config,\n",
    "                                     shard_count = shard_count, \n",
    "                                     output_pred_dir = output_pred_dir, \n",
    "                                     input_data_dir = input_data_dir)\n",
    "    trainer = Trainer(gpus = [0], \n",
    "                      accelerator='dp', \n",
    "                      progress_bar_refresh_rate=1, \n",
    "                      limit_test_batches = TOTAL_VALID_BATCHES)\n",
    "    trainer.test(model)\n",
    "    \n",
    "def write_prediction_for_test_dataset(BEST_MODEL_CHECKPOINT,\n",
    "                                      config,\n",
    "                                      shard_count,\n",
    "                                      output_pred_dir = TEST_EVAL_OUTPUT_DIR,\n",
    "                                      input_data_dir=TEST_DATA_DIR):\n",
    "    pl.seed_everything(42, workers=True)\n",
    "    model = get_model_for_prediction(BEST_MODEL_CHECKPOINT = BEST_MODEL_CHECKPOINT, \n",
    "                                     config = config,\n",
    "                                     shard_count = shard_count, \n",
    "                                     output_pred_dir = output_pred_dir, \n",
    "                                     input_data_dir = input_data_dir)\n",
    "    trainer = Trainer(gpus = [0], \n",
    "                      accelerator='dp', \n",
    "                      progress_bar_refresh_rate=1, \n",
    "                      limit_test_batches = TOTAL_TEST_BATCHES)\n",
    "    trainer.test(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%env LOCAL_RANK=0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "if __name__ == \"__main__\":\n",
    "    print('Processing valid dataset...\\n')\n",
    "    write_prediction_for_valid_dataset(BEST_MODEL_CHECKPOINT = BEST_MODEL_CHECKPOINT, \n",
    "                                      config = get_best_model_config(),\n",
    "                                      shard_count = SHARD_COUNT)\n",
    "    print('Finished Processing valid dataset!!!\\n')\n",
    "    \n",
    "    print('Processing test dataset...\\n')\n",
    "    write_prediction_for_test_dataset(BEST_MODEL_CHECKPOINT = BEST_MODEL_CHECKPOINT, \n",
    "                                      config = get_best_model_config(),\n",
    "                                      shard_count = SHARD_COUNT)\n",
    "    print('Finished Processing test dataset!!!\\n')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Valid Dataset Prediction Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import scikitplot as skplt\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.metrics import f1_score\n",
    "from sklearn.metrics import f1_score\n",
    "\n",
    "valid_gdf = pd.concat(map(pd.read_csv, glob.glob(f'{VALID_EVAL_OUTPUT_DIR}/*.csv')))\n",
    "valid_gdf['target'] = valid_gdf['target'].astype('int64')\n",
    "valid_gdf['case_id'] = valid_gdf['case_id'].astype('int64')\n",
    "valid_gdf = valid_gdf.drop_duplicates()\n",
    "valid_gdf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "valid_gdf[valid_gdf.prediction == valid_gdf.target].count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "valid_gdf['target'].min(), valid_gdf['prediction'].min(), valid_gdf['target'].max(), valid_gdf['prediction'].max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "skplt.metrics.plot_precision_recall(valid_gdf['target'].to_numpy(), \n",
    "                                    valid_gdf[['probability_0', 'probability_1']].to_numpy(), \n",
    "                                    cmap='nipy_spectral')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "skplt.metrics.plot_roc(valid_gdf['target'].to_numpy(), \n",
    "                       valid_gdf[['probability_0', 'probability_1']].to_numpy(), \n",
    "                       cmap='nipy_spectral')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f1_score(valid_gdf['target'], valid_gdf['prediction'], average='macro')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f1_score(valid_gdf['target'], valid_gdf['prediction'], average='weighted')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Test Dataset Prediction Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_gdf = pd.concat(map(pd.read_csv, glob.glob(f'{TEST_EVAL_OUTPUT_DIR}/*.csv')))\n",
    "test_gdf['target'] = test_gdf['target'].astype('int64')\n",
    "test_gdf['case_id'] = test_gdf['case_id'].astype('int64')\n",
    "test_gdf = test_gdf.drop_duplicates()\n",
    "test_gdf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_gdf[test_gdf.prediction == test_gdf.target].count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_gdf['target'].min(), test_gdf['prediction'].min(), test_gdf['target'].max(), test_gdf['prediction'].max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "skplt.metrics.plot_precision_recall(test_gdf['target'].to_numpy(), \n",
    "                                    test_gdf[['probability_0', 'probability_1']].to_numpy(), \n",
    "                                    cmap='nipy_spectral')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "skplt.metrics.plot_roc(test_gdf['target'].to_numpy(), \n",
    "                       test_gdf[['probability_0', 'probability_1']].to_numpy(), \n",
    "                       cmap='nipy_spectral')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f1_score(test_gdf['target'], test_gdf['prediction'], average='macro')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f1_score(test_gdf['target'], test_gdf['prediction'], average='weighted')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "__We shutdown the kernel!!!__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nbdev import nbdev_export\n",
    "nbdev_export()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "rapids-22.08_ploomber",
   "language": "python",
   "name": "rapids-22.08_ploomber"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "papermill": {
   "environment_variables": {},
   "parameters": {
    "product": {
     "FaceDetection_MODEL_TEST_EVAL_OUTPUT": "/home/ubuntu/vitmtsc_nbdev/output/FaceDetection/experiments_result/evaluation/test",
     "FaceDetection_MODEL_VALID_EVAL_OUTPUT": "/home/ubuntu/vitmtsc_nbdev/output/FaceDetection/experiments_result/evaluation/valid",
     "nb": "/home/ubuntu/vitmtsc_nbdev/output/501_model.optimization.nn.tsc.vittsc.face_detection_evaluation_mask.html"
    },
    "upstream": {
     "model_training_face_detection": {
      "FaceDetection_BEST_MODEL": "/home/ubuntu/vitmtsc_nbdev/output/FaceDetection/experiments_result/best_model.ckpt",
      "FaceDetection_BEST_MODEL_CONFIG": "/home/ubuntu/vitmtsc_nbdev/output/FaceDetection/experiments_result/best_model_config.json",
      "FaceDetection_MODEL_TRAINING_CHECKPOINT_OUTPUT": "/home/ubuntu/vitmtsc_nbdev/output/FaceDetection/experiments_result/checkpoint",
      "FaceDetection_MODEL_TRAINING_OUTPUT": "/home/ubuntu/vitmtsc_nbdev/output/FaceDetection/experiments_result",
      "FaceDetection_MODEL_TUNE_OUTPUT": "/home/ubuntu/vitmtsc_nbdev/output/FaceDetection/ray_results",
      "nb": "/home/ubuntu/vitmtsc_nbdev/output/401_model.optimization.nn.tsc.vittsc.face_detection_training_mask_tune.html"
     },
     "tabular_to_timeseries_face_detection": {
      "FaceDetection_TEST_MODEL_INPUT": "/home/ubuntu/vitmtsc_nbdev/output/FaceDetection/target_encoding-nn/test",
      "FaceDetection_TRAIN_MODEL_INPUT": "/home/ubuntu/vitmtsc_nbdev/output/FaceDetection/target_encoding-nn/train",
      "FaceDetection_VALID_MODEL_INPUT": "/home/ubuntu/vitmtsc_nbdev/output/FaceDetection/target_encoding-nn/valid",
      "nb": "/home/ubuntu/vitmtsc_nbdev/output/301_feature_preprocessing.face_detection.tabular_to_timeseries.html"
     }
    }
   },
   "version": null
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
