{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "420ac8ca-113b-4626-ae9e-3f594bca7b22",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!pip install dotmap\n",
    "!pip install vitaldb \n",
    "!pip install pyPPG==1.0.41\n",
    "!pip install openpyxl\n",
    "!pip install torch_ecg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ac0944b-46aa-471f-88fe-c4ba9f8b1eaf",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import pandas as pd \n",
    "import numpy as np\n",
    "import os \n",
    "import sys\n",
    "import joblib\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "from tqdm import tqdm \n",
    "from linearprobing.utils import resample_batch_signal, load_model_without_module_prefix, get_data_for_ml\n",
    "from preprocessing.ppg import preprocess_one_ppg_signal\n",
    "from segmentations import waveform_to_segments, save_segments_to_directory\n",
    "from sklearn.model_selection import train_test_split\n",
    "from torch_ecg._preprocessors import Normalize\n",
    "from models.resnet import ResNet1D, ResNet1DMoE\n",
    "from linearprobing.feature_extraction_papagei import save_embeddings\n",
    "from linearprobing.extracted_feature_combine import segment_avg_to_dict\n",
    "from linearprobing.regression import regression_model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a41f867-8e5c-424a-b8dd-806043de84e8",
   "metadata": {},
   "source": [
    "## 1. Data\n",
    "\n",
    "- **(a)**: Download the PPG-BP data from [PPG-BP Database](https://figshare.com/articles/dataset/PPG-BP_Database_zip/5459299) to a directory.\n",
    "  - Add the download path to `download_dir`.\n",
    "- **(b)**: Use the download path for further pre-processing.\n",
    "- **(c)**: Use pre-defined user splits for later linear evaluation\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f97dd234-6ca9-4375-b6bf-f07ccdd84d48",
   "metadata": {},
   "source": [
    "#### 1. (a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97fd5528-dac5-4e9b-8fce-2b6ae9bcd03a",
   "metadata": {},
   "outputs": [],
   "source": [
    "download_dir = \"data/5459299\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a35efad9-f5b8-4720-a571-53d2063e0b5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_excel(f\"{download_dir}/Data File/PPG-BP dataset.xlsx\", header=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01597863-d4f8-45a7-a952-78f374edcb5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "subjects = df.subject_ID.values\n",
    "main_dir = f\"{download_dir}/Data File/0_subject/\"\n",
    "ppg_dir = f\"{download_dir}/Data File/ppg/\"\n",
    "\n",
    "if not os.path.exists(ppg_dir):\n",
    "    os.mkdir(ppg_dir)\n",
    "    \n",
    "fs = 1000 \n",
    "fs_target = 125"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10bf86a4-d680-431a-be40-d610bb11778c",
   "metadata": {},
   "source": [
    "#### 1. (b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32562452-deb4-4c2d-b645-4dc4b559f762",
   "metadata": {},
   "outputs": [],
   "source": [
    "filenames = [f.split(\"_\")[0] for f in os.listdir(main_dir)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "840fd007-dd42-4fff-a9b8-331da4ebaaac",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "norm = Normalize(method='z-score')\n",
    "\n",
    "for f in tqdm(filenames):\n",
    "    segments = []\n",
    "    for s in range(1, 4):\n",
    "        print(f\"Processing: {f}_{s}\")\n",
    "        signal = pd.read_csv(f\"{main_dir}{f}_{str(s)}.txt\", sep='\\t', header=None)\n",
    "        signal = signal.values.squeeze()[:-1]\n",
    "        signal, _ = norm.apply(signal, fs=fs)\n",
    "        signal, _, _, _ = preprocess_one_ppg_signal(waveform=signal,\n",
    "                                                frequency=fs)\n",
    "        signal = resample_batch_signal(signal, fs_original=fs, fs_target=fs_target, axis=0)\n",
    "        \n",
    "        padding_needed = 1250 - len(signal)\n",
    "        pad_left = padding_needed // 2\n",
    "        pad_right = padding_needed - pad_left\n",
    "        \n",
    "        signal = np.pad(signal, pad_width=(pad_left, pad_right))\n",
    "        segments.append(signal)\n",
    "    segments = np.vstack(segments)\n",
    "    child_dir = f.zfill(4)\n",
    "    save_segments_to_directory(save_dir=ppg_dir,\n",
    "                              dir_name=child_dir,\n",
    "                              segments=segments)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20ee0eab-9e3f-4be3-8a23-fc5f4780d988",
   "metadata": {},
   "source": [
    "#### 1. (c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f316a66-01da-488c-ad84-9e114ef79c2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = df.rename(columns={\"Sex(M/F)\": \"sex\",\n",
    "                   \"Age(year)\": \"age\",\n",
    "                   \"Systolic Blood Pressure(mmHg)\": \"sysbp\",\n",
    "                   \"Diastolic Blood Pressure(mmHg)\": \"diasbp\",\n",
    "                   \"Heart Rate(b/m)\": \"hr\",\n",
    "                   \"BMI(kg/m^2)\": \"bmi\"})\n",
    "df = df.fillna(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69802db7-99e6-43b0-bd02-d744884d013e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# These randomly selected subject splits used in our work.\n",
    "# We hardcode them because we cannot share it as a \"data source\".\n",
    "\n",
    "train_ids = [  2,   6,   8,  10,  12,  15,  16,  17,  18,  19,  22,  23,  26,\n",
    "        31,  32,  34,  35,  38,  40,  45,  48,  50,  53,  55,  56,  58,\n",
    "        60,  61,  63,  65,  66,  83,  85,  87,  89,  92,  93,  97,  98,\n",
    "        99, 100, 104, 105, 106, 107, 112, 113, 114, 116, 120, 122, 126,\n",
    "       128, 131, 134, 135, 137, 138, 139, 140, 141, 146, 148, 149, 152,\n",
    "       153, 154, 158, 160, 162, 164, 165, 167, 169, 170, 175, 176, 179,\n",
    "       183, 184, 186, 188, 189, 190, 191, 193, 196, 197, 199, 205, 206,\n",
    "       207, 209, 210, 212, 216, 217, 218, 223, 226, 227, 230, 231, 233,\n",
    "       234, 240, 242, 243, 244, 246, 247, 248, 256, 257, 404, 407, 409,\n",
    "       412, 414, 415, 416, 417, 419]\n",
    "\n",
    "test_ids = [14,  21,  25,  51,  52,  62,  67,  86,  90,  96, 103, 108, 110,\n",
    "       119, 123, 124, 130, 142, 144, 157, 172, 173, 174, 180, 182, 185,\n",
    "       192, 195, 200, 201, 211, 214, 219, 221, 228, 239, 250, 403, 405,\n",
    "       406, 410]\n",
    "\n",
    "val_ids = [3,  11,  24,  27,  29,  30,  41,  43,  47,  64,  88,  91,  95,\n",
    "       115, 125, 127, 136, 145, 155, 156, 161, 163, 166, 178, 198, 203,\n",
    "       208, 213, 215, 222, 229, 232, 235, 237, 241, 245, 252, 254, 259,\n",
    "       411, 418]\n",
    "\n",
    "df_train = df[df.subject_ID.isin(train_ids)]\n",
    "df_val = df[df.subject_ID.isin(val_ids)]\n",
    "df_test = df[df.subject_ID.isin(test_ids)]\n",
    "\n",
    "df_train.to_csv(f\"{download_dir}/Data File/train.csv\", index=False)\n",
    "df_val.to_csv(f\"{download_dir}/Data File/val.csv\", index=False)\n",
    "df_test.to_csv(f\"{download_dir}/Data File/test.csv\", index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "46389343-6c33-4927-aae3-df505988aa39",
   "metadata": {},
   "source": [
    "## 2. Extracting Features\n",
    "\n",
    "In this section, we describe how to load the *PaPaGei* model and extract embeddings.\n",
    "- **(a)**: Loading a model and extract embeddings from a single signal. This code will be most relevant to customize for your own problems.\n",
    "- **(b)**: Function to extract features from all ppg segments (batched).\n",
    "- **(c)**: Extracting *PaPaGei-S* features for PPG-BP.\n",
    "- **(d)**: Extracting *PaPaGei-S sVRI only* features for PPG-BP.\n",
    "- **(e)**: Extracting *PaPaGei-P* features for PPG-BP."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "003c8e82-ea67-41fa-958a-fd44d241b10a",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 256\n",
    "device = \"cuda:0\"\n",
    "case_name = \"subject_ID\"\n",
    "ppg_dir = f\"{download_dir}/Data File/ppg/\"\n",
    "\n",
    "df_train = pd.read_csv(f\"{download_dir}/Data File/train.csv\")\n",
    "df_val = pd.read_csv(f\"{download_dir}/Data File/val.csv\")\n",
    "df_test = pd.read_csv(f\"{download_dir}/Data File/test.csv\")\n",
    "\n",
    "df_train.loc[:, case_name] = df_train[case_name].apply(lambda x:str(x).zfill(4))\n",
    "df_val.loc[:, case_name] = df_val[case_name].apply(lambda x:str(x).zfill(4))\n",
    "df_test.loc[:, case_name] = df_test[case_name].apply(lambda x:str(x).zfill(4))\n",
    "\n",
    "dict_df = {'train': df_train, 'val': df_val, 'test': df_test}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "95e4c2de-0431-4808-93db-fdd58f054117",
   "metadata": {},
   "source": [
    "#### 2. (a) Code to load model and extract embeddings for one signal\n",
    "\n",
    "**Extend this code to extract features for your own datasets**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de10c957-fab1-406e-9207-8f8cbe7a69b0",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model_config = {'base_filters': 32,\n",
    "            'kernel_size': 3,\n",
    "            'stride': 2,\n",
    "            'groups': 1,\n",
    "            'n_block': 18,\n",
    "            'n_classes': 512,\n",
    "            'n_experts': 3\n",
    "            }\n",
    "\n",
    "model = ResNet1DMoE(in_channels=1, \n",
    "            base_filters=model_config['base_filters'], \n",
    "            kernel_size=model_config['kernel_size'],\n",
    "            stride=model_config['stride'],\n",
    "            groups=model_config['groups'],\n",
    "            n_block=model_config['n_block'],\n",
    "            n_classes=model_config['n_classes'],\n",
    "            n_experts=model_config['n_experts'])\n",
    "\n",
    "model_path = \"weights/papagei_s.pt\"\n",
    "model = load_model_without_module_prefix(model, model_path)\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "713bdbb3-d5ec-4286-9974-9b18592e1718",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load a signal\n",
    "# note that that signal is already resampled and normalized\n",
    "signal = joblib.load(os.path.join(ppg_dir, '0002', '0.p'))\n",
    "signal = torch.Tensor(signal)\n",
    "signal = signal[None, :].unsqueeze(dim=0)\n",
    "print(f\"PPG dimensions before inference : {signal.shape}\")\n",
    "\n",
    "model.eval()\n",
    "with torch.inference_mode():\n",
    "    signal = signal.to(device)\n",
    "    outputs = model(signal)\n",
    "    embeddings = outputs[0].cpu().detach().numpy()\n",
    "print(f\"Embedding dimensions : {embeddings.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cffc787b-7c00-45bd-8540-00ce49a465f8",
   "metadata": {},
   "source": [
    "#### 2. (b) Extracting features for all ppg data and saving them for downstream use"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d59ffe52-064e-4da1-8ce7-5ed7ea0c097f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def extract_features_and_save(model, ppg_dir, batch_size, device, output_idx, resample, normalize, fs, fs_target, content):\n",
    "    \"\"\"\n",
    "    Function to extract features and save them\n",
    "    \"\"\"\n",
    "    for split in ['train', 'val', 'test']:\n",
    "        # Choose one split at a time\n",
    "        df = dict_df[split]\n",
    "        save_dir = f\"{download_dir}/features\"\n",
    "\n",
    "        # Creating require directory structure and names\n",
    "        if not os.path.exists(f\"{save_dir}\"):\n",
    "            os.mkdir(f\"{save_dir}\")\n",
    "        \n",
    "        model_name = model_path.split(\"/\")[-1].split(\".pt\")[0]\n",
    "        if not os.path.exists(f\"{save_dir}/{model_name}\"):\n",
    "            os.mkdir(f\"{save_dir}/{model_name}\")\n",
    "        split_dir = f\"{save_dir}/{model_name}/{split}/\"\n",
    "        \n",
    "        child_dirs = np.unique(df[case_name].values)\n",
    "\n",
    "        # Function that extracts and saves embeddings\n",
    "        save_embeddings(path=ppg_dir,\n",
    "                        child_dirs=child_dirs, \n",
    "                        save_dir=split_dir, \n",
    "                        model=model, \n",
    "                        batch_size=batch_size, \n",
    "                        device=device, \n",
    "                        output_idx=output_idx,\n",
    "                        resample=resample, \n",
    "                        normalize=normalize, \n",
    "                        fs=fs, \n",
    "                        fs_target=fs_target)\n",
    "        \n",
    "        # Compile the extracted embeddings at the patient or segment level adn save it               \n",
    "        dict_feat = segment_avg_to_dict(split_dir, content)\n",
    "        joblib.dump(dict_feat, f\"{save_dir}/{model_name}/dict_{split}_{content}.p\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39ca9253-fa95-4eac-8812-973a2093b226",
   "metadata": {},
   "source": [
    "#### 2. (c) Extraction: PaPaGei-S"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d5b2ead-b54d-40dc-9413-7bff9d8848e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_config = {'base_filters': 32,\n",
    "            'kernel_size': 3,\n",
    "            'stride': 2,\n",
    "            'groups': 1,\n",
    "            'n_block': 18,\n",
    "            'n_classes': 512,\n",
    "            'n_experts': 3\n",
    "            }\n",
    "\n",
    "model = ResNet1DMoE(in_channels=1, \n",
    "            base_filters=model_config['base_filters'], \n",
    "            kernel_size=model_config['kernel_size'],\n",
    "            stride=model_config['stride'],\n",
    "            groups=model_config['groups'],\n",
    "            n_block=model_config['n_block'],\n",
    "            n_classes=model_config['n_classes'],\n",
    "            n_experts=model_config['n_experts'])\n",
    "\n",
    "model_path = \"weights/papagei_s.pt\"\n",
    "model = load_model_without_module_prefix(model, model_path)\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dbf9325d-58bc-44b0-ba25-0b210f52f940",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "extract_features_and_save(model=model,\n",
    "                         ppg_dir=ppg_dir,\n",
    "                         batch_size=batch_size,\n",
    "                         device=device,\n",
    "                         output_idx=0,\n",
    "                         resample=False,\n",
    "                         normalize=False,\n",
    "                         fs=125,\n",
    "                         fs_target=125,\n",
    "                         content=\"patient\"\n",
    "                         )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17387f1c-1e62-4dc4-9cca-bfedef5b57e6",
   "metadata": {},
   "source": [
    "#### 2. (d) Extraction: PaPaGei-S sVRI only"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "156e724d-a645-4973-b821-21fb2d7df7df",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model_config = {'base_filters': 32,\n",
    "        'kernel_size': 3,\n",
    "        'stride': 2,\n",
    "        'groups': 1,\n",
    "        'n_block': 18,\n",
    "        'n_classes': 512,\n",
    "        }\n",
    "\n",
    "model = ResNet1D(in_channels=1, \n",
    "            base_filters=model_config['base_filters'], \n",
    "            kernel_size=model_config['kernel_size'],\n",
    "            stride=model_config['stride'],\n",
    "            groups=model_config['groups'],\n",
    "            n_block=model_config['n_block'],\n",
    "            n_classes=model_config['n_classes'],\n",
    "            use_mt_regression=False,\n",
    "            use_projection=False)\n",
    "\n",
    "model_path = \"weights/papagei_s_svri.pt\"\n",
    "model = load_model_without_module_prefix(model, model_path)\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4536363-acd8-4d1b-aa66-5bfa410b365a",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "extract_features_and_save(model=model,\n",
    "                         ppg_dir=ppg_dir,\n",
    "                         batch_size=batch_size,\n",
    "                         device=device,\n",
    "                         output_idx=0,\n",
    "                         resample=False,\n",
    "                         normalize=False,\n",
    "                         fs=125,\n",
    "                         fs_target=125,\n",
    "                         content=\"patient\"\n",
    "                         )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0059424f-262a-4f6e-a14a-7a7294fbd9db",
   "metadata": {},
   "source": [
    "#### 2. (e) Extraction: PaPaGei-P"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e765facc-55ce-442a-b46c-05811e759e8b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model_config = {'base_filters': 32,\n",
    "                'kernel_size': 3,\n",
    "                'stride': 2,\n",
    "                'groups': 1,\n",
    "                'n_block': 18,\n",
    "                'n_classes': 512,\n",
    "                }\n",
    "\n",
    "model = ResNet1D(in_channels=1, \n",
    "            base_filters=model_config['base_filters'], \n",
    "            kernel_size=model_config['kernel_size'],\n",
    "            stride=model_config['stride'],\n",
    "            groups=model_config['groups'],\n",
    "            n_block=model_config['n_block'],\n",
    "            n_classes=model_config['n_classes'])\n",
    "\n",
    "model_path = \"weights/papagei_p.pt\"\n",
    "model = load_model_without_module_prefix(model, model_path)\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d599e08-f2f7-4387-a0ad-b42aa1095527",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "extract_features_and_save(model=model,\n",
    "                         ppg_dir=ppg_dir,\n",
    "                         batch_size=batch_size,\n",
    "                         device=device,\n",
    "                         output_idx=0,\n",
    "                         resample=False,\n",
    "                         normalize=False,\n",
    "                         fs=125,\n",
    "                         fs_target=125,\n",
    "                         content=\"patient\"\n",
    "                         )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "875d96a2-697a-4715-a903-6a75358cd10b",
   "metadata": {},
   "source": [
    "## 3. Linear Evaluation\n",
    "\n",
    "In this section, we use the extracted embeddings to predict diastolic BP, systolic BP, and Average HR."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0b24f3d-aea1-4b5a-80d8-9bb5a2fc79c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.linear_model import Ridge\n",
    "from sklearn.ensemble import RandomForestRegressor\n",
    "from sklearn.metrics import mean_absolute_error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d730665-3a68-4ba6-9609-b684f0c4cb75",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_dir = f\"{download_dir}/features/\"\n",
    "df_train = pd.read_csv(f\"{download_dir}/Data File/train.csv\")\n",
    "df_val = pd.read_csv(f\"{download_dir}/Data File/val.csv\")\n",
    "df_test = pd.read_csv(f\"{download_dir}/Data File/test.csv\")\n",
    "\n",
    "df_train.loc[:, case_name] = df_train[case_name].apply(lambda x:str(x).zfill(4))\n",
    "df_val.loc[:, case_name] = df_val[case_name].apply(lambda x:str(x).zfill(4))\n",
    "df_test.loc[:, case_name] = df_test[case_name].apply(lambda x:str(x).zfill(4))\n",
    "                                                                          \n",
    "case_name = \"subject_ID\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df175df6-089a-4587-8540-8139f5b1386e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def regression(save_dir, model_name, content, df_train, df_val, df_test, case_name, label):\n",
    "    \n",
    "    dict_train = joblib.load(f\"{save_dir}/{model_name}/dict_train_{content}.p\")\n",
    "    dict_val = joblib.load(f\"{save_dir}/{model_name}/dict_val_{content}.p\")\n",
    "    dict_test = joblib.load(f\"{save_dir}/{model_name}/dict_test_{content}.p\")\n",
    "    \n",
    "    X_train, y_train, _ = get_data_for_ml(df=df_train,\n",
    "                                     dict_embeddings=dict_train,\n",
    "                                     case_name=case_name,\n",
    "                                     label=label)\n",
    "\n",
    "    X_val, y_val, _ = get_data_for_ml(df=df_val,\n",
    "                                         dict_embeddings=dict_val,\n",
    "                                         case_name=case_name,\n",
    "                                         label=label)\n",
    "    \n",
    "    X_test, y_test, _ = get_data_for_ml(df=df_test,\n",
    "                                         dict_embeddings=dict_test,\n",
    "                                         case_name=case_name,\n",
    "                                         label=label)\n",
    "    \n",
    "    X_test = np.concatenate((X_test, X_val))\n",
    "    y_test = np.concatenate((y_test, y_val))\n",
    "\n",
    "    estimator = Ridge()\n",
    "    param_grid = {\n",
    "        'alpha': [0.1, 1.0, 10.0, 100.0, 1000.0],  # Regularization strength\n",
    "        'solver': ['auto', 'cholesky', 'sparse_cg']  # Solver to use in the computational routines\n",
    "    }\n",
    "    \n",
    "    results = regression_model(estimator=estimator,\n",
    "                param_grid=param_grid,\n",
    "                X_train=X_train,\n",
    "                y_train=y_train,\n",
    "                X_test=X_test,\n",
    "                y_test=y_test)\n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd11f74a-74cc-4f98-a365-0a79c73919d7",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "results_papagei_s = regression(save_dir=save_dir,\n",
    "                                    model_name='papagei_s',\n",
    "                                    content=\"patient\",\n",
    "                                    df_train=df_train,\n",
    "                                    df_val=df_val,\n",
    "                                    df_test=df_test,\n",
    "                                    case_name=case_name,\n",
    "                                    label=\"diasbp\")\n",
    "\n",
    "results_papagei_svri = regression(save_dir=save_dir,\n",
    "                                    model_name='papagei_s_svri',\n",
    "                                    content=\"patient\",\n",
    "                                    df_train=df_train,\n",
    "                                    df_val=df_val,\n",
    "                                    df_test=df_test,\n",
    "                                    case_name=case_name,\n",
    "                                    label=\"hr\")\n",
    "\n",
    "results_papagei_p = regression(save_dir=save_dir,\n",
    "                                    model_name='papagei_p',\n",
    "                                    content=\"patient\",\n",
    "                                    df_train=df_train,\n",
    "                                    df_val=df_val,\n",
    "                                    df_test=df_test,\n",
    "                                    case_name=case_name,\n",
    "                                    label=\"sysbp\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1dedb1ab-fef8-42b2-ad07-ecf7f2b9997f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "print(f\"PaPaGei-S Diastolic BP MAE: {results_papagei_s['mae']}\")\n",
    "print(f\"PaPaGei-S sVRI HR MAE: {results_papagei_svri['mae']}\")\n",
    "print(f\"PaPaGei-P Systolic BP MAE: {results_papagei_p['mae']}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6cc6b65-f356-4e01-a25d-f3d4d874e4d4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
