# %% imports
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import h5py
import glob
import scipy
from scipy.interpolate import CubicSpline
from scipy.io import loadmat
from scipy.signal import periodogram
import re
import argparse
import neurokit2 as nk

sys.path.append("./tensorflow")

real_file_string = r"([PFM]\d+)(T\d+)\S*(C\d+)"
synthetic_file_string = r"(P\d+)(T\d+).+"
ubfc_file_string = r"(P\d+)(C\d+)"
v4v_file_string = r"(\d+)(C\d+)"
# %% functions for parsing video file names
def get_participant_id(file: str):
    """Use regex to extract participant ID from file name string
    For example: 
        get_participant_id("P101T3VideoB2C43") -> "P10"
    Args:
        file (str): path to file 

    Returns:
        str: participant ID
    """
    result = re.search(real_file_string, os.path.basename(file))
    if result is None:
        result = re.search(synthetic_file_string, os.path.basename(file))
        if result is None:
            result = re.search(ubfc_file_string, os.path.basename(file))
            if result is None:
                result = re.search(v4v_file_string, os.path.basename(file))
    return result.group(1)

def get_task_id(file: str):
    """Use regex to extract task ID from file name string
    For example: 
        get_participant_id("P101T3VideoB2C43") -> "T3"

    Args:
        file (str): path to file

    Returns:
        str: task ID
    """
    result = re.search(real_file_string, os.path.basename(file))
    if result is None:
        result = re.search(synthetic_file_string, os.path.basename(file))
    return result.group(2)

def get_chunk_id(file: str):
    """Use regex to extract chunk ID from file name string
    For example: 
        get_participant_id("P101T3VideoB2C43") -> "C43"

    Args:
        file (str): path to file

    Returns:
        str: chunk ID
    """
    result = re.search(real_file_string, os.path.basename(file))
    if result is None:
        result = re.search(synthetic_file_string, os.path.basename(file))
        if result is None:
            result = re.search(v4v_file_string, os.path.basename(file))
            return result.group(2)
    return result.group(3)

def get_window_number(file: str):
    """Extract the window number from file string name 
    For example: 
        get_participant_id("P101T3VideoB2C43_30") -> 30

    Args:
        file (str): path to file

    Returns:
        int: window number 
    """
    return int(file.split("_")[-1])

def calc_HR_freq(signal: np.ndarray, fs=30, detrend=False, 
    min_freq=0.75, max_freq=2.5):
    """Estimate heart rate (HR) frequency by calculating the frequency with 
    the highest power in the spectrum using the FFT. min_freq and max_freq 
    arguments can be used to limit the range of frequencies considered in 
    the spectrum (i.e. bandpass filtering) 

    Args:
        signal (np.ndarray): waveform signal 
        fs (int, optional): sample frequency. Defaults to 30.
        detrend (bool, optional): if True, detrend signal. Defaults to False.
        min_freq (float, optional): minimum filter frequency. Defaults to 0.75.
        max_freq (float, optional): maximum filter frequency. Defaults to 2.5.

    Returns:
        float: estimated heart rate (beats per minute)
    """
    f, pxx = periodogram(signal, fs=fs, nfft=4*30*fs, detrend=detrend)
    # regular Heart beat are 0.75*60 and 2.5*60
    fmask = np.argwhere((f >= min_freq) & (f <= max_freq))  
    # select only frequencies in the range we want
    frange = np.take(f, fmask)
    max_rate = np.take(frange, np.argmax(np.take(pxx, fmask), 0))[0] * 60
    return max_rate

def calc_HR_from_beat_intervals(peaks: np.ndarray, fs=30):
    """Calculate heart rate (HR) by converting the mean interbeat interval
    to beats-per-minute (BPM)

    Args:
        peaks (np.ndarray): indices of signal peaks
        fs (int, optional): sample frequency. Defaults to 30.

    Returns:
        np.float: estimated heart rate (beats per minute)
    """
    # 1 / [(N samples/beat) * (1 sec/30 samples) * (60 sec/1 min)] = X beats / minute
    return 1. / (np.mean(np.diff(peaks))/(fs*60))

def calc_HR_ECG(signal: np.ndarray, fs=30):
    """Estimate the heart rate (HR) from an ECG signal by calculating the 
    mean interbeat interval and using that to calculate heart rate 
    (beats per minute). This function uses a peak-finding method to find the 
    R peaks of the ECG signal, and returns the estimated HR (scaler value).

    Args:
        signal (np.ndarray): ECG waveform signal
        fs (int, optional): sample frequency. Defaults to 30.

    Returns:
        float: estimated heart rate (beats per minute)
    """
    # get time (in number of samples) where R peaks occur
    _, rpeaks = nk.ecg_peaks(signal, sampling_rate=fs)
    # use beat interval times to calculate heart rate
    return calc_HR_from_beat_intervals(rpeaks['ECG_R_Peaks'], fs=fs)

def calc_instant_HR_from_beat_intervals(peaks: np.ndarray, fs=30):
    """

    Args:
        peaks (np.ndarray): indices of PPG peaks
        fs (int, optional): sample frequency. Defaults to 30.

    Returns:
        np.ndarray: vector of length peaks.shape[0]-1 of instantaneous HR values
    """
    # 1 / [(N samples/beat)*(1 sec/30 samples)*(60 sec/1 min)] = X beats/minute
    return 1. / (np.diff(peaks)/(fs*60))

def calc_PPG_peaks(signal: np.ndarray, distance=12, height=(0.1,)):
    """Use scipy.signal.find_peaks function to get the peaks in a PPG signal

    Args:
        signal (np.ndarray): window of PPG signal
        distance (int, optional): minimum number of samples between
        consecutive PPG peaks. Defaults to 12.

    Returns:
        list(int): indices of PPG signal peaks
    """
    indices_peaks, _ = scipy.signal.find_peaks(signal, distance=(distance), height=height)
    return indices_peaks

def generate_per_frame_instant_HR(signal: np.ndarray, fs=30, filter=True, 
    min_freq=0.75, max_freq=4.0, smooth=False):
    """Calculates the instantaneous HR for each frame (sample) in an input 
    PPG window. First, PPG peaks are extracted using the scipy.signal.find_peaks
    function, and the number of samples (time) between consecutive peaks is 
    then calculated. The instantaneous HR between frames is then calculated, 
    and frames between peaks are labeled with the instantaneous HR for that 
    time period. 

    Args:
        signal (np.ndarray): window of PPG signal
        fs (int, optional): sample frequency. Defaults to 30.

    Returns:
        np.ndarray: vector containing the calculated instantaneous HR for 
        each frame (sample) in the input waveform
    """
    if filter:
    # filter signal 
        [b, a] = scipy.signal.butter(1, [min_freq / fs * 2, max_freq / fs * 2], btype='bandpass')
        signal = scipy.signal.filtfilt(b, a, np.double(signal))

    peaks = calc_PPG_peaks(signal)
    instant_HR = calc_instant_HR_from_beat_intervals(peaks, fs=fs)
    # TODO: how do we handle the frames at the beginning/end of the video
    # before/after the first/last peaks? 
    instant_HR_per_frame = np.empty(signal.shape[0])
    instant_HR_per_frame.fill(np.nan)
    # for each pair of consecutive peaks, fill the frames in between with the 
    # instantaneous HR value calculated using the time between the peaks
    for t in range(peaks.shape[0]-1):
        instant_HR_per_frame[peaks[t]:peaks[t+1]] = instant_HR[t]

    if smooth:
    # use median windows to smooth out times
        smoothed_window_size = fs * 10
        smoothed_instant_HR_per_frame = np.empty(instant_HR_per_frame.shape[0])
        for i in range(0, smoothed_instant_HR_per_frame.shape[0], smoothed_window_size):
            smoothed_instant_HR_per_frame[i:i+smoothed_window_size] = np.nanmedian(instant_HR_per_frame[i:i+smoothed_window_size])
        return smoothed_instant_HR_per_frame

    return instant_HR_per_frame

def plot_per_frame_instant_HR(signal: np.ndarray, signal2: None, save_dir=None, file_prefix=None,
    fs=30, min_freq=0.75, max_freq=4.0, font_size=14):
    """Plot PPG signal, annotated with peaks detected with peak detection
    method, alongside estimated instantaneous heart rate estimated using 
    the interbeat intervals calculated from the PPG peaks. 

    Args:
        signal (np.ndarray): PPG waveform
        signal2 (np.ndarray, optional): Comparison PPG waveform. Defaults to None.
        save_dir (str, optional): Directory to save plot. Defaults to None.
        file_prefix (str, optional): Prefix of plot file name. Defaults to None.
        fs (int, optional): sample frequency. Defaults to 30.
        min_freq (float, optional): min frequency for bandpass filter. Defaults to 0.75 Hz.
        max_freq (float, optional): max frequency for bandpass filter. Defaults to 4.0 Hz. 
        font_size (int, optional): font size for plot. 

    """
    if file_prefix:
        assert save_dir is not None
    # filter signal 
    [b, a] = scipy.signal.butter(1, [min_freq / fs * 2, max_freq / fs * 2], btype='bandpass')
    signal = scipy.signal.filtfilt(b, a, np.double(signal))
    # calculate the PPG peaks
    peaks = calc_PPG_peaks(signal)
    # generate per-frame instantaneous HR values
    instant_HR_per_frame = generate_per_frame_instant_HR(signal, fs=fs, filter=False, smooth=True)

    # plot the annotated PPG waveform and the instantaneous HR over time
    fig, ax = plt.subplots(2, 1, figsize=(12, 12),)
    ax[0].plot(signal, label="PPG")
    ax[0].scatter(peaks, signal[peaks], label="PPG Peaks")
    ax[0].tick_params(labelsize=font_size)
    # ax[1] = ax.twinx()
    ax[1].plot(instant_HR_per_frame, linestyle="--", label="Instantaneous HR")
    if signal2 is not None:
        signal2 = scipy.signal.filtfilt(b, a, np.double(signal2))
        # calculate the PPG peaks
        peaks = calc_PPG_peaks(signal2)
        # generate per-frame instantaneous HR values
        instant_HR_per_frame = generate_per_frame_instant_HR(signal2, fs=fs, filter=False, smooth=True)
        ax[0].plot(signal2, label="Predicted PPG")
        ax[0].scatter(peaks, signal2[peaks], label="Predicted PPG Peaks")
        ax[1].plot(instant_HR_per_frame, linestyle="--", label="Predicted Instantaneous HR")
    ax[1].set_xlabel("Samples (30Hz)", fontsize=font_size)
    ax[1].set_ylabel("Instantaneous HR", fontsize=font_size)
    ax[1].tick_params(labelsize=font_size)
    ax[0].legend(loc="upper left", fontsize=font_size)
    ax[1].legend(loc="upper right", fontsize=font_size)
    plt.tight_layout()
    if save_dir:
        plt.savefig(os.path.join(save_dir, f"{file_prefix}_instant_HR_v_time.png"))
        plt.savefig(os.path.join(save_dir, f"{file_prefix}_instant_HR_v_time.svg"))
    plt.show()
    plt.close()

def recover_signal_from_diff(new_signal: np.ndarray, 
    original_signal: np.ndarray):
    # get first time value from original signal 
    first_time_val = original_signal[:, 0, :1]
    # concatenate first_time_val with new_signal to add back bias that was lost
    concat = np.concatenate((first_time_val, np.squeeze(new_signal)), axis=1)
    # cumulative sum the values over time to recover signal from differences
    recovered_sig = np.cumsum(concat, axis=1)
    # detrend signal to remove drift over time
    recovered_sig = scipy.signal.detrend(recovered_sig, axis=1)
    return recovered_sig

def upsample_signal(signal, fs=30, upsample_freq=256):
    signal = signal.flatten()
    # upsample waveform to 256 Hz using cublic spline interpolation
    cs = CubicSpline(x=np.arange(signal.shape[0]), y=signal)
    xs = np.linspace(0, signal.shape[0], num=int((signal.shape[0]/fs)*upsample_freq))
    upsampled_signal = cs(xs)
    return upsampled_signal

def calc_time_sys_to_dicrotic(signal: np.ndarray, reference_signal=None,
    diff=True, fs=30, upsample_freq=256, ax=None, **kwargs):
    if len(signal.shape) > 2 and signal.shape[2] == 2:
        signal = signal[:, :, 0]
    if diff:
        # calculate second derivative
        signal = np.diff(signal, axis=1)
        signal = (signal - np.mean(signal, axis=1, keepdims=True)) / np.std(signal, axis=1, keepdims=True)
        # if we have 1st derivative reference signal, calculate diff
        if reference_signal is not None:
            reference_signal = np.diff(reference_signal, axis=1)
            reference_signal = (reference_signal - np.mean(reference_signal, axis=1, keepdims=True)) / np.std(reference_signal, axis=1, keepdims=True)
    
    signal = signal.flatten()
    # upsample waveform to 256 Hz using cublic spline interpolation
    cs = CubicSpline(x=np.arange(signal.shape[0]), y=signal)
    xs = np.linspace(0, signal.shape[0], num=int((signal.shape[0]/fs)*upsample_freq))
    upsampled_signal = cs(xs)
    
    # # pad with nan values to make same size as first derivative
    # try:
    #     upsampled_signal = np.pad(np.squeeze(upsampled_signal), ((0, 0), (0, 1)), mode='constant', constant_values=np.nan)
    #     signal = np.pad(np.squeeze(signal), ((0, 0), (0, 1)), mode='constant', constant_values=np.nan)
    # except ValueError:
    #     upsampled_signal = np.pad(np.squeeze(upsampled_signal[:, :, 0]), ((0, 0), (0, 1)), mode='constant', constant_values=np.nan)
    #     signal = np.pad(np.squeeze(signal[:, :, 0]), ((0, 0), (0, 1)), mode='constant', constant_values=np.nan)

    upsampled_signal = upsampled_signal.flatten()
    # # calculate systolic peak times
    if reference_signal is not None:
        # upsample the reference signal (eg. ground truth PPG)
        upsampled_reference_signal = upsample_signal(reference_signal, 
            fs=fs, upsample_freq=upsample_freq)
        # get systolic peaks from upsampled reference signal 
        sys_upstroke = calc_PPG_peaks(upsampled_reference_signal,
            distance=int(0.5*upsample_freq), height=(0.1,))
    else:
        sys_upstroke = calc_PPG_peaks(upsampled_signal,
            distance=int(0.5*upsample_freq), height=(0.1,))
    # calculate dicrotic notch times 
    dicrotic_peaks = calc_PPG_peaks(upsampled_signal, 
      distance=int(0.1*upsample_freq), height=(0.,))

    # calculate systolic peak times
    # num_sec = 8
    # sys_upstroke = calc_PPG_peaks(-upsampled_signal[:upsample_freq*num_sec], 
    #     distance=int(0.5*upsample_freq), height=(0.1,))
    # # calculate dicrotic notch times 
    # dicrotic_peaks = calc_PPG_peaks(upsampled_signal[:upsample_freq*num_sec], 
    #     distance=int(0.1*upsample_freq), height=(0.,1))

    # _, axx = plt.subplots(1, 1)
    # xs = np.linspace(0, signal.shape[0], num=int((signal.shape[0]/fs)*upsample_freq))
    # axx.plot(xs[:upsample_freq*num_sec], upsampled_signal[:upsample_freq*num_sec], linestyle="--", label="256 Hz signal")
    # axx.plot(np.arange(fs*num_sec), signal[:fs*num_sec], label="30 Hz signal", alpha=0.8)
    # axx.scatter(xs[sys_upstroke], upsampled_signal[sys_upstroke], c="red")
    # axx.scatter(xs[dicrotic_peaks], upsampled_signal[dicrotic_peaks], c="blue")

    sys_dicr_time_per_frame = np.empty(upsampled_signal.shape[0])
    sys_dicr_time_per_frame.fill(np.nan)
    
    # add some amout of buffer to make sure we don't grab the wrong dicrotic
    # notch candidates
    buffer_window = 0.05 * upsample_freq
    # get all times in waveform window
    times = []
    for m in sys_upstroke:
        # from the systolic peak, get the next closest dicrotic notch
        dicrotic_peak_candidates = dicrotic_peaks[dicrotic_peaks > (m + buffer_window)]
        # if there are no candidates, skip
        if dicrotic_peak_candidates.shape[0] == 0:
            continue
        closest_dicrotic_peak = dicrotic_peak_candidates[np.argmin(dicrotic_peak_candidates - m)]
        # calculate time between closest dicrotic notch and systolic peak
        time_diff = closest_dicrotic_peak - m
        times.append(time_diff)

        # fill in time diff for frames in window
        sys_dicr_time_per_frame[m:closest_dicrotic_peak] = time_diff
    #     axx.plot([xs[m], xs[closest_dicrotic_peak]], 
    #         [upsampled_signal[m], upsampled_signal[m]], 
    #         linestyle="--", c="black")
    # plt.legend()
    # plt.show()
    # plt.close()

    # turn samples into milliseconds
    sys_dicr_time_per_frame = (sys_dicr_time_per_frame / upsample_freq) * 1000.
    
    # use median windows to smooth out times
    smoothed_window_size = upsample_freq * 10
    smoothed_sys_dicr_time_per_frame = np.empty(upsampled_signal.shape[0])
    smoothed_sys_dicr_time_per_frame.fill(np.nan)
    for i in range(0, smoothed_sys_dicr_time_per_frame.shape[0]+1, smoothed_window_size):
        smoothed_sys_dicr_time_per_frame[i:i+smoothed_window_size] = np.nanmean(sys_dicr_time_per_frame[i:i+smoothed_window_size])
    
    if not ax:
        fig, ax = plt.subplots(1, 1, figsize=(16, 6))
    # fix large jumps on plot by finding indices where consecutive values
    # are not zero (i.e. different consecutive values, creating jump)
    bad_locations = np.nonzero(np.diff(sys_dicr_time_per_frame))
    sys_dicr_time_per_frame[bad_locations] = np.nan
    ax.plot(sys_dicr_time_per_frame, label='_nolegend_', **kwargs)
    ax.plot(smoothed_sys_dicr_time_per_frame, alpha=0.8, **kwargs)
    ax.set_xticks(np.arange(0, sys_dicr_time_per_frame.shape[0]+1, upsample_freq*60))
    ax.set_xticklabels([int(x/upsample_freq) for x in ax.get_xticks()], fontsize=14)
    ax.set_xlabel("Time (seconds)", fontsize=16)
    ax.set_ylabel("LVET (ms)", fontsize=16)
    ax.tick_params(labelsize=14)

    return (np.array(times) / upsample_freq) * 1000., ax, smoothed_sys_dicr_time_per_frame

def get_ECG_HR_values(labels, fs=30, num_chunks_in_window=5, chunk_len=6, plot=False, plot_save_dir=".", title_string=""):

    num_examples_in_window = num_chunks_in_window*chunk_len

    all_true_rates = []
    # to get 30 second window, get 5 chunks, each of length 6 seconds (180 samples at 30 Hz)
    for i in range(0, labels.shape[0]-num_examples_in_window, chunk_len):
        label_window = labels[i:i+num_examples_in_window].flatten()

        true_rate = calc_HR_ECG(label_window, fs=fs)
        all_true_rates.append(true_rate)
        if plot and np.random.uniform() > 0.999:
            fig, ax = plt.subplots(1, 1, figsize=(16, 6))
            plt.plot(label_window, label="True - {}".format(true_rate))
            _, rpeaks = nk.ecg_peaks(label_window, sampling_rate=fs)
            plt.scatter(rpeaks['ECG_R_Peaks'], label_window[rpeaks['ECG_R_Peaks']], c="red")
            plt.legend()
            plt.title(title_string)
            plt.tight_layout()
            plt.savefig(os.path.join(plot_save_dir, "example_window_ECG_{}.png".format(title_string)))
            plt.show()
            plt.close()

    all_true_rates = np.array(all_true_rates)
    return all_true_rates

def get_rate_values(pred, labels, cumsum=True, filter_signal=True, fs=30, min_freq=0.75, max_freq=2.5,
    num_chunks_in_window=5, chunk_len=6, plot=False, plot_save_dir=".", title_string=""):

    num_examples_in_window = num_chunks_in_window*chunk_len
    [b, a] = scipy.signal.butter(1, [min_freq / fs * 2, max_freq / fs * 2], btype='bandpass')

    all_true_rates = []
    all_pred_rates = []
    # to get 30 second window, get 5 chunks, each of length 6 seconds (180 samples at 30 Hz)
    for i in range(0, max(1, labels.shape[0]-num_examples_in_window), chunk_len):
        label_window = labels[i:i+num_examples_in_window].flatten()
        # if there is no data, continue
        if label_window.shape[0] == 0:
            continue
        if cumsum:
            label_window = np.cumsum(label_window)

        if filter_signal:
            label_window = scipy.signal.filtfilt(b, a, np.double(label_window))
            # label_window = pre_process.detrend(np.cumsum(label_window.flatten()), 100)

        true_rate = calc_HR_freq(label_window, fs=fs, min_freq=min_freq, max_freq=max_freq)
        all_true_rates.append(true_rate)

        pred_window = pred[i:i+num_examples_in_window].flatten()
        if cumsum:
            pred_window = np.cumsum(pred_window)

        if filter_signal:
            pred_window = scipy.signal.filtfilt(b, a, np.double(pred_window))
            # pred_window = pre_process.detrend(np.cumsum(pred_window.flatten()), 100)

        pred_rate = calc_HR_freq(pred_window, fs=fs, min_freq=min_freq, max_freq=max_freq)
        all_pred_rates.append(pred_rate)

        # if plot and (np.abs(true_rate - pred_rate) > 10):
        if plot and np.random.uniform() > 0.999:
            fig, ax = plt.subplots(1, 1, figsize=(16, 6))
            plt.plot(label_window, label="True - {}".format(true_rate))
            plt.plot(pred_window, label="Pred - {}".format(pred_rate))
            plt.legend()
            plt.title(title_string)
            plt.tight_layout()
            plt.savefig(os.path.join(plot_save_dir, "example_window_{}.png".format(title_string)))
            plt.show()
            plt.close()

    all_true_rates = np.array(all_true_rates)
    all_pred_rates = np.array(all_pred_rates)
    return all_true_rates, all_pred_rates

def calculate_metrics(true, predicted):
    """Evaulates different metrics by comparing true and predicted values. 

    Args:
        true (np.ndarray): array of ground-truth values
        predicted (np.ndarray): array of predicted values that correspond to 
            the true values

    Returns:
        dict: keys are metric names, values are numeric metric values
    """
    metrics = {}
    metrics["MAE"] = np.mean(np.abs(predicted - true))
    metrics["STD AE"] = np.std(np.abs(predicted - true))
    metrics["RMSE"] = np.sqrt(np.mean(np.square(predicted - true)))
    metrics["R"] = np.corrcoef(predicted, true)[0, 1]
    for name, val in metrics.items():
        print("{}: {:.3f}".format(name, val))
    return metrics

def save_waveforms_to_file(true, pred, signal_name, filename, save_dir, fs=30,
    filter=True, min_freq=0.75, max_freq=4.0, pad=False, standardize=True):
    os.makedirs(save_dir, exist_ok=True)
    save_df = pd.DataFrame()
    # if standardize, subtract mean and divide by std dev for each window
    if standardize:
        true_mean_vals = np.mean(true, axis=1, keepdims=True)
        true_std_vals = np.std(true, axis=1, keepdims=True)
        true = (true - true_mean_vals) / true_std_vals
        pred_mean_vals = np.mean(pred, axis=1, keepdims=True)
        pred_std_vals = np.std(pred, axis=1, keepdims=True)
        pred = (pred - pred_mean_vals) / pred_std_vals

    # if pad, add NaN value at end of signal (for SD) to fill in window shape
    if pad:
        try:
            true = np.pad(np.squeeze(true), ((0, 0), (0, 1)), mode='constant', constant_values=np.nan)
            pred = np.pad(np.squeeze(pred), ((0, 0), (0, 1)), mode='constant', constant_values=np.nan)
        except ValueError:
            true = np.pad(np.squeeze(true[:, :, 0]), ((0, 0), (0, 1)), mode='constant', constant_values=np.nan)
            pred = np.pad(np.squeeze(pred[:, :, 0]), ((0, 0), (0, 1)), mode='constant', constant_values=np.nan)
    # save raw waveforms
    if filter:
        [b, a] = scipy.signal.butter(1, [min_freq / fs * 2, max_freq / fs * 2], btype='bandpass')
        save_df[f"{signal_name}_true"] = scipy.signal.filtfilt(b, a, np.cumsum(true.flatten()))
        save_df[f"{signal_name}_pred"] = scipy.signal.filtfilt(b, a, np.cumsum(pred.flatten()))
    else:
        save_df[f"{signal_name}_true"] = np.cumsum(true.flatten())
        save_df[f"{signal_name}_pred"] = np.cumsum(pred.flatten())
    # additionally save first derivative of signal 
    save_df[f"{signal_name}_true_FD"] = true.flatten()
    save_df[f"{signal_name}_pred_FD"] = pred.flatten()
    # additionally save second derivative of signal
    # pad missing value at end of second derivative with NaN
    true = np.squeeze(true)
    pred = np.squeeze(pred)
    save_df[f"{signal_name}_true_SD"] = np.pad(np.diff(true, axis=1), 
        ((0, 0), (0, 1)), 
        mode='constant', constant_values=np.nan).flatten()
    save_df[f"{signal_name}_pred_SD"] = np.pad(np.diff(pred, axis=1), 
        ((0, 0), (0, 1)), 
        mode='constant', constant_values=np.nan).flatten()
    save_df.to_csv(os.path.join(save_dir, filename), header=True, index=False)

def calculate_waveform_mae(true, pred):
    true = np.squeeze(true)
    pred = np.squeeze(pred)
    return np.mean(np.abs(np.subtract(true, pred)))

def plot_error_by_HR(error_df, plot_save_dir, col="ppg_hr", font_size=14):
    fig, ax = plt.subplots(1, 1, figsize=(10, 6))
    ax.scatter(error_df[col], error_df["error"], edgecolor="black")
    ax.set_xlabel("True HR", fontsize=font_size)
    ax.set_ylabel("MAE", fontsize=font_size)
    ax.set_title("Error vs HR", fontsize=font_size)
    ax.tick_params(labelsize=font_size)
    plt.tight_layout()
    plt.savefig(os.path.join(plot_save_dir, "error_by_HR"))
    plt.show()
    plt.close()

def plot_error_by_participant(error_df, plot_save_dir, font_size=14):
    fig, ax = plt.subplots(1, 1, figsize=(12,8))
    error_df.boxplot(ax=ax, by="participant", column="error", 
        rot=45, fontsize=12)
    ax.set_xlabel("Participant", fontsize=font_size)
    ax.set_ylabel("MAE", fontsize=font_size)
    ax.set_title("HR MAE by participant", fontsize=font_size)
    ax.tick_params(labelsize=font_size)
    plt.tight_layout()
    plt.savefig(os.path.join(plot_save_dir, "error_by_participant.png"))
    plt.show()
    plt.close()

def plot_error_by_task(error_df, plot_save_dir, font_size=14):
    fig, ax = plt.subplots(1, 1, figsize=(12,8))
    error_df.boxplot(ax=ax, by="task", column="error", 
        rot=45, fontsize=12)
    ax.set_xlabel("Task", fontsize=font_size)
    ax.set_ylabel("MAE", fontsize=font_size)
    ax.set_title("HR MAE by task", fontsize=font_size)
    ax.tick_params(labelsize=font_size)
    plt.tight_layout()
    plt.savefig(os.path.join(plot_save_dir, "error_by_task.png"))
    plt.show()
    plt.close()

def plot_ECG_vs_PPG_HR(error_df, plot_save_dir, font_size=14):
    fig, ax  = plt.subplots(1, 1, figsize=(6, 6))
    ax.scatter(error_df["ecg_hr"], error_df["ppg_hr"])
    ax.set_xlim((error_df[["ecg_hr", "ppg_hr"]].min().min()-5, error_df[["ecg_hr", "ppg_hr"]].max().max()+5))
    ax.set_ylim((error_df[["ecg_hr", "ppg_hr"]].min().min()-5, error_df[["ecg_hr", "ppg_hr"]].max().max()+5))
    # draw y=x line
    lims = [
        np.min([ax.get_xlim(), ax.get_ylim()]),  # min of both axes
        np.max([ax.get_xlim(), ax.get_ylim()]),  # max of both axes
    ]
    # now plot both limits against eachother
    ax.plot(lims, lims, 'k-',  linestyle="--", alpha=0.75, zorder=0)
    ax.set_title("ECG HR vs PPG HR", fontsize=font_size)
    ax.set_xlabel("ECG HR (BPM)", fontsize=font_size)
    ax.set_ylabel("PPG HR (BPM)", fontsize=font_size)
    ax.tick_params(labelsize=font_size)
    plt.tight_layout()
    plt.savefig(os.path.join(plot_save_dir, "ECG_vs_PPG_HR.png"))
    plt.show()
    plt.close()

def plot_scatter(true, pred, plot_save_dir, font_size=14):
    import matplotlib.markers as markers
    from scipy.stats import linregress
    
    marker = markers.MarkerStyle(marker='s')
    fig, ax  = plt.subplots(1, 1, figsize=(6, 6))
    ax.scatter(true, pred, color='blue', marker=marker, facecolors='none')
    
    ax.set_xlim((np.min([np.min(true), np.min(pred)])-5, np.max([np.max(true), np.max(pred)])+5))
    ax.set_ylim((np.min([np.min(true), np.min(pred)])-5, np.max([np.max(true), np.max(pred)])+5))
    ax.set_xlabel("Gold-Standard (Beats/Min)", fontsize=font_size)
    ax.set_ylabel("Predictions (Beats/Min)", fontsize=font_size)
    # draw y=x line
    lims = [
        np.min([ax.get_xlim(), ax.get_ylim()]),  # min of both axes
        np.max([ax.get_xlim(), ax.get_ylim()]),  # max of both axes
    ]
    ax.plot(lims, lims, 'k-',  linestyle="--", alpha=0.25, zorder=0)
    # annotate with best fit regression line and confidence intervals
    regression = linregress(true, pred)
    line_x_vals = np.arange(*ax.get_xlim())
    ax.plot(line_x_vals, regression.intercept + regression.slope*line_x_vals, c='black', alpha=1, zorder=0, linewidth=0.7)
    ax.plot(line_x_vals, regression.intercept+(1.96*regression.intercept_stderr) + regression.slope*line_x_vals, c='black', alpha=1, zorder=0, linewidth=0.7)
    ax.plot(line_x_vals, regression.intercept-(1.96*regression.intercept_stderr) + regression.slope*line_x_vals, c='black', alpha=1, zorder=0, linewidth=0.7)
    ax.text(x=ax.get_xlim()[0], y=ax.get_ylim()[1]-10, s="r$^2$={:.2f}".format(regression.rvalue**2), fontsize=font_size,)
    ax.text(x=ax.get_xlim()[0], y=ax.get_ylim()[1]-15, s="y={:.2f}x+{:.2f}".format(regression.slope, regression.intercept), fontsize=font_size,)
    ax.tick_params(labelsize=font_size)
    plt.tight_layout()
    plt.savefig(os.path.join(plot_save_dir, "scatter_plot.png"))
    plt.show()
    plt.close()

def plot_bland_altman(true, pred, plot_save_dir, normalize=True, font_size=14, ci=1.96):
    import matplotlib.markers as markers
    
    marker = markers.MarkerStyle(marker='s')
    fig, ax  = plt.subplots(1, 1, figsize=(8, 6))
    xvals = (true + pred) / 2.
    yvals = pred - true
    if normalize:
        yvals = (yvals / true) * 100.
    ax.scatter(xvals, yvals, color='blue', marker=marker, facecolors='none')
    
    # annotate plot iwith mean and 95% CI
    mean_val = np.mean(yvals)
    std_val = np.std(yvals)
    line_x_vals = np.arange(*ax.get_xlim())
    min_x, max_x = ax.get_xlim()
    ax.hlines(mean_val, min_x, max_x, colors='black', linewidth=0.7)
    ax.hlines(mean_val+ci*std_val, min_x, max_x, colors='black', linestyles=':', linewidth=0.9)
    ax.hlines(mean_val-ci*std_val, min_x, max_x, colors='black', linestyles=':', linewidth=0.9)
    ax.text(x=max_x+5, y=mean_val, s="{:.1f}".format(mean_val), fontsize=font_size,)
    ax.text(x=max_x+5, y=mean_val+ci*std_val, s="{:.0f} (+{:.2f}SD)".format(mean_val+ci*std_val, ci), fontsize=font_size,)
    ax.text(x=max_x+5, y=mean_val-ci*std_val, s="{:.0f} (-{:.2f}SD)".format(mean_val-ci*std_val, ci), fontsize=font_size,)
    # ax.set_xlim(())
    ax.set_ylim((-105, 105))
    ax.set_xlabel("Mean Gold-Standard & Predictions (Beats/Min)", fontsize=font_size)
    ax.set_ylabel("Predictions - Gold-Standard {}".format("(%)" if normalize else ""), fontsize=font_size)
    ax.tick_params(labelsize=font_size)
    plt.tight_layout()
    plt.savefig(os.path.join(plot_save_dir, "bland-altman_plot.png"))
    plt.show()
    plt.close()

def plot_bland_altman_sys_dicr_time(true, pred, plot_save_dir,
    file_name_prefix="sys_dicr_time_bland_altman_plot",
    normalize=True, font_size=14, ci=1.96):
    import matplotlib.markers as markers

    marker = markers.MarkerStyle(marker='s')
    fig, ax  = plt.subplots(1, 1, figsize=(8, 6))
    xvals = true
    yvals = pred
    if normalize:
        yvals = (yvals / true) * 100.
    ax.scatter(xvals, yvals, color='blue', marker=marker, facecolors='none')

    # annotate plot iwith mean and 95% CI
    mean_val = np.mean(yvals)
    std_val = np.std(yvals)
    ax.set_xlim((50, 260))
    line_x_vals = np.arange(*ax.get_xlim())
    min_x, max_x = ax.get_xlim()
    ax.hlines(mean_val, min_x, max_x, colors='black', linewidth=0.7)
    ax.hlines(mean_val+ci*std_val, min_x, max_x, colors='black', linestyles=':', linewidth=0.9)
    ax.hlines(mean_val-ci*std_val, min_x, max_x, colors='black', linestyles=':', linewidth=0.9)
    ax.text(x=max_x+10, y=mean_val, s="{:.1f}".format(mean_val), fontsize=font_size,)
    ax.text(x=max_x+10, y=mean_val+ci*std_val, s="{:.0f} (+{:.2f}SD)".format(mean_val+ci*std_val, ci), fontsize=font_size,)
    ax.text(x=max_x+10, y=mean_val-ci*std_val, s="{:.0f} (-{:.2f}SD)".format(mean_val-ci*std_val, ci), fontsize=font_size,)
    ax.set_ylim((-5, 305))
    ax.set_xlabel("True Systolic to Dicrotic Notch Time (ms)", fontsize=font_size)
    ax.set_ylabel("|True - Predicted| Time (ms){}".format("(%)" if normalize else ""), fontsize=font_size)
    ax.tick_params(labelsize=font_size)
    plt.tight_layout()
    plt.savefig(os.path.join(plot_save_dir, f"{file_name_prefix}.png"))
    plt.savefig(os.path.join(plot_save_dir, f"{file_name_prefix}.svg"))
    plt.show()
    plt.close()

def combine_sliding_window_preds(signal: np.ndarray, step_size: int):
    num_windows = int(signal.shape[1]/step_size)
    # print(f"W: {num_windows}")
    combined_windows = []
    for i in range(0, signal.shape[0]):
        # print("i", i)
        num_rows = min(i, num_windows-1) + 1
        # get a subset of the predicted window matrix and combine
        # windows across the flipped block diagonal
        sub_mat = signal[i:i+num_rows]
        # print("sub mat shape", sub_mat.shape)
        # print("num rows", num_rows)
        windows_to_combine = []
        for j in range(sub_mat.shape[0]):
            # print(f'for row {sub_mat.shape[0]-j-1} getting ', j*step_size, j*step_size+step_size)
            win = sub_mat[sub_mat.shape[0]-j-1][j*step_size:j*step_size+step_size]
            # plt.plot(np.cumsum(win))
            windows_to_combine.append(win)
        combined_win = np.median(np.array(windows_to_combine), axis=0)
        combined_windows.append(combined_win)
        # plt.plot(np.cumsum(combined_win), label="mean")
        # plt.legend()
        # plt.show()

    combined_windows = np.array(combined_windows)
    # fig, ax = plt.subplots(1, 1, figsize=(12, 8))
    # plt.plot(np.cumsum(combined_windows, axis=1).flatten())
    return combined_windows

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-file', '--metric_file', type=str, default='C:\\Users\\username\Downloads\\trainSyn2345678900_testAFRLall_ppg_resp_8epoch_modelv3_msemaskloss\\metric_HR.mat',
                    help='path to metric_HR.mat file containing true/predicted waveform values')
    parser.add_argument('-save', '--save_dir', type=str, default=None, 
                    help='Path to directory to save plots/metrics (default: "plot" directory in same directory as metric_HR.mat file')
    parser.add_argument('-fs', '--sample_freq', type=int, default=30, 
                    help='Sample frequency (Hz)')
    parser.add_argument('--exclude_file', type=str, default=None,
                    help='File containing files to exclude from analysis')
    parser.add_argument('--exclude_file_dataset', type=str, default=None,
                    help='Dataset in exclude file to exclude from analysis')
    args = parser.parse_args()

    # %% load metric file
    metric_file_HR = args.metric_file
    assert os.path.exists(metric_file_HR)

    if args.exclude_file:
        # exclude file should have at least two columns: 
        # one specifying the dataset (eg. AFRL)
        print("Excluding bad files from {}".format(args.exclude_file_dataset))
        exclude_df = pd.read_csv(args.exclude_file, header=0)
        
        assert args.exclude_file_dataset in exclude_df["Dataset"].unique(), "Dataset must be one of: {}".format(exclude_df["Dataset"].unique())
        # only take rows for the dataset of interest
        exclude_df = exclude_df[exclude_df["Dataset"] == args.exclude_file_dataset]
        # get the list of files to ignore in the analysis
        files_to_exclude = exclude_df["File name"]
        print("Excluding {} files from the analysis".format(files_to_exclude.shape[0]))
        # get participant and task ID from the file name
        exclude_df["participant_id"] = exclude_df["File name"].apply(get_participant_id)
        try:
            exclude_df["task_id"] = exclude_df["File name"].apply(get_task_id)
        # unless there is no task in the file name
        except AttributeError:
            exclude_df["task_id"] = 0
        print(exclude_df.head())

    # if save directory not supplied as command line argument, use default directory
    if args.save_dir is not None:
        plot_save_dir = args.save_dir
    else:
        plot_save_dir = os.path.join(os.path.dirname(metric_file_HR), "plots")
    # create directory to save plots/metrics if it does not already exist
    os.makedirs(plot_save_dir, exist_ok=True)
    # create directory for saving true/predicted waveforms
    waveform_save_dir = os.path.join(plot_save_dir, "waveforms")
    os.makedirs(waveform_save_dir, exist_ok=True)
    # create directory to save sys->dicrotic notch time plots
    systolic_dicrotic_time_plot_dir = os.path.join(plot_save_dir, "sys_dicrotic_time")
    os.makedirs(systolic_dicrotic_time_plot_dir, exist_ok=True)
    # create directory to save sys->dicrotic notch time plots
    instant_HR_plot_dir = os.path.join(plot_save_dir, "instant_HR")
    os.makedirs(instant_HR_plot_dir, exist_ok=True)

    # set sample frequency 
    fs = args.sample_freq

    pred_file_HR = loadmat(metric_file_HR,)
    print(pred_file_HR.keys())

    # %% get predictions and labels from metric_HR file
    if "dysub_pred" in pred_file_HR.keys():
        pulse_pred = pred_file_HR["dysub_pred"]
        pulse_labels = pred_file_HR["dysub_label"]
        if len(pulse_labels.shape) > 2 and pulse_labels.shape[2] > 1:
            print("Removing mask...")
            pulse_labels = pulse_labels[:, :, :1]
        print(pulse_pred.shape)
        print(pulse_labels.shape)
    # else if dysub_pred not in keys, and we have 2nd derivative (SD), 
    # then calculate dysub_pred from SD
    elif "dysub_SD_pred" in pred_file_HR.keys():
        print("Missing dysub_pred but we have dysub_SD_pred.")
        print("Calculating dysub_pred using dysub_SD_pred...")
        print(pred_file_HR["dysub_SD_pred"].shape)
        pred_file_HR["dysub_pred"] = recover_signal_from_diff(
            new_signal=pred_file_HR["dysub_SD_pred"], 
            original_signal=pred_file_HR["dysub_label"]
            )
        pulse_pred = pred_file_HR["dysub_pred"]
        pulse_labels = pred_file_HR["dysub_label"]
        if len(pulse_labels.shape) > 2 and pulse_labels.shape[2] > 1:
            print("Removing mask...")
            pulse_labels = pulse_labels[:, :, :1]
        print(pulse_pred.shape)
        print(pulse_labels.shape)
    # if ground truth ECG data is also included, load data
    if "ecg30" in pred_file_HR.keys():
        ecg_labels = pred_file_HR["ecg30"]

    pulse_files = pred_file_HR["file_label"]
    print(pulse_files.shape)

    if "drsub_pred" in pred_file_HR.keys():
        resp_pred = pred_file_HR["drsub_pred"]
        resp_labels = pred_file_HR["drsub_label"]
        print(resp_pred.shape)
        print(resp_labels.shape)

    # for each window, get participant and task IDs
    participant_ids = np.array([get_participant_id(i) for i in pulse_files])
    try:
        task_ids = np.array([get_task_id(i) for i in pulse_files])
    except AttributeError as e:
        print(e)
        print("Creating dummy task")
        task_ids = np.zeros(len(pulse_files))
    # get chunk IDs as numeric values
    chunk_ids = np.array([get_chunk_id(i) for i in pulse_files])
    numeric_chunk_ids = np.array([int(i.split("C")[1]) for i in chunk_ids])
    # get window numbers as numeric values to further sort within chunks
    numeric_window_numbers = np.array([int(get_window_number(i)) for i in pulse_files])

    unique_participant_ids = np.unique(participant_ids)
    unique_task_ids = np.unique(task_ids)

    result_df = pd.DataFrame()
    HR, HR0 = [], []
    RR, RR0 = [], []
    if "ecg30" in pred_file_HR.keys():
        ECGHR = []

    participant_col, task_col, pred_HR_col, error_col, ppg_hr_col, ecg_hr_col = [], [], [], [], [], []
    waveform_mae_col, instant_HR_mae, systolic_dicrotic_time_mae, systolic_dicrotic_time = [], [], [], []
    smoothed_systolic_dicrotic_time_mae, smoothed_systolic_dicrotic_time = [], []
    systolic_dicrotic_time_SD_mae = []
    smoothed_systolic_dicrotic_time_SD_mae, smoothed_systolic_dicrotic_time_SD = [], []
    SD_waveform_mae_col = []
    # iterate through each participant and task 
    for participant_id in unique_participant_ids:
        print(participant_id)
        # get indices of windows where we have data for this participant
        participant_files = np.where(participant_ids == participant_id)
        # for task_id in unique_task_ids:
        for task_id in ["T1", "T2"]:
            print(task_id)
            if args.exclude_file:
                # check if this is a file we should exclude from the analysis
                if participant_id in exclude_df["participant_id"].values and task_id in exclude_df["task_id"].values:
                    print("{} - {} found in the file for videos to exclude. Skipping...".format(participant_id, task_id))
                    continue
            # get indices of windows where we have data for this task
            task_files = np.where(task_ids == task_id)
            # get the overlap between the two sets of files
            overlap_files = np.intersect1d(participant_files, task_files)
            if overlap_files.shape[0] == 0:
                continue
            # get the sorted ordering of the chunks based on numeric value, 
            # and then within each chunk sort by window number 
            sorted_overlap_file_idx = np.lexsort((numeric_window_numbers[overlap_files], numeric_chunk_ids[overlap_files]))
            # use sorted ordering to reorder chunk indices in descending order
            overlap_files = overlap_files[sorted_overlap_file_idx]
            # print(list(zip(participant_ids[overlap_files], task_ids[overlap_files], numeric_chunk_ids[overlap_files], numeric_window_numbers[overlap_files])))

            # combined_pulse_label = combine_sliding_window_preds(pulse_labels[overlap_files], step_size=5)
            # combined_pulse_preds = combine_sliding_window_preds(pulse_pred[overlap_files], step_size=5)
            # print("pulse pred shape", pulse_pred[overlap_files].shape)
            # print("Combined pulse preds shape", combined_pulse_preds.shape)
            # plt.plot(np.cumsum(combined_pulse_preds, axis=1).flatten())

            # calculate heart rate values 
            title_string = "{} - {}".format(participant_id, task_id)
            if overlap_files.shape[0] > 0:
                if "dysub_pred" in pred_file_HR.keys():
                    HR0_temp, HR_temp = get_rate_values(pulse_pred[overlap_files], pulse_labels[overlap_files], 
                        cumsum=True, filter_signal=True, min_freq=0.75, max_freq=4.0, fs=fs,
                        plot=True, plot_save_dir=plot_save_dir, title_string=title_string)
                    HR0.extend(HR0_temp)
                    HR.extend(HR_temp)

                    pred_HR_col.append(np.median(np.array(HR_temp)))
                    error_col.append(np.median(np.abs(np.array(HR0_temp) - np.array(HR_temp))))
                    ppg_hr_col.append(np.median(np.array(HR0_temp)))

                    plot_per_frame_instant_HR(np.cumsum(pulse_labels[overlap_files], axis=1).flatten(), 
                        np.cumsum(pulse_pred[overlap_files], axis=1).flatten(), 
                        save_dir=instant_HR_plot_dir, file_prefix=title_string)
                    true_instant_HR = generate_per_frame_instant_HR(np.cumsum(pulse_labels[overlap_files], axis=1).flatten(), fs=fs, smooth=True)
                    pred_instant_HR = generate_per_frame_instant_HR(np.cumsum(pulse_pred[overlap_files], axis=1).flatten(), fs=fs, smooth=True)
                    diff_instant_HR = np.nanmedian(np.abs(true_instant_HR - pred_instant_HR))
                    # print("diff instant HR", diff_instant_HR)
                    # if diff_instant_HR > 10:
                    #     plot_per_frame_instant_HR(np.cumsum(pulse_labels[overlap_files][0:10], axis=1).flatten(), 
                    #         np.cumsum(pulse_pred[overlap_files][0:10], axis=1).flatten())
                    instant_HR_mae.append(diff_instant_HR)

                    # get median difference between median times from systolic peak 
                    # to dicrotic notch
                    true_sys_dicrotic_time, true_sys_dicrotic_time_ax, sst = calc_time_sys_to_dicrotic(pulse_labels[overlap_files], fs=fs)
                    pred_sys_dicrotic_time, pred_sys_dicrotic_time_ax, ssp = calc_time_sys_to_dicrotic(pulse_pred[overlap_files], reference_signal=pulse_labels[overlap_files], fs=fs, ax=true_sys_dicrotic_time_ax)
                    plt.legend(["True time", "True time (10s window)", "Pred time", "Pred time (10s window)"])
                    plt.tight_layout()
                    plt.savefig(os.path.join(systolic_dicrotic_time_plot_dir, "{}.png".format(title_string)))
                    plt.savefig(os.path.join(systolic_dicrotic_time_plot_dir, "{}.svg".format(title_string)))
                    plt.show()
                    plt.close()

                    diff_sys_dicrotic_time = np.median(np.abs(np.median(true_sys_dicrotic_time) - np.median(pred_sys_dicrotic_time)))
                    systolic_dicrotic_time_mae.append(diff_sys_dicrotic_time)
                    systolic_dicrotic_time.append(np.median(true_sys_dicrotic_time))
                    smoothed_systolic_dicrotic_time_mae.append(np.mean(np.abs(sst - ssp)))
                    smoothed_systolic_dicrotic_time.append(np.mean(sst))

                # get median difference between median times from systolic peak 
                # to dicrotic notch for predicted SD 
                if "dysub_SD_pred" in pred_file_HR.keys():
                    print("True")
                    true_sys_dicrotic_time_SD, true_sys_dicrotic_time_SD_ax, sst = calc_time_sys_to_dicrotic(pred_file_HR["dysub_SD_label"][overlap_files], diff=False, fs=fs)
                    print("Pred")
                    pred_sys_dicrotic_time_SD, pred_sys_dicrotic_time_SD_ax, ssp = calc_time_sys_to_dicrotic(pred_file_HR["dysub_SD_pred"][overlap_files], reference_signal=pred_file_HR["dysub_SD_label"][overlap_files], diff=False, fs=fs, ax=true_sys_dicrotic_time_SD_ax)
                    plt.legend(["True time", "True time (10s window)", "Pred time", "Pred time (10s window)"])
                    plt.tight_layout()
                    plt.savefig(os.path.join(systolic_dicrotic_time_plot_dir, "{} - dysub_SD_pred.png".format(title_string)))
                    plt.savefig(os.path.join(systolic_dicrotic_time_plot_dir, "{} - dysub_SD_pred.svg".format(title_string)))
                    plt.show()
                    plt.close()

                    diff_sys_dicrotic_time_SD = np.median(np.abs(np.median(true_sys_dicrotic_time_SD) - np.median(pred_sys_dicrotic_time_SD)))
                    systolic_dicrotic_time_SD_mae.append(diff_sys_dicrotic_time_SD)
                    smoothed_systolic_dicrotic_time_SD_mae.append(np.mean(np.abs(sst - ssp)))
                    smoothed_systolic_dicrotic_time_SD.append(np.mean(sst))


                if "drsub_pred" in pred_file_HR.keys():
                # calculate resp rate values
                    RR0_temp, RR_temp = get_rate_values(resp_pred[overlap_files], resp_labels[overlap_files], 
                        cumsum=False, filter_signal=False, min_freq=0.08, max_freq=0.5, fs=fs,
                        plot=False, plot_save_dir=plot_save_dir, title_string=title_string)
                    RR0.extend(RR0_temp)
                    RR.extend(RR_temp)

                # if we also have ECG ground truth data, extract HR 
                if "ecg30" in pred_file_HR.keys():
                    ECGHR0_temp = get_ECG_HR_values(ecg_labels[overlap_files], fs=fs,
                        plot=True, plot_save_dir=plot_save_dir, title_string=title_string)
                    ECGHR.extend(ECGHR0_temp)

            waveform_mae = np.nan
            SD_waveform_mae = np.nan
            # save true/predicted waveforms for each participant/task
            if overlap_files.shape[0] > 0:
                if "dysub_pred" in pred_file_HR.keys():
                    save_waveforms_to_file(true=pulse_labels[overlap_files], 
                        pred=pulse_pred[overlap_files],
                    # save_waveforms_to_file(true=combined_pulse_label, 
                        # pred=combined_pulse_preds,
                        signal_name="ppg",
                        filename="{}.csv".format(title_string), 
                        save_dir=os.path.join(waveform_save_dir, "ppg"))
                    waveform_mae = calculate_waveform_mae(pulse_labels[overlap_files], 
                        pulse_pred[overlap_files])
                if "drsub_pred" in pred_file_HR.keys():
                    save_waveforms_to_file(true=resp_labels[overlap_files], 
                        pred=resp_pred[overlap_files],
                        signal_name="resp",
                        filename="{}.csv".format(title_string), 
                        save_dir=os.path.join(waveform_save_dir, "resp"))
                # TODO: check if labels exist, and use if they do
                # if ECG signal is available, save true/preds
                if "ecg_pred" in pred_file_HR.keys():
                    save_waveforms_to_file(true=pred_file_HR["ecg_pred"][overlap_files], 
                        pred=pred_file_HR["ecg_pred"][overlap_files],
                        signal_name="ecg",
                        filename="{}.csv".format(title_string), 
                        save_dir=os.path.join(waveform_save_dir, "ecg"))
                # if ABP signal is available, save true/preds
                # make sure we don't standardize ABP waveform magnitude
                if "abp_pred" in pred_file_HR.keys():
                    save_waveforms_to_file(true=pred_file_HR["abp_pred"][overlap_files], 
                        pred=pred_file_HR["abp_pred"][overlap_files],
                        signal_name="abp",
                        filename="{}.csv".format(title_string), 
                        save_dir=os.path.join(waveform_save_dir, "abp"), 
                        standardize=False)
                # if dysub_raw_pred signal is available, save true/preds
                if "dysub_raw_pred" in pred_file_HR.keys():
                    save_waveforms_to_file(true=pred_file_HR["dysub_raw_label"][overlap_files, :, 0], 
                        pred=pred_file_HR["dysub_raw_pred"][overlap_files],
                        signal_name="dysub_raw_pred",
                        filename="{}.csv".format(title_string), 
                        save_dir=os.path.join(waveform_save_dir, "dysub_raw_pred"))
                # if dysub_SD_pred signal is available, save true/preds
                if "dysub_SD_pred" in pred_file_HR.keys():
                    save_waveforms_to_file(true=pred_file_HR["dysub_SD_label"][overlap_files], 
                        pred=pred_file_HR["dysub_SD_pred"][overlap_files],
                        signal_name="dysub_SD_pred",
                        filename="{}.csv".format(title_string), 
                        save_dir=os.path.join(waveform_save_dir, "dysub_SD_pred"), 
                        pad=True)
                    SD_waveform_mae = calculate_waveform_mae(pred_file_HR["dysub_SD_label"][overlap_files, :, 0], 
                        pred_file_HR["dysub_SD_pred"][overlap_files])
                # if dysub_raw_joint_pred signal is available, save true/preds
                if "dysub_raw_joint_pred" in pred_file_HR.keys():
                    save_waveforms_to_file(true=pred_file_HR["dysub_raw_label"][overlap_files], 
                        pred=pred_file_HR["dysub_raw_joint_pred"][overlap_files],
                        signal_name="dysub_raw_joint_pred",
                        filename="{}.csv".format(title_string), 
                        save_dir=os.path.join(waveform_save_dir, "dysub_raw_joint_pred"))

                # save for breaking down errors by participant and task
                participant_col.append(participant_id)
                task_col.append(task_id)
                waveform_mae_col.append(waveform_mae) 
                SD_waveform_mae_col.append(SD_waveform_mae)
                
            if "ecg30" in pred_file_HR.keys():
                ecg_hr_col.append(np.median(np.array(ECGHR0_temp)))

    HR, HR0 = np.array(HR), np.array(HR0)
    RR, RR0 = np.array(RR), np.array(RR0)
    if "ecg30" in pred_file_HR.keys():
        ECGHR = np.array(ECGHR)

    plt.hist(HR - HR0, bins=20)
    plt.show()
    plt.close()

    # %% create error dataframe
    error_df = pd.DataFrame.from_dict({"participant": participant_col, 
    "task": task_col,
    "pred_hr": pred_HR_col,
    "error": error_col,
    "ppg_hr": ppg_hr_col,
    "waveform_mae": waveform_mae_col,
    "instant_HR_mae": instant_HR_mae,
    "systolic_dicrotic_time_mae": systolic_dicrotic_time_mae,
    "systolic_dicrotic_time": systolic_dicrotic_time,
    "smoothed_systolic_dicrotic_time_mae": smoothed_systolic_dicrotic_time_mae,
    "smoothed_systolic_dicrotic_time": smoothed_systolic_dicrotic_time,
    }, orient="columns")
    # calculate percent error
    error_df["smoothed_systolic_dicrotic_time_mae_percent"] = (error_df["smoothed_systolic_dicrotic_time_mae"] / error_df["smoothed_systolic_dicrotic_time"])*100.

    # optionally add additional columns if data is available
    if "ecg30" in pred_file_HR.keys():
        error_df["ecg_hr"] = ecg_hr_col
    if "dysub_SD_pred" in pred_file_HR.keys():
        error_df["systolic_dicrotic_time_SD_mae"] = systolic_dicrotic_time_SD_mae
        error_df["smoothed_systolic_dicrotic_time_SD_mae"] = smoothed_systolic_dicrotic_time_SD_mae
        error_df["smoothed_systolic_dicrotic_time_SD"] = smoothed_systolic_dicrotic_time_SD
        error_df["SD_waveform_mae"] = SD_waveform_mae_col
        # calculate percent error
        error_df["smoothed_systolic_dicrotic_time_SD_mae_percent"] = (error_df["smoothed_systolic_dicrotic_time_SD_mae"] / error_df["smoothed_systolic_dicrotic_time_SD"])*100.

    # remove rows with missing data
    try:
        error_df.dropna(axis=0, subset=["ppg_hr"], inplace=True)
    except KeyError:
        pass
    error_df.to_csv(os.path.join(os.path.dirname(metric_file_HR), "error_df.csv"), header=True, index=True)
    print(error_df)

    print(error_df["task"].unique())
    waveform_metric_tasks = ["T1", "T2"]
    waveform_metric_cols = [
        "waveform_mae", 
        "instant_HR_mae", 
        "systolic_dicrotic_time",
        "systolic_dicrotic_time_mae", 
        "systolic_dicrotic_time_SD_mae", 
        "smoothed_systolic_dicrotic_time", 
        "smoothed_systolic_dicrotic_time_mae", 
        "smoothed_systolic_dicrotic_time_mae_percent",
        "smoothed_systolic_dicrotic_time_SD_mae", 
        "smoothed_systolic_dicrotic_time_SD_mae_percent",
        "SD_waveform_mae"]
    with open(os.path.join(plot_save_dir, "waveform_metrics.csv"), "w") as out_f:
        for col in waveform_metric_cols:
            try:
                def calc_mean_std(x):
                    return r'{:.2f} $\pm$ {:.2f}'.format(x.mean(), x.std())
                print("{},{:.3f}".format(col, error_df[error_df["task"].isin(waveform_metric_tasks)][col].median()))
                print("{},{:.3f}".format(col, error_df[error_df["task"].isin(waveform_metric_tasks)][col].mean()))
                print(r"{},{}".format(col, calc_mean_std(error_df[error_df["task"].isin(waveform_metric_tasks)][col])))
                out_f.write("{},{:.3f}\n".format(col, error_df[error_df["task"].isin(waveform_metric_tasks)][col].median()))
                out_f.write("{},{}\n".format(col, calc_mean_std(error_df[error_df["task"].isin(waveform_metric_tasks)][col])))
            except KeyError as e:
                print(e)
                print("Cannot find {} in error_df. Skipping...".format(col))
    # %% Plot error as function of HR
    plot_error_by_HR(error_df, plot_save_dir=plot_save_dir, col="ppg_hr")

    # %% Plot error by participant
    plot_error_by_participant(error_df, plot_save_dir=plot_save_dir)

    # %% Plot error by task
    plot_error_by_task(error_df, plot_save_dir=plot_save_dir)

    # %% Scatter plot of true vs predicted HR
    plot_scatter(HR0, HR, plot_save_dir=plot_save_dir)

    # %% Bland-Altman plot comparing true and predicted HR
    plot_bland_altman(HR0, HR, plot_save_dir=plot_save_dir, normalize=True)

    # plot bland-altman plot for systolic to dicrotic notch time error for SD
    plot_bland_altman_sys_dicr_time(
        true=error_df["smoothed_systolic_dicrotic_time"], 
        pred=error_df["smoothed_systolic_dicrotic_time_mae"], 
        plot_save_dir=plot_save_dir, 
        file_name_prefix="smoothed_systolic_dicrotic_time",
        normalize=False)

    if "dysub_SD_pred" in pred_file_HR.keys():
        # plot bland-altman plot for systolic to dicrotic notch time error for SD
        plot_bland_altman_sys_dicr_time(
            true=error_df["smoothed_systolic_dicrotic_time_SD"], 
            pred=error_df["smoothed_systolic_dicrotic_time_SD_mae"], 
            plot_save_dir=plot_save_dir, 
            file_name_prefix="smoothed_systolic_dicrotic_time_SD",
            normalize=False)

    # %% Plot ECG vs PPG HR
    if "ecg30" in pred_file_HR.keys():
        plot_ECG_vs_PPG_HR(error_df, plot_save_dir=plot_save_dir)

    # %% calculate metrics
    metric_dict = {}
    print("PPG-based Heart Rate:")
    HR_metrics = calculate_metrics(true=HR0, predicted=HR)
    metric_dict["PPG HR"] = HR_metrics
    if "ecg30" in pred_file_HR.keys():
        print("="*40)
        print("ECG-based Heart Rate:")
        ECG_metrics = calculate_metrics(true=ECGHR, predicted=HR)
        metric_dict["ECG HR"] = ECG_metrics
    print("="*40)
    print("Resp Rate:")
    RR_metrics = calculate_metrics(true=RR0, predicted=RR)
    metric_dict["RR"] = RR_metrics
    # save metrics to file
    metrics_df = pd.DataFrame(metric_dict)
    metrics_df.to_csv(os.path.join(plot_save_dir, "metrics.csv"), index=True, header=True)

    # %% save example prediction plots to file
    min_freq = 0.75
    max_freq = 4.0
    [b, a] = scipy.signal.butter(1, [min_freq / fs * 2, max_freq / fs * 2], btype='bandpass')
    for i in np.random.randint(low=0, high=pred_file_HR["dysub_label"].shape[0], size=10):
        print(i)
        label_window = scipy.signal.filtfilt(b, a, np.double(np.cumsum(pred_file_HR["dysub_label"][i])))
        plt.plot(label_window, label="True")
        try:
            pred_window = scipy.signal.filtfilt(b, a, np.double(np.cumsum(pred_file_HR["dysub_pred"][i][:, 0])))
        except IndexError:
            pred_window = scipy.signal.filtfilt(b, a, np.double(np.cumsum(pred_file_HR["dysub_pred"][i])))

        plt.plot(pred_window, label="Pred")
        plt.legend()
        plt.tight_layout()
        plt.savefig(os.path.join(plot_save_dir, f"example_pred_{i}.png"))
        plt.show()
        plt.close()

    # %% plot ECG windows 
    if "ecg_pred" in pred_file_HR.keys():
        for i in np.random.randint(low=0, high=pred_file_HR["ecg_pred"].shape[0], size=10):
            print(i)
            try:
                pred_window = np.cumsum(pred_file_HR["ecg_pred"][i][:, 0])
            except IndexError:
                pred_window = np.cumsum(pred_file_HR["ecg_pred"][i])
            plt.plot(pred_window, label="Pred")
            plt.legend()
            plt.tight_layout()
            plt.savefig(os.path.join(plot_save_dir, f"example_ecg_pred_{i}.png"))
            plt.show()

    # %% plot ABP windows 
    if "abp_pred" in pred_file_HR.keys():
        for i in np.random.randint(low=0, high=pred_file_HR["abp_pred"].shape[0], size=10):
            print(i)
            try:
                pred_window = np.cumsum(pred_file_HR["abp_pred"][i][:, 0])
            except IndexError:
                pred_window = np.cumsum(pred_file_HR["abp_pred"][i])

            plt.plot(pred_window, label="Pred")
            plt.legend()
            plt.tight_layout()
            plt.savefig(os.path.join(plot_save_dir, f"example_abp_pred_{i}.png"))
            plt.show()

if __name__ == "__main__":
    main()
