# %%
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import matplotlib.markers as markers

GREEN = '#00946c'
BLUE = '#568ce9'

# %%
file1 = "/data1/ippg/pttraining/username_results/trainSynFull_testAFRL_ppg_8epoch_modelv1_baseline/error_df.csv"
file2 = "/data1/ippg/pttraining/username_results/trainSynFull_testAFRL_ppgSD_8epoch_modelv1_SDframes/error_df.csv"

df1 = pd.read_csv(file1, header=0)
df2 = pd.read_csv(file2, header=0)


# %%
def plot_bland_altman_sys_dicr_time(true, pred, plot_save_dir,
    file_name_prefix="sys_dicr_time_bland_altman_plot", 
    normalize=False, font_size=24, ci=1.96,
    xlim=None, ylim=None,
    label=None, xlabel=None, ylabel=None,
    marker_type='s', marker_color='blue', s=100,
    add_text=True,
    ax=None, **kwargs):
    
    
    marker = markers.MarkerStyle(marker=marker_type)
    if not ax:
        _, ax  = plt.subplots(1, 1, figsize=(8, 6))
    xvals = true
    yvals = pred
    if normalize:
        yvals = (yvals / true) * 100.
    ax.scatter(xvals, yvals, 
        color=marker_color, marker=marker, facecolors='none', s=s, linewidths=2,
        label=label)
    
    # annotate plot iwith mean and 95% CI
    mean_val = np.mean(yvals)
    std_val = np.std(yvals)
    if xlim:
        ax.set_xlim(xlim)
    line_x_vals = np.arange(*ax.get_xlim())
    min_x, max_x = ax.get_xlim()
    ax.set_xticks(np.linspace(min_x, max_x, num=5, dtype=int))
    ax.hlines(mean_val, min_x, max_x, colors=marker_color, linewidth=3.0)
    ax.hlines(mean_val+ci*std_val, min_x, max_x, colors=marker_color, linestyles=':', linewidth=2.5)
    ax.hlines(mean_val-ci*std_val, min_x, max_x, colors=marker_color, linestyles=':', linewidth=2.5)
    if add_text:
        ax.text(x=max_x+10, y=mean_val, s="{:.1f}".format(mean_val), fontsize=font_size,)
        ax.text(x=max_x+10, y=mean_val+ci*std_val, s="{:.0f} (+{:.2f}SD)".format(mean_val+ci*std_val, ci), fontsize=font_size,)
        ax.text(x=max_x+10, y=mean_val-ci*std_val, s="{:.0f} (-{:.2f}SD)".format(mean_val-ci*std_val, ci), fontsize=font_size,)
    if ylim:
        ax.set_ylim(ylim)
    ax.set_xlabel(xlabel, fontsize=font_size)
    ax.set_ylabel(ylabel, fontsize=font_size)
    ax.tick_params(labelsize=font_size)
    return ax

# %%
plot_dir = "sys_dicr_plots_baseline_vs_SD_SDframes"
os.makedirs(plot_dir, exist_ok=True)

xlabel = "True LVET (ms)"
ylabel = "Absolute LVET Error (ms)"
xlim=(125, 325)
ylim=(-20, 300)

fig, ax = plt.subplots(1, 1, figsize=(8, 8))
ax = plot_bland_altman_sys_dicr_time(
            true=df1["smoothed_systolic_dicrotic_time"], 
            pred=df1["smoothed_systolic_dicrotic_time_mae"],
            plot_save_dir=plot_dir, 
            file_name_prefix="smoothed_systolic_dicrotic_time_SD",
            xlim=xlim, ylim=ylim,
            label="FD-optimized Predictions",
            marker_type='s', marker_color=BLUE,
            xlabel=xlabel, ylabel=ylabel,
            add_text=False,
            ax=ax)

ax = plot_bland_altman_sys_dicr_time(
            true=df2["smoothed_systolic_dicrotic_time_SD"], 
            pred=df2["smoothed_systolic_dicrotic_time_SD_mae"],
            plot_save_dir=plot_dir, 
            file_name_prefix="smoothed_systolic_dicrotic_time_SD",
            xlim=xlim, ylim=ylim,
            label="SD-optimized Predictions",
            marker_type='o', marker_color=GREEN,
            xlabel=xlabel, ylabel=ylabel,
            add_text=False,
            ax=ax)

plt.legend(loc="upper right", fontsize=24, frameon=False, markerfirst=False)
plt.tight_layout()
plt.savefig(os.path.join(plot_dir, f"LVET_bland_altman.png"))
plt.savefig(os.path.join(plot_dir, f"LVET_bland_altman.svg"))
plt.show()
plt.close()

# %%
xlabel = "True HR (BPM)"
ylabel = "Absolute HR Error (BPM)"
xlim = (40, 100)
ylim=None

fig, ax = plt.subplots(1, 1, figsize=(8, 8))
ax = plot_bland_altman_sys_dicr_time(
            true=df1["ppg_hr"], 
            pred=df1["error"],
            plot_save_dir=plot_dir, 
            file_name_prefix="smoothed_systolic_dicrotic_time_SD",
            xlim=xlim, ylim=ylim,
            label="FD-optimized Predictions",
            marker_type='s', marker_color=BLUE,
            xlabel=xlabel, ylabel=ylabel,
            add_text=False,
            ax=ax)

ax = plot_bland_altman_sys_dicr_time(
            true=df2["ppg_hr"], 
            pred=df2["error"],
            plot_save_dir=plot_dir, 
            file_name_prefix="smoothed_systolic_dicrotic_time_SD",
            xlim=xlim, ylim=ylim,
            label="SD-optimized Predictions",
            marker_type='o', marker_color=GREEN,
            xlabel=xlabel, ylabel=ylabel,
            add_text=False,
            ax=ax)

plt.legend(loc="upper right", fontsize=24, frameon=False, markerfirst=False)
plt.tight_layout()
plt.savefig(os.path.join(plot_dir, f"HR_bland_altman.png"))
plt.savefig(os.path.join(plot_dir, f"HR_bland_altman.svg"))
plt.show()
plt.close()
