# %% imports
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import h5py
import glob
import scipy
from scipy.interpolate import CubicSpline
from scipy.io import loadmat
from scipy.signal import periodogram, detrend
import re
import argparse
import neurokit2 as nk

from calculate_metrics import calc_time_sys_to_dicrotic, calc_PPG_peaks

GREEN = '#00946c'
BLUE = '#568ce9'
RED = '#D55E00'
DARK='black'
SCATTER_SIZE=50
LABEL_FONT_SIZE=24
LINE_WIDTH=2.0
LINE_STYLE="-"

file1 = "/data1/ippg/pttraining/username_results/trainSynFull_testAFRL_ppg_8epoch_modelv1_baseline/plots/waveforms/ppg"
# # file1 = "/data1/ippg/pttraining/username_results/trainSynFull_testAFRL_ppgSD_8epoch_modelv1_SDpeakloss/plots/waveforms/dysub_SD_pred"
# file2 = "/data1/ippg/pttraining/username_results/trainSynFull_testAFRL_ppgSD_8epoch_modelv1/plots/waveforms/dysub_SD_pred"
file2 = "/data1/ippg/pttraining/username_results/trainSynFull_testAFRL_ppgSD_8epoch_modelv1_SDframes/plots/waveforms/dysub_SD_pred"
# # file2 = "/data1/ippg/pttraining/username_results/trainSynFull_testAFRL_ppg_8epoch_modelv1_FDpeakloss/plots/waveforms/ppg"
# # file2 = "/data1/ippg/pttraining/username_results/trainSynFull_testAFRL_ppg_ppgSD_8epoch_modelv1/plots/waveforms/dysub_SD_pred"

# file1 = "/data1/ippg/pttraining/username_results/trainSynFull_testAFRL_ppgSD_8epoch_modelv1/plots/waveforms/dysub_SD_pred"
# # file2 = "/data1/ippg/pttraining/username_results/trainSynFull_testAFRL_ppgSD_8epoch_modelv1_SDpeakloss/plots/waveforms/dysub_SD_pred"
# file2 = "/data1/ippg/pttraining/username_results/trainSynFull_testAFRL_ppgSD_8epoch_modelv1_SDframes/plots/waveforms/dysub_SD_pred"

label1 = "FD-optimized Prediction"
label2 = "SD-optimized Prediction"

plot_dir = "sys_dicr_plots_baseline_vs_SD_SDframes"
os.makedirs(plot_dir, exist_ok=True)
file_to_compare = "P21 - T1.csv"

num_seconds_plot = int(5)
fs = 30
upsample_freq = 300
num_samples = fs * num_seconds_plot

# %%
df1 = pd.read_csv(os.path.join(file1, file_to_compare))
df1 = df1.iloc[:num_samples]
df1.head()

# %%
df2 = pd.read_csv(os.path.join(file2, file_to_compare))
df2 = df2.iloc[:num_samples]
df2.head()

# %%
def upsample_signal(signal, upsample_freq=256):
    cs = CubicSpline(x=np.arange(signal.shape[0]), y=signal)
    xs = np.linspace(0, signal.shape[0], num=int((signal.shape[0]/fs)*upsample_freq))
    upsampled_signal = cs(xs)
    return upsampled_signal
# %%
def get_col_key(file):
    if os.path.basename(file) == "ppg":
        col_key = "ppg_pred_SD"
    elif os.path.basename(file) == "dysub_SD_pred":
        col_key = "dysub_SD_pred_pred_FD"
    return col_key

if "dysub_SD_pred_true_FD" in df2.columns.values:
    # rename column 
    df2["ppg_pred_SD"] = df2["dysub_SD_pred_pred_FD"]
    # calculate predicted first derivative 
    df2["ppg_pred_FD"] = np.cumsum(df2["ppg_pred_SD"])
    # calculate predicted raw PPG signal 
    df2["ppg_pred"] = np.cumsum(df2["ppg_pred_FD"])

try:
    true_SD = df2["dysub_SD_pred_true_FD"].values
except KeyError:
    true_SD = df1["ppg_true_SD"].values

df1_col_key = get_col_key(file1)
pred1_SD = df1[df1_col_key].values
df2_col_key = get_col_key(file2)
pred2_SD = df2[df2_col_key].values

def reshape_norm_detrend(sig):
    # reshape from long to wide
    sig_reshaped = np.reshape(np.array(sig), newshape=(-1, 30))
    # cumsum each window
    sig_cumsum = np.cumsum(sig_reshaped, axis=1)
    # detrend each window
    sig_cumsum[:, :-1] = detrend(sig_cumsum[:, :-1] , axis=1)
    # normalize windows
    sig_cumsum = (sig_cumsum - np.nanmean(sig_cumsum, axis=1, keepdims=True)) / np.nanstd(sig_cumsum, axis=1, keepdims=True)
    # ss = []
    # for s in sig_cumsum:
    #     ss.append(upsample_signal(s[:-1], upsample_freq=upsample_freq))
    # return np.array(ss).flatten()
    return sig_cumsum.flatten()

true_FD = reshape_norm_detrend(true_SD)
true_raw = reshape_norm_detrend(true_FD)

pred1_FD = reshape_norm_detrend(pred1_SD)
pred1_raw = reshape_norm_detrend(pred1_FD)

pred2_FD = reshape_norm_detrend(pred2_SD)
pred2_raw = reshape_norm_detrend(pred2_FD)

# %% remove NaN values 
true_SD = true_SD[~np.isnan(true_SD)]
true_FD = true_FD[~np.isnan(true_FD)]
true_raw = true_raw[~np.isnan(true_raw)]

pred1_SD = pred1_SD[~np.isnan(pred1_SD)]
pred1_FD = pred1_FD[~np.isnan(pred1_FD)]
pred1_raw = pred1_raw[~np.isnan(pred1_raw)]

pred2_SD = pred2_SD[~np.isnan(pred2_SD)]
pred2_FD = pred2_FD[~np.isnan(pred2_FD)]
pred2_raw = pred2_raw[~np.isnan(pred2_raw)]

# %% upsample signal to smooth
true_raw_upsample = upsample_signal(true_raw, upsample_freq=upsample_freq)
pred1_raw_upsample = upsample_signal(pred1_raw, upsample_freq=upsample_freq)
pred2_raw_upsample = upsample_signal(pred2_raw, upsample_freq=upsample_freq)

true_FD_upsample = upsample_signal(true_FD, upsample_freq=upsample_freq)
pred1_FD_upsample = upsample_signal(pred1_FD, upsample_freq=upsample_freq)
pred2_FD_upsample = upsample_signal(pred2_FD, upsample_freq=upsample_freq)

true_SD_upsample = upsample_signal(true_SD, upsample_freq=upsample_freq)
pred1_SD_upsample = upsample_signal(pred1_SD, upsample_freq=upsample_freq)
pred2_SD_upsample = upsample_signal(pred2_SD, upsample_freq=upsample_freq)

# %% reshape, replace the gaps in the signal after upsampling, and flatten again
# win_reshape_size = int(upsample_freq-(1/30)*upsample_freq)
# print(win_reshape_size)
# tt = np.reshape(true_raw_upsample, newshape=(-1, win_reshape_size))
# tt = np.pad(tt, ((0, 0), (0, 10)), mode='constant', constant_values=np.nan)
# plt.plot(tt.flatten())
# plt.show()

# # %% for generating plots for architecture diagram
# fig, ax = plt.subplots(3, 1, figsize=(4, 18))
# ax[1].plot(true_FD_upsample, label="True FD", c=DARK)
# ax[1].plot(pred2_FD_upsample, label=f"{label2} FD", linestyle=LINE_STYLE, linewidth=LINE_WIDTH, c=GREEN)
# ax[2].plot(true_SD_upsample, label="True SD", c=DARK)
# ax[2].plot(pred2_SD_upsample, label=f"{label2} SD", linestyle=LINE_STYLE, linewidth=LINE_WIDTH, c=GREEN)
# plt.show()

# %% 
fig, ax = plt.subplots(3, 1, figsize=(16, 18))
xs = np.linspace(0, true_SD_upsample.shape[0], 
    num=int((true_SD_upsample.shape[0]/fs)*upsample_freq))

# bandpass filter the raw signal
min_freq = 0.75
max_freq = 4.0
[b, a] = scipy.signal.butter(1, [min_freq / upsample_freq * 2, max_freq / upsample_freq * 2], btype='bandpass')
true_raw_upsample = scipy.signal.filtfilt(b, a, np.double(true_raw_upsample))
pred1_raw_upsample = scipy.signal.filtfilt(b, a, np.double(pred1_raw_upsample))
pred2_raw_upsample = scipy.signal.filtfilt(b, a, np.double(pred2_raw_upsample))


ax[0].plot(true_raw_upsample, label="True PPG", c=DARK)
ax[0].plot(pred1_raw_upsample, label=f"{label1}", linestyle=LINE_STYLE, linewidth=LINE_WIDTH, c=BLUE)
ax[0].plot(pred2_raw_upsample, label=f"{label2}", linestyle=LINE_STYLE, linewidth=LINE_WIDTH, c=GREEN)
# ax[0].legend()

ax[1].plot(true_FD_upsample, label="True FD", c=DARK)
ax[1].plot(pred1_FD_upsample, label=f"{label1} FD", linestyle=LINE_STYLE, linewidth=LINE_WIDTH, c=BLUE)
ax[1].plot(pred2_FD_upsample, label=f"{label2} FD", linestyle=LINE_STYLE, linewidth=LINE_WIDTH, c=GREEN)
# ax[1].legend()

ax[2].plot(true_SD_upsample, label="True SD", c=DARK)
ax[2].plot(pred1_SD_upsample, label=f"{label1} SD", linestyle=LINE_STYLE, linewidth=LINE_WIDTH, c=BLUE)
ax[2].plot(pred2_SD_upsample, label=f"{label2} SD", linestyle=LINE_STYLE, linewidth=LINE_WIDTH, c=GREEN)

# # calculate systolic peak times
true_sys_upstroke = calc_PPG_peaks(true_SD_upsample, 
    distance=int(0.5*upsample_freq), height=(0.1,))

# true_sys_upstroke_plot = true_sys_upstroke - int((2/fs)*upsample_freq)
# ax[0].scatter(true_sys_upstroke_plot, true_raw_upsample[true_sys_upstroke_plot], c="red", label="Systolic Peak")
ax[2].scatter(true_sys_upstroke, true_SD_upsample[true_sys_upstroke], c="red", label="Systolic Peak", s=SCATTER_SIZE)

first_iter = True
for upsampled_sig in [true_SD_upsample, pred1_SD_upsample, pred2_SD_upsample]:
    # # calculate systolic peak times
    sys_upstroke = calc_PPG_peaks(upsampled_sig, 
        distance=int(0.5*upsample_freq), height=(0.1,))
    # calculate dicrotic notch times 
    dicrotic_peaks = calc_PPG_peaks(upsampled_sig, 
        distance=int(0.1*upsample_freq), height=(0.,))

    # ax[2].scatter(sys_upstroke, upsampled_sig[sys_upstroke], c="red")
    # ax.scatter(dicrotic_peaks, upsampled_sig[dicrotic_peaks], c="blue")

    sys_dicr_time_per_frame = np.empty(upsampled_sig.shape[0])
    sys_dicr_time_per_frame.fill(np.nan)
    
    # add some amout of buffer to make sure we don't grab the wrong dicrotic 
    # notch candidates
    buffer_window = 0.05 * upsample_freq
    # get all times in waveform window
    times = []
    for m in true_sys_upstroke:
        # from the systolic peak, get the next closest dicrotic notch
        dicrotic_peak_candidates = dicrotic_peaks[dicrotic_peaks > (m + buffer_window)]
        # if there are no candidates, skip
        if dicrotic_peak_candidates.shape[0] == 0:
            continue
        closest_dicrotic_peak = dicrotic_peak_candidates[np.argmin(dicrotic_peak_candidates - m)]
        
        if first_iter:
            ax[2].scatter(closest_dicrotic_peak, upsampled_sig[closest_dicrotic_peak], c="blue", label="Dicrotic Notch", s=SCATTER_SIZE)
            first_iter = False
        else:
            ax[2].scatter(closest_dicrotic_peak, upsampled_sig[closest_dicrotic_peak], c="blue", s=SCATTER_SIZE)


        # calculate time between closest dicrotic notch and systolic peak
        time_diff = closest_dicrotic_peak - m
        times.append(time_diff)

        # fill in time diff for frames in window
        sys_dicr_time_per_frame[m:closest_dicrotic_peak] = time_diff
        ax[2].plot([m, closest_dicrotic_peak], 
            [upsampled_sig[m], upsampled_sig[m]], 
            linestyle="--", c=RED)

ax[0].set_ylabel("Normalized Raw\nPPG Amplitude", fontsize=LABEL_FONT_SIZE)
ax[1].set_ylabel("Normalized 1st Derivative\nPPG Amplitude", fontsize=LABEL_FONT_SIZE)
ax[2].set_ylabel("Normalized 2nd Derivative\nPPG Amplitude", fontsize=LABEL_FONT_SIZE)
ax[2].set_xlabel("Time (seconds)", fontsize=LABEL_FONT_SIZE)

for a in ax:
    a.set_xticks(np.arange(0, upsampled_sig.shape[0]+1, upsample_freq))
    a.set_xticklabels([int(x/upsample_freq) for x in a.get_xticks()], fontsize=LABEL_FONT_SIZE)
    a.tick_params(labelsize=LABEL_FONT_SIZE)
    # a.legend(loc="upper center", fontsize=LABEL_FONT_SIZE, 
    # frameon=True, markerfirst=False, bbox_to_anchor=(0.5, 1.1),
    #       ncol=3)
ax[0].legend(loc="upper center", fontsize=LABEL_FONT_SIZE, 
    frameon=False, markerfirst=False, bbox_to_anchor=(0.5, 1.15), ncol=3)
plt.tight_layout()
plt.savefig(os.path.join(plot_dir, "{}_waveform_compare.png".format(os.path.splitext(file_to_compare)[0])))
plt.savefig(os.path.join(plot_dir, "{}_waveform_compare.svg".format(os.path.splitext(file_to_compare)[0])))
plt.show()

# %%
fig, ax = plt.subplots(1, 1, figsize=(16, 6))
true_sys_dicrotic_time, true_sys_dicrotic_time_ax, _ = calc_time_sys_to_dicrotic(true_SD, diff=False, fs=fs, ax=ax, c=DARK, linewidth=LINE_WIDTH,)
pred_sys_dicrotic_time1, pred1_sys_dicrotic_time_ax, _ = calc_time_sys_to_dicrotic(pred1_SD, reference_signal=true_SD, diff=False, fs=fs, ax=true_sys_dicrotic_time_ax, linestyle="-", c=BLUE, linewidth=LINE_WIDTH,)
pred_sys_dicrotic_time2, pred2_sys_dicrotic_time_ax, _ = calc_time_sys_to_dicrotic(pred2_SD, reference_signal=true_SD, diff=False, fs=fs, ax=pred1_sys_dicrotic_time_ax, linestyle="-", c=GREEN, linewidth=LINE_WIDTH,)
plt.legend([
    # "True time", "True time (10s window)", 
    # f"{label1} time", f"{label1} time (10s window)",
    # f"{label2} time", f"{label2} time (10s window)",
    # ], fontsize=LABEL_FONT_SIZE, loc="upper center", 
    "True LVET",
    f"{label1}",
    f"{label2}",
    ], fontsize=LABEL_FONT_SIZE, loc="upper center", 
    frameon=False, markerfirst=False, bbox_to_anchor=(0.5, 1.15), ncol=3)
plt.grid()
ax.tick_params(labelsize=LABEL_FONT_SIZE)
ax.xaxis.label.set_size(LABEL_FONT_SIZE)
ax.yaxis.label.set_size(LABEL_FONT_SIZE)
plt.tight_layout()
plt.savefig(os.path.join(plot_dir, "{}_sys_dicr_time.png".format(os.path.splitext(file_to_compare)[0])))
plt.savefig(os.path.join(plot_dir, "{}_sys_dicr_time.svg".format(os.path.splitext(file_to_compare)[0])))
plt.show()
# %%
plt.hist(true_sys_dicrotic_time, label="True", color="black")
plt.hist(pred_sys_dicrotic_time1, label="pred1", alpha=0.5)
plt.hist(pred_sys_dicrotic_time2, label="pred2", alpha=0.5)
plt.legend(fontsize=14)
plt.show()
# %%
