# %% Import and setup

import torch
from scipy import stats

import matplotlib
import os
import pandas as pd
import re
import numpy as np
import seaborn as sns

os.environ["TOKENIZERS_PARALLELISM"] = "false"

from _objects.plot_config import *
from _objects.configs import *

from _objects.sae_models import *
from _objects.model_configs import *

# # from _llm_based.utils.sampling_utils import phq9_qs_inv_map
# from utils.analysis_utils import spec_levels, spec_levels_crit, spec_level_dict, spec_names, qa_loc_key_names, \
#     get_responses, get_responses_avg_merged, get_responses_merged_diff, responses_totals_shuffled
# from utils.plot_utils import set_scatter_axes, set_a_hist
# from objects.configs import *
# from objects.plot_config import *
# from _llm_based.objects.model_configs import *
# from _llm_based.sae_models import *

matplotlib.use('Qt5Agg')
import matplotlib.pyplot as plt

plt.ion()

pc = PlotConfig()
bools = Bools()

paths = Paths(files_dir='', plots_subdir='', plots_subsubdir='')

testCase_dir = 'test3_mood_induction'
paths.plots_path = f'{testCase_dir}/_plots/_sae/'
Path(paths.plots_path).mkdir(parents=True, exist_ok=True)
paths.data_proc_dir = f'{testCase_dir}/_data/processed/'
paths.data_proc_dir_sae = f'{testCase_dir}/_data/processed/sae/'
paths.data_path = f'{testCase_dir}/_data/'
paths.prev_path_cd = f'../construct_detector/'
paths.prev_path_int = f'../online_tasks/qs_intervention/'

paths.states_path = f'{paths.prev_path_int}_analysis/_llm_based/files/'  # old path

# data_proc_dir = f'_data/processed/'
# plots_path = '_plots/sae/'
# files_path = '_llm_based/files/'


mname = 'gemma2-9b-it'
sample_config = SampleConfig(paths, model_name=mname, qs_name='phq9', instr_name='instr2', temp='', top_p='')
layer_idxs = np.arange(sample_config.L // 2, sample_config.L)

do_zscores = True
max_epochs = 1000
device_name = 'mps'

bools.loadMe = True
bools.saveMe = False
bools.savePlot = True
bools.saveFiles = True
bools.saveFiles = False

id_cols = ['sub', 'condition', 'group', 'autobio']
hue_order = ['MH', 'ML']
hue_cols = ['tab:blue', 'tab:orange']
hue_cols2 = ['tab:green', 'tab:purple']

fig_no = 'Fig6'
# %% Load data
sub_outliers = ['sub95_v2b', 'sub96_v2b', 'sub150_v2b', 'sub133_v2b', 'sub180_v3', 'sub131_v3']

phq9_diff = pd.read_csv(f'{paths.data_proc_dir}phq9_diff_data_wide.csv')
phq9_diff = phq9_diff[(phq9_diff['autobio']) & (~phq9_diff['sub'].isin(sub_outliers))]
mood_diff = pd.read_csv(f'{paths.data_proc_dir}mood_data.csv')
mood_diff = mood_diff[(mood_diff['autobio']) & (~mood_diff['sub'].isin(sub_outliers))]

recall_data = pd.read_csv(f"{paths.data_proc_dir}recall_data.csv")
recall_data = recall_data[(recall_data['autobio']) & (~recall_data['sub'].isin(sub_outliers))]

# phq9_diff_text_data = pd.read_csv(f'{data_proc_dir}phq9_diff_text.csv')
text_data = pd.read_csv(f'{paths.data_proc_dir}text_data.csv')
text_data = text_data[(text_data['autobio']) & (~text_data['sub'].isin(sub_outliers))]

phq9_q_names = ['phq9_q' + str(q + 1) for q in range(9)]
# %% Set configs
exp_name = 'SAE_Mistral_gemma_llama_exp_v3_best'
model_dir = f'{paths.prev_path_cd}saved_models/{exp_name}/{mname}/'

best_configs = os.listdir(model_dir)
best_configs = [t.replace('.pt', '').replace('best_', '') for t in best_configs]
best_config_dicts = [{s.split('-')[0]: '-'.join(s.split('-')[1:]) for s in t.split('^^')} for t in best_configs]
q_case = list(best_config_dicts[0].values())[2]

exp_config = ExpConfig(sample_config, q_case=q_case, do_zscores=do_zscores, max_epochs=max_epochs)
# Load model
if mname == 'MistralOo':
    exp_config.d_h = 4096  #
elif mname == 'gemma2-9b-it':
    exp_config.d_h = 3584

# %% Calculate average-across-layer sae scores for each text type for each ppts
text_types = ['act', 'mood', 'energy', 'pospert']
# text_types =['act']
for text_type in text_types:
    # for text_type in text_types[0:2]:
    sae_preds_df = pd.DataFrame()
    if bools.saveFiles:
        for l, layer_idx in enumerate(layer_idxs):
            # for l, layer_idx in enumerate(layer_idxs[0:2]):
            layer_idx_idx = np.where(layer_idxs == layer_idx)[0][0]
            layer_name = f'layer-{layer_idx + 1}'
            # layer_module_name = f'model.layers.{layer_idx}'
            model_layer_config = [c for c in best_configs if layer_name in c][0]
            model_layer_config_dict = \
            [{s.split('-')[0]: '-'.join(s.split('-')[1:]) for s in t.split('^^')} for t in [model_layer_config]][0]
            model_fp = f'{model_dir}best_{model_layer_config}.pt'
            hparams = list(model_layer_config_dict.values())
            q_case = hparams[2]
            batch_size = int(hparams[3])
            optim_lr = float(hparams[4])
            sparsity_coeff = float(hparams[5])
            qs_coeff = int(hparams[6])
            sae_factor = int(hparams[7])
            tied_weights = hparams[8] == 'True'
            sae_name = hparams[9]
            sae_cfg = SAE_Config(exp_config, device=device_name, x_m=sae_factor, sparse_coeff=sparsity_coeff,
                                 tied_weights=tied_weights, qs_coeff=qs_coeff)

            model = SAE4(sae_cfg).to(device_name)

            model.load_state_dict(torch.load(model_fp, map_location=torch.device(device_name)))
            model.eval()
            hs_path = f'{paths.states_path}text_hs/{text_type}/'
            pt_files = [f for f in os.listdir(hs_path) if '.pt' in f and sample_config.model_name in f]
            for pt_file in pt_files:
                sub = pt_file.split('^^')[0]
                if sub in text_data['sub'].unique():
                    group = text_data[text_data['sub'] == sub]['group'].values[0]
                    condition = text_data[text_data['sub'] == sub]['condition'].values[0]
                    is_autobio = text_data[text_data['sub'] == sub]['autobio'].values[0]

                    sub_hs_ts = torch.load(f'{hs_path}{pt_file}', weights_only=True)
                    sub_hs_ts = sub_hs_ts[layer_idx_idx, -1, :].type(torch.float).to(device_name).unsqueeze(0)
                    sub_hs_ts = sub_hs_ts / torch.linalg.norm(sub_hs_ts, axis=1, keepdims=True)
                    latent_qs, _, _ = model.infer(sub_hs_ts)
                    latent_qs = latent_qs.to('cpu').detach().numpy()

                    tmp_dict = {'sub': sub, 'condition': condition, 'group': group, 'autobio': is_autobio,
                                'text': text_type, 'layer': layer_name}
                    for q, q_name in enumerate(phq9_q_names):
                        tmp_dict['t'] = np.arange(1, latent_qs.shape[0] + 1)
                        tmp_dict[f'sae_{q_name}'] = latent_qs[:, q]
                    tmp_pd = pd.DataFrame(tmp_dict)
                    sae_preds_df = pd.concat([sae_preds_df, tmp_pd], axis=0)
                    del sub_hs_ts

            del model
        sae_preds_df_avg = sae_preds_df.groupby(['sub', 'condition', 'group', 'autobio', 'text', 't'], as_index=False)[
            [f'sae_phq9_q{qi + 1}' for qi in range(9)]].mean()
        sae_preds_df_avg.to_csv(f'{paths.data_proc_dir_sae}sae_preds_{mname}_{text_type}.csv', index=False)
    else:
        sae_preds_df = pd.read_csv(f'{paths.data_proc_dir_sae}sae_preds_{mname}_{text_type}.csv')

# %% Plot SAE for each text type
for text_type in text_types:
    sae_preds_df = pd.read_csv(f'{paths.data_proc_dir_sae}sae_preds_{mname}_{text_type}.csv')
    # % Plot all SAE scores
    plt.close('all')
    pc.r, pc.c, pc.mlt = 2, 5, 1
    pc.figsize = ((pc.c + 5.25) * pc.mlt, (pc.r + 3.55) * pc.mlt)
    fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize, sharex=False, sharey=True)
    axes = axes.flatten()
    # axes = np.array([[axes]])
    # axes = np.array([axes]).T
    # axes = np.array([axes])
    pc.onerow = True
    pc.i = 0
    pc.j = 0

    pc.axes = axes
    pc.ax_ts(16, 1.1)
    pc.l_fs(8, 0.85)
    ts = 14
    pc.xyt_ls(ts, 10)
    pc.ax_ls(16)
    pc.kde_lw = 3
    pc.p_lab_spec[2] = 12
    pc.p_lab_spec[0] = -0.1
    pc.p_lab_spec[1] = 1.1
    pc.dpi_val = 300

    ssize = 3.75
    slw = 0.2
    alpha = 0.5
    sae_names = [f'sae_{q}' for q in phq9_q_names]
    for pc.j, sae_name in enumerate(sae_names):

        bp = sns.boxplot(data=sae_preds_df, x='condition', y=sae_name, hue='condition', width=0.4,
                         linewidth=slw + 1, ax=pc.ax, legend=pc.j == 0, gap=0.2, palette=hue_cols, hue_order=hue_order)

        for patch in bp.patches:
            face_color = patch.get_facecolor()
            patch.set_facecolor((*face_color[:3], alpha))
        sns.violinplot(data=sae_preds_df, x='condition', y=sae_name, hue='condition', width=0.4, linewidth=slw + 1,
                       ax=pc.ax,
                       legend=False, gap=0.2, palette=hue_cols, hue_order=hue_order)
        sns.stripplot(data=sae_preds_df, x='condition', y=sae_name, hue='condition', edgecolor='black', linewidth=slw,
                      dodge=True,
                      jitter=0.05, size=ssize, legend=False, ax=pc.ax, palette=hue_cols, hue_order=hue_order)

        pc.ax.set_title(sae_name)
        pc.ax.set_ylabel('SAE Score')

    if pc.j == 8:
        pc.j = 9
        pc.ax.remove()
    plt.suptitle(f'SAE Scores on the last token of the {text_type} text')
    plt.tight_layout()
    if bools.savePlot:
        plt.savefig(f'{paths.plots_path}sae_scores_{mname}_{text_type}.pdf', dpi=300)

# %% Plot SAE scores against Q2 diff
text_type = 'act'
sae_preds_df = pd.read_csv(f'{paths.data_proc_dir_sae}sae_preds_{mname}_{text_type}.csv')
sae_phq9 = pd.merge(sae_preds_df, phq9_diff, on=id_cols)
sae_phq9.to_csv(f"{paths.data_proc_dir_sae}sae_phq9_{mname}_{text_type}.csv", index=False)

sae_phq9_mh = sae_phq9[sae_phq9['condition'] == 'MH']
sae_phq9_ml = sae_phq9[sae_phq9['condition'] == 'ML']
plt.close('all')
pc.r, pc.c, pc.mlt = 1, 2, 1.75
pc.figsize = ((pc.c + 6.25) * pc.mlt, (pc.r + 1.25) * pc.mlt)

fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize, sharex=False, sharey=False)
# axes = np.array([[axes]])
# axes = np.array([axes]).T
axes = np.array([axes])
pc.onerow = True
pc.i = 0
pc.j = 0

pc.axes = axes
pc.ax_ts(16, 1.1)
pc.l_fs(12, 0.85)
ts = 14
pc.xyt_ls(ts, 10)
pc.ax_ls(16)
pc.kde_lw = 3
pc.p_lab_spec[2] = 12
pc.p_lab_spec[0] = -0.1
pc.p_lab_spec[1] = 1.1
pc.dpi_val = 300

ssize = 5.75
slw = 0.2
alpha = 0.5
sns.regplot(sae_phq9_mh, x='sae_phq9_q2', y='phq9_q2', color='tab:blue', label='MH', ax=pc.ax)
sns.regplot(sae_phq9_ml, x='sae_phq9_q2', y='phq9_q2', color='tab:orange', label='ML', ax=pc.ax)
# sns.regplot(lmer_pd_ml, x='actQ2sim', y='q2Diff', color='tab:orange', label='ML', ax=pc.ax)
pc.ax.set_xlabel('SAE Q2 score')
pc.ax.set_ylabel('PHQ-9 Q2 change.')
# .legend()
pc.ax.legend()
#
# pc.j = 1
# alpha = 0.5
# sns.regplot(sae_phq9_mh, x='sae_phq9_q4', y='phq9_q2', color='tab:blue', label='MH', ax=pc.ax)
# sns.regplot(sae_phq9_ml, x='sae_phq9_q4', y='phq9_q2', color='tab:orange', label='ML', ax=pc.ax)
# # sns.regplot(lmer_pd_ml, x='actQ2sim', y='q2Diff', color='tab:orange', label='ML', ax=pc.ax)
# pc.ax.set_xlabel('SAE Q4 score')
# pc.ax.set_ylabel('PHQ-9 Q2 change.')
# # .legend()
# pc.ax.legend()

pc.j = 1
bp = sns.boxplot(data=sae_phq9, x='condition', y='sae_phq9_q2', width=0.4,
                 linewidth=slw + 1, ax=pc.ax, legend=True, gap=0.2)
for patch in bp.patches:
    face_color = patch.get_facecolor()
    patch.set_facecolor((*face_color[:3], alpha))
sns.violinplot(data=sae_phq9, x='condition', y='sae_phq9_q2', width=0.4, linewidth=slw + 1, ax=pc.ax,
               legend=False, gap=0.2)
sns.stripplot(data=sae_phq9, x='condition', y='sae_phq9_q2', edgecolor='black', linewidth=slw, dodge=True,
              jitter=0.05, size=ssize, legend=False, ax=pc.ax)
pc.ax.set_ylabel('SAE Q2 score')

# pc.j = 3
# bp = sns.boxplot(data=sae_phq9, x='condition', y='sae_phq9_q4', width=0.4,
#                  linewidth=slw + 1, ax=pc.ax, legend=True, gap=0.2)
#
# for patch in bp.patches:
#     face_color = patch.get_facecolor()
#     patch.set_facecolor((*face_color[:3], alpha))
# sns.violinplot(data=sae_phq9, x='condition', y='sae_phq9_q4', width=0.4, linewidth=slw + 1, ax=pc.ax,
#                legend=False, gap=0.2)
# sns.stripplot(data=sae_phq9, x='condition', y='sae_phq9_q4', edgecolor='black', linewidth=slw, dodge=True,
#               jitter=0.05, size=ssize, legend=False, ax=pc.ax)
#
# pc.ax.set_ylabel('SAE Q4 score')
# pc.ax.set_title(int_labs[pc.i])
# pc.ax.set_ylim(-0.2, 0.75)
plt.tight_layout()
# act_sev_pd.to_csv('_data/act_sev_pd.csv', index=False)
if bools.savePlot:
    plt.savefig(f"{paths.plots_path}SAEscore_vs_q2Diff_{mname}.pdf", dpi=300)

# %% Plot SAE scores against mood and recall diff
text_type = 'act'
sae_preds_df = pd.read_csv(f'{paths.data_proc_dir_sae}sae_preds_{mname}_{text_type}.csv')
sae_mood = pd.merge(sae_preds_df, mood_diff, on=id_cols)
sae_recall = pd.merge(sae_preds_df, recall_data, on=id_cols)
sae_mood.to_csv(f"{paths.data_proc_dir_sae}sae_mood_{mname}_{text_type}.csv", index=False)
sae_recall.to_csv(f"{paths.data_proc_dir_sae}sae_recall_{mname}_{text_type}.csv", index=False)

sae_mood_mh = sae_mood[sae_mood['condition'] == 'MH']
sae_mood_ml = sae_mood[sae_mood['condition'] == 'ML']
sae_recall_mh = sae_recall[sae_recall['condition'] == 'MH']
sae_recall_ml = sae_recall[sae_recall['condition'] == 'ML']
sae_phq9_mh = sae_phq9[sae_phq9['condition'] == 'MH']
sae_phq9_ml = sae_phq9[sae_phq9['condition'] == 'ML']

plt.close('all')
pc.r, pc.c, pc.mlt = 1, 3, 1.5
pc.figsize = ((pc.c + 6.25) * pc.mlt, (pc.r + 1.25) * pc.mlt)

fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize, sharex=False, sharey=False)
# axes = np.array([[axes]])
# axes = np.array([axes]).T
axes = np.array([axes])
pc.onerow = True
pc.i = 0
pc.j = 0

pc.axes = axes
pc.ax_ts(16, 1.1)
pc.l_fs(12, 0.85)
ts = 14
pc.xyt_ls(ts, ts)
pc.ax_ls(18)
pc.kde_lw = 3
pc.p_lab_spec[2] = 12
pc.p_lab_spec[0] = -0.1
pc.p_lab_spec[1] = 1.1
pc.dpi_val = 300

ssize = 5.75
slw = 0.2
alpha = 0.5

pc.j = 0
sns.regplot(sae_phq9_mh, x='sae_phq9_q2', y='phq9_q2', color='tab:blue', label='MH', ax=pc.ax)
sns.regplot(sae_phq9_ml, x='sae_phq9_q2', y='phq9_q2', color='tab:orange', label='ML', ax=pc.ax)
# sns.regplot(lmer_pd_ml, x='actQ2sim', y='q2Diff', color='tab:orange', label='ML', ax=pc.ax)
pc.ax.set_xlabel('SAE Q2 score')
pc.ax.set_ylabel('PHQ-9 Q2')
# .legend()
pc.ax.text(pc.p_lab_spec[0] + 0.05, pc.p_lab_spec[1], pc.p_labs[pc.i, pc.j], transform=pc.ax.transAxes,
           fontweight='bold',
           va='top', ha='right',
           fontsize=pc.p_lab_spec[2])

pc.j = 1
sns.regplot(sae_mood_mh, x='sae_phq9_q2', y='mood_diff', color='tab:blue', label='MH', ax=pc.ax)
sns.regplot(sae_mood_ml, x='sae_phq9_q2', y='mood_diff', color='tab:orange', label='ML', ax=pc.ax)

pc.ax.set_xlabel('SAE Q2 score')
pc.ax.set_ylabel('Momentary mood')
pc.ax.legend(title='condition')
pc.ax.text(pc.p_lab_spec[0] + 0.05, pc.p_lab_spec[1], pc.p_labs[pc.i, pc.j], transform=pc.ax.transAxes,
           fontweight='bold',
           va='top', ha='right',
           fontsize=pc.p_lab_spec[2])

#
pc.j = 2
sns.regplot(sae_recall_mh, x='sae_phq9_q2', y='recall_diffSent', color='tab:blue', label='MH', ax=pc.ax)
sns.regplot(sae_recall_ml, x='sae_phq9_q2', y='recall_diffSent', color='tab:orange', label='ML', ax=pc.ax)
# sns.regplot(lmer_pd_ml, x='actQ2sim', y='q2Diff', color='tab:orange', label='ML', ax=pc.ax)
pc.ax.set_xlabel('SAE Q2 score')
pc.ax.set_ylabel('Recall sentiment')
pc.ax.text(pc.p_lab_spec[0] + 0.05, pc.p_lab_spec[1], pc.p_labs[pc.i, pc.j], transform=pc.ax.transAxes,
           fontweight='bold',
           va='top', ha='right',
           fontsize=pc.p_lab_spec[2])
plt.suptitle(f'Affective change measures (FU - Baseline) vs SAE Q2 score')
plt.tight_layout()
# act_sev_pd.to_csv('_data/act_sev_pd.csv', index=False)
if bools.savePlot:
    plt.savefig(f"{paths.plots_path}{fig_no}_p3_{testCase_dir}_measures-vs_sae.pdf", dpi=300)

# %% Calc FU (pospert)-baseline - averae across saes
sae_preds_df_all_mood = pd.read_csv(f'{paths.data_proc_dir_sae}sae_preds_{mname}_mood.csv')
sae_preds_df_all_mood = sae_preds_df_all_mood.melt(id_vars=id_cols + ['t', 'text'], var_name='sae_var',
                                                   value_name='sae_score')
sae_preds_df_all_mood = sae_preds_df_all_mood.groupby(id_cols, as_index=False)['sae_score'].mean()
sae_preds_df_all_mood.rename(columns={'sae_score': 'sae_score_mood'}, inplace=True)

sae_preds_df_all_energy = pd.read_csv(f'{paths.data_proc_dir_sae}sae_preds_{mname}_energy.csv')
sae_preds_df_all_energy = sae_preds_df_all_energy.melt(id_vars=id_cols + ['t', 'text'], var_name='sae_var',
                                                       value_name='sae_score')
sae_preds_df_all_energy = sae_preds_df_all_energy.groupby(id_cols, as_index=False)['sae_score'].mean()
sae_preds_df_all_energy.rename(columns={'sae_score': 'sae_score_energy'}, inplace=True)

sae_preds_df_all_pospert = pd.read_csv(f'{paths.data_proc_dir_sae}sae_preds_{mname}_pospert.csv')
sae_preds_df_all_pospert = sae_preds_df_all_pospert.melt(id_vars=id_cols + ['t', 'text'], var_name='sae_var',
                                                         value_name='sae_score')
sae_preds_df_all_pospert = sae_preds_df_all_pospert.groupby(id_cols, as_index=False)['sae_score'].mean()
sae_preds_df_all_pospert.rename(columns={'sae_score': 'sae_score_pospert'}, inplace=True)

sae_pred_df_avg = pd.merge(sae_preds_df_all_mood, sae_preds_df_all_energy, on=id_cols)
sae_pred_df_avg = pd.merge(sae_pred_df_avg, sae_preds_df_all_pospert, on=id_cols)

sae_pred_df_avg['sae_diff_fuB_mood'] = sae_pred_df_avg['sae_score_pospert'] - (sae_pred_df_avg['sae_score_mood'])
sae_pred_df_avg['sae_diff_fuB_avg'] = sae_pred_df_avg['sae_score_pospert'] - (
            sae_pred_df_avg['sae_score_mood'] + sae_pred_df_avg['sae_score_energy']) / 2
sae_pred_df_avg.to_csv(f'{paths.data_proc_dir_sae}sae_diff_fuB_avg.csv', index=False)



# %% Plot sae measures netween conditon
plt.close('all')
pc.r, pc.c, pc.mlt = 1, 2, 1.8
pc.figsize = ((pc.c + 2.25) * pc.mlt, (pc.r + 1.) * pc.mlt)

fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize, sharex=False, sharey=False)
# axes = np.array([[axes]])
# axes = np.array([axes]).T
axes = np.array([axes])
pc.onerow = False
pc.axes=  axes
pc.i = 0
pc.j = 0
pc.ax_ts(13.5, 1.1)
pc.l_fs(12, 0.85)
ts = 13
pc.xyt_ls(ts, ts)
pc.ax_ls(13)
pc.kde_lw = 3
pc.p_lab_spec[2] = 14
pc.p_lab_spec[0] = -0.2
pc.p_lab_spec[1] = 1.05
pc.dpi_val = 300

ssize = 5.75
slw = 0.2
alpha = 0.3
b_width = 0.2
v_width = 0.4
b_lw = 1.5
s_lw = 0.8
s_size = 2
s_out_size = 8

bp = sns.boxplot(data=sae_phq9, x='condition', hue='condition', y='sae_phq9_q2', width=b_width,
                 linewidth=slw + 1, ax=pc.ax, legend=False, gap=0.2, hue_order=hue_order, palette=hue_cols)
for patch in bp.patches:
    face_color = patch.get_facecolor()
    patch.set_facecolor((*face_color[:3], alpha))
sns.violinplot(data=sae_phq9, x='condition', y='sae_phq9_q2', hue='condition', width=v_width, linewidth=slw + 1,
               ax=pc.ax,
               legend=False, gap=0.2, hue_order=hue_order, palette=hue_cols)
sns.stripplot(data=sae_phq9, x='condition', y='sae_phq9_q2', hue='condition', edgecolor='black',
              linewidth=slw, dodge=True,
              jitter=0.05, size=ssize, legend=False, ax=pc.ax, hue_order=hue_order, palette=hue_cols)
pc.ax.text(pc.p_lab_spec[0] + 0.05, pc.p_lab_spec[1], pc.p_labs[pc.i, pc.j], transform=pc.ax.transAxes,
           fontweight='bold',
           va='top', ha='right',
           fontsize=pc.p_lab_spec[2])

pc.ax.set_ylabel('SAE Q2 Score')
pc.ax.set_xlabel('Condition')
pc.ax.set_title('SAE Q2 Score')


pc.j = 1
bp = sns.boxplot(data=sae_pred_df_avg, x='condition', hue='condition', y='sae_diff_fuB_avg', width=b_width,
                 linewidth=slw + 1, ax=pc.ax, legend=False, gap=0.2, hue_order=hue_order, palette=hue_cols)
for patch in bp.patches:
    face_color = patch.get_facecolor()
    patch.set_facecolor((*face_color[:3], alpha))
sns.violinplot(data=sae_pred_df_avg, x='condition', y='sae_diff_fuB_avg', hue='condition', width=v_width, linewidth=slw + 1,
               ax=pc.ax,
               legend=False, gap=0.2, hue_order=hue_order, palette=hue_cols)
sns.stripplot(data=sae_pred_df_avg, x='condition', y='sae_diff_fuB_avg', hue='condition', edgecolor='black',
              linewidth=slw, dodge=True,
              jitter=0.05, size=ssize, legend=False, ax=pc.ax, hue_order=hue_order, palette=hue_cols)
pc.ax.text(pc.p_lab_spec[0] + 0.05, pc.p_lab_spec[1], pc.p_labs[pc.i, pc.j], transform=pc.ax.transAxes,
           fontweight='bold',
           va='top', ha='right',
           fontsize=pc.p_lab_spec[2])

pc.ax.set_ylabel('FU-Baseline')
pc.ax.set_xlabel('Condition')
pc.ax.set_title('Average SAE score difference\npositive reevaluation vs baseline')
plt.tight_layout()
if bools.savePlot:
    plt.savefig(f"{paths.plots_path}{fig_no}_p2_{testCase_dir}_sae-measures.pdf", dpi=300)

