{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "import cv2\n",
    "import re"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Specify path to preprocessed dataset folder\n",
    "preprocessed_dataset_path = '/your/preprocessed/dataset/path/here'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check if the preprocessed dataset path is valid\n",
    "# If valid, gather metadata\n",
    "\n",
    "# Get a list of all input files in the preprocessed dataset folder\n",
    "input_files = [filename for filename in os.listdir(preprocessed_dataset_path) if filename.endswith(\".npy\") and re.match(r\".*input\\d+\\.npy\", filename)]\n",
    "label_files = []\n",
    "\n",
    "# Check if each input file has a corresponding label file\n",
    "isValid = True\n",
    "for input_file in input_files:\n",
    "    label_filename = input_file.replace(\"input\", \"label\")\n",
    "    if label_filename not in os.listdir(preprocessed_dataset_path):\n",
    "        isValid = False\n",
    "        print(f\"Label file {label_filename} is missing for input file {input_file}!\")\n",
    "    else:\n",
    "        label_files.append(label_filename)\n",
    "if isValid == True:\n",
    "    print(f\"{len(input_files)} input chunks and {len(label_files)} label chunks detected.\")\n",
    "    # Extract SizeW value\n",
    "    match_size_w = re.search(r\"SizeW(\\d+)\", preprocessed_dataset_path)\n",
    "    size_w = match_size_w.group(1) if match_size_w else None\n",
    "\n",
    "    # Extract SizeH value\n",
    "    match_size_h = re.search(r\"SizeH(\\d+)\", preprocessed_dataset_path)\n",
    "    size_h = match_size_h.group(1) if match_size_h else None\n",
    "\n",
    "    # Extract ClipLength value\n",
    "    match_clip_length = re.search(r\"ClipLength(\\d+)\", preprocessed_dataset_path)\n",
    "    clip_length = match_clip_length.group(1) if match_clip_length else None\n",
    "\n",
    "    print(f\"Preprocessed data has a width of {size_w}, height of {size_h}, and a clip length of {clip_length}.\")\n",
    "else:\n",
    "    print(\"Preprocessed dataset is invalid! Please delete the preprocessed dataset folder and try to preprocess it again.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Cell to visualize RGB frames, diff normalized frames, and ground truth (label) after preprocessing\n",
    "# Each visualization is done one chunk at a time\n",
    "!pip install natsort\n",
    "!pip install scipy\n",
    "!pip install ipywidgets\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import ipywidgets as widgets\n",
    "from IPython.display import display, clear_output\n",
    "from natsort import natsorted\n",
    "import os\n",
    "import scipy\n",
    "from scipy.sparse import spdiags\n",
    "from scipy.signal import butter\n",
    "import math\n",
    "from scipy import linalg\n",
    "from scipy import signal\n",
    "from scipy import sparse\n",
    "\n",
    "class Visualizer:\n",
    "    def __init__(self, preprocessed_dataset_path, input_files, label_files):\n",
    "        self.input_files = natsorted(input_files)\n",
    "        self.label_files = natsorted(label_files)\n",
    "        self.chunk_dropdown = None\n",
    "        self.fig = None\n",
    "        self.ax = None\n",
    "        self.frame_index_slider = None\n",
    "        self.current_chunk_index = 0\n",
    "        self.current_frame_index = 0\n",
    "        self.dataset_path = preprocessed_dataset_path\n",
    "\n",
    "        self.create_visualization()\n",
    "\n",
    "    def create_visualization(self):\n",
    "        self.create_dropdown()\n",
    "        self.create_figure()\n",
    "        self.create_slider()\n",
    "\n",
    "    def create_dropdown(self):\n",
    "        # Chunk dropdown\n",
    "        self.chunk_dropdown = widgets.Dropdown(\n",
    "            options=self.input_files,\n",
    "            description='Chunk:',\n",
    "            layout=widgets.Layout(width='800px'),\n",
    "            value=None  # Set initial value to None\n",
    "        )\n",
    "        self.chunk_dropdown.observe(self.chunk_selection_changed, names='value')\n",
    "\n",
    "        # Display dropdown\n",
    "        display(self.chunk_dropdown)\n",
    "\n",
    "    def create_figure(self):\n",
    "        self.fig, self.ax = plt.subplots(figsize=(8, 4))\n",
    "        self.fig.tight_layout()\n",
    "        self.ax.set_xticks([])\n",
    "        self.ax.set_yticks([])\n",
    "\n",
    "    def create_slider(self):\n",
    "        num_frames = self.get_num_frames()\n",
    "        self.frame_index_slider = widgets.IntSlider(\n",
    "            min=0,\n",
    "            max=num_frames - 1,\n",
    "            step=1,\n",
    "            value=0,\n",
    "            description='Frame:',\n",
    "            layout=widgets.Layout(width='800px')\n",
    "        )\n",
    "        self.frame_index_slider.observe(self.frame_index_changed, names='value')\n",
    "\n",
    "        # Display slider\n",
    "        display(self.frame_index_slider)\n",
    "\n",
    "    def chunk_selection_changed(self, change):\n",
    "        selected_chunk = change.new\n",
    "        self.current_chunk_index = self.input_files.index(selected_chunk)\n",
    "        self.current_frame_index = 0\n",
    "        self.frame_index_slider.value = 0\n",
    "        self.update_frame()\n",
    "\n",
    "    def frame_index_changed(self, change):\n",
    "        self.current_frame_index = change.new\n",
    "        self.update_frame()\n",
    "\n",
    "    def get_num_frames(self):\n",
    "        input_file = self.input_files[self.current_chunk_index]\n",
    "        return np.load(os.path.join(self.dataset_path, input_file)).shape[0]\n",
    "\n",
    "    def load_frame(self):\n",
    "        input_file = self.input_files[self.current_chunk_index]\n",
    "        input_data = np.load(os.path.join(self.dataset_path, input_file))\n",
    "        print(f'The shape of the loaded chunk is {np.shape(input_data)}.')\n",
    "\n",
    "        if input_data.shape[-1] == 6:\n",
    "            rgb_frames = input_data[..., 3:]\n",
    "            diff_normalized_frames = input_data[..., :3]\n",
    "\n",
    "            # Create a side-by-side visualization\n",
    "            frame = np.concatenate((rgb_frames, diff_normalized_frames), axis=2)\n",
    "        elif input_data.shape[-1] == 3:\n",
    "            # Use RGB frames directly for visualization\n",
    "            frame = input_data\n",
    "        else:\n",
    "            raise ValueError(\"Invalid input_data shape. Expected shape (..., 3) or (..., 6).\")\n",
    "\n",
    "        return frame[self.current_frame_index]\n",
    "    \n",
    "    def create_label_figure(self):\n",
    "        self.label_fig, (self.label_ax_time, self.label_ax_fft) = plt.subplots(2, 1, figsize=(8, 4))\n",
    "        self.label_ax_time.set_xlabel('Frame #')\n",
    "        self.label_ax_time.set_ylabel('Magnitude')\n",
    "        self.label_ax_fft.set_xlabel('Frequency')\n",
    "        self.label_ax_fft.set_ylabel('Magnitude')\n",
    "\n",
    "    def update_label_frame(self):\n",
    "        def _next_power_of_2(x):\n",
    "            \"\"\"Calculate the nearest power of 2.\"\"\"\n",
    "            return 1 if x == 0 else 2 ** (x - 1).bit_length()\n",
    "        \n",
    "        clear_output(wait=True)\n",
    "        label_frame = self.load_label_frame()\n",
    "\n",
    "        N = _next_power_of_2(label_frame.shape[0])\n",
    "        fs = 30\n",
    "        N = 30 * fs\n",
    "        ppg_label_fft = label_frame\n",
    "        ppg_label_f, ppg_label_pxx = scipy.signal.periodogram(ppg_label_fft, fs=fs, nfft=N, detrend=False)\n",
    "\n",
    "        self.label_ax_time.clear()\n",
    "        self.label_ax_fft.clear()\n",
    "\n",
    "        self.label_ax_time.set_xlabel('Frame #')\n",
    "        self.label_ax_time.set_ylabel('Magnitude')\n",
    "        self.label_ax_time.plot(label_frame)\n",
    "\n",
    "        self.label_ax_fft.set_xlabel('Frequency')\n",
    "        self.label_ax_fft.set_ylabel('Magnitude')\n",
    "        self.label_ax_fft.plot(ppg_label_f, ppg_label_pxx/ppg_label_pxx.max())\n",
    "\n",
    "        self.label_ax_fft.set_xlim([0, 5])      # Set x-axis limits from 0 to 5 Hz\n",
    "\n",
    "        self.label_fig.suptitle(self.label_files[self.current_chunk_index], y=1.02)\n",
    "        self.label_fig.tight_layout()\n",
    "\n",
    "    def load_label_frame(self):\n",
    "        def _process_signal(signal):\n",
    "            # Detrend and filter\n",
    "            fs = 30\n",
    "            diff_flag = False\n",
    "            use_bandpass = True\n",
    "            if diff_flag:  # if the predictions and labels are 1st derivative of PPG signal.\n",
    "                gt_bvp = _detrend(np.cumsum(signal), 100)\n",
    "            else:\n",
    "                gt_bvp = _detrend(signal, 100)\n",
    "            if use_bandpass:\n",
    "                # bandpass filter between [0.75, 2.5] Hz\n",
    "                # equals [45, 150] beats per min\n",
    "                [b, a] = butter(1, [0.75 / fs * 2, 2.5 / fs * 2], btype='bandpass')\n",
    "                signal = scipy.signal.filtfilt(b, a, np.double(signal))\n",
    "            return signal\n",
    "\n",
    "\n",
    "        def _detrend(input_signal, lambda_value):\n",
    "            \"\"\"Detrend PPG signal.\"\"\"\n",
    "            signal_length = input_signal.shape[0]\n",
    "            # observation matrix\n",
    "            H = np.identity(signal_length)\n",
    "            ones = np.ones(signal_length)\n",
    "            minus_twos = -2 * np.ones(signal_length)\n",
    "            diags_data = np.array([ones, minus_twos, ones])\n",
    "            diags_index = np.array([0, 1, 2])\n",
    "            D = spdiags(diags_data, diags_index,\n",
    "                        (signal_length - 2), signal_length).toarray()\n",
    "            detrended_signal = np.dot(\n",
    "                (H - np.linalg.inv(H + (lambda_value ** 2) * np.dot(D.T, D))), input_signal)\n",
    "            return detrended_signal\n",
    "        \n",
    "        label_file = self.label_files[self.current_chunk_index]\n",
    "        label_data = np.load(os.path.join(self.dataset_path, label_file))\n",
    "        print(f'The shape of the loaded label file is {np.shape(label_data)}.')\n",
    "\n",
    "        label_data = _process_signal(label_data)\n",
    "\n",
    "        return label_data\n",
    "\n",
    "    def update_frame(self):\n",
    "        clear_output(wait=True)\n",
    "        frame = self.load_frame()\n",
    "\n",
    "        # Clip frame to a valid range\n",
    "        if np.issubdtype(frame.dtype, np.floating):\n",
    "            frame = np.clip(frame, 0, 1)  # Clip float pixel values to [0, 1]\n",
    "        elif np.issubdtype(frame.dtype, np.integer):\n",
    "            frame = np.clip(frame, 0, 255)  # Clip integer pixel values to [0, 255]\n",
    "        else:\n",
    "            raise ValueError(\"Unsupported pixel value data type.\")\n",
    "\n",
    "        self.ax.imshow(frame)\n",
    "        self.ax.set_title(f\"Frame {self.current_frame_index}\")\n",
    "        self.fig.suptitle(self.input_files[self.current_chunk_index], y=1.02)\n",
    "        self.fig.tight_layout()\n",
    "        self.update_label_frame()\n",
    "        display(self.chunk_dropdown, self.frame_index_slider, self.fig, self.label_fig)\n",
    "\n",
    "\n",
    "# Create the visualizer instance\n",
    "visualizer = Visualizer(preprocessed_dataset_path, input_files, label_files)\n",
    "\n",
    "# Create the label figure\n",
    "visualizer.create_label_figure()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ma-rppg-video-toolbox",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
