import matplotlib

matplotlib.use('Qt5Agg')
import matplotlib.pyplot as plt

plt.ion()
import os
import pandas as pd
import seaborn as sns
from natsort import natsorted
import joblib

os.environ["TOKENIZERS_PARALLELISM"] = "false"

from _objects.model_configs import *
from _objects.configs import *
from _objects.data_configs import QsConfig

from _objects.plot_config import *

testCase_dir = 'test1_qs_structure'
subPath = 'analysis_sampling'
test_path = f"{testCase_dir}/{subPath}/"

from test1_qs_structure.analysis_sampling.analysis_utils import retrieve_logits_genLevel, process_logits_genLevel, \
    get_corrs_gen_wphq9, get_corrs_pvals_genq_logits, responses_totals_gen_logits

pc = PlotConfig()
bools = Bools()
qs_config = QsConfig()
# %% Set paths and sets
files_path_new = f'{testCase_dir}/llm_sampling/files_logits/'  # new repo path
files_path = '../online_tasks/qs_structure/_analysis/_llm_based/files_logits/'  # old path

paths = Paths(files_dir=files_path, plots_subdir='', plots_subsubdir='')
paths.plots_path = f'{testCase_dir}/_plots/logits_genLevel/'

label_permutation = 'ABCD'
paths.responses_path = f"{files_path}paired/{label_permutation}/subjects/"
paths.data_path = f'{testCase_dir}/_data/'

bools.loadMe = True
bools.saveMe = False
# bools.loadMe = False
# bools.saveMe = True

bools.saveFig = True
fig_no = 'Fig3'

instr_name_str = 'instr3'
task_v = ['v4', 'v4_d', 'v4_dd', 'v4_ddd']
model_names = ['MistralOo', 'gemma2-2b-it', 'llama32-3b-it', 'gemma2-9b-it', 'llama31-8b-it']
# model_names = ['gemma2-9b-it']

gen_qs_list = ['sds', 'gad7', 'phq9']
# gen_qs_list = ['sds']

p_thr = 0.05


# %% Get data
def get_logits_genLevel(model_names, gen_qs_list, paths, bools):
    def run_models(m_id):
        model_name = model_names[m_id]
        # model_name = model_names[0]
        print(model_name)
        sample_config = SampleLogitsConfig(paths, model_name=model_name, qs_name='phq9', instr_name=instr_name_str)
        for i, gen_qs in enumerate(gen_qs_list):
            print('\t', gen_qs)
            paths.save_responses_path = f"{files_path_new}cross/{gen_qs}/{label_permutation}/"
            Path(paths.save_responses_path).mkdir(parents=True, exist_ok=True)
            paths.gen_responses_path = f"{files_path}cross/{gen_qs}/{label_permutation}/subjects/"
            _ = retrieve_logits_genLevel(gen_qs, bools, sample_config, paths)

    if not bools.loadMe:
        joblib.Parallel(n_jobs=5)(joblib.delayed(run_models)(i) for i in range(len(model_names)))


if not bools.loadMe:
    get_logits_genLevel(model_names, gen_qs_list, paths, bools)
    bools.loadMe = True
    bools.saveMe = False

# %% Plot all questionnaires, all models
corr_diff_df = []
total_corr_df = []
for model_name in model_names:
    print(model_name)
    # Prepare plot
    pc.r, pc.c, pc.mlt = len(gen_qs_list), 3, 2.75
    pc.figsize = ((pc.c + 2.75) * pc.mlt, (pc.r + 0.75) * pc.mlt)
    # pc.annot_fs = 7.25
    pc.ax_ts(15.5, 1.1)
    pc.l_fs(12, 0.85)
    ts = 14
    pc.xyt_ls(ts, ts)
    pc.ax_ls(16)
    pc.kde_lw = 3
    pc.p_lab_spec[2] = 14
    pc.p_lab_spec[0] = -0.075
    pc.p_lab_spec[1] = 1.1
    pc.dpi_val = 300
    pc.ms = 100
    ax_space = 5
    hrat = 3.75
    plt.close('all')
    fig, pc.axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize, gridspec_kw={'width_ratios': [1, hrat, hrat]})
    annots = [7.75, 14, 12]
    # notsig_colors = [sns.color_palette("Greys", n_colors=1, desat=1)[0]]
    notsig_colors = 'Greys'
    sc_lw = 3

    for pc.i, gen_qs in enumerate(gen_qs_list):
        print(gen_qs)
        sample_config = SampleLogitsConfig(paths, model_name=model_name, qs_name='phq9', instr_name=instr_name_str)
        paths.save_responses_path = f"{files_path_new}cross/{gen_qs}/{label_permutation}/"
        paths.gen_responses_path = f"{files_path}cross/{gen_qs}/{label_permutation}/subjects/"
        responses_mean = retrieve_logits_genLevel(gen_qs, bools, sample_config, paths)
        responses_avg_merged, closed_data_long, q_names = process_logits_genLevel(gen_qs, bools,
                                                                                  sample_config, paths,
                                                                                  task_v)
        context_names = natsorted(responses_avg_merged['q_name_context'].unique())
        context_names = [c for c in context_names if 'lvl3' in c]
        responses_avg_merged = responses_avg_merged[responses_avg_merged['q_name_context'].isin(context_names)]
        q_names = natsorted(responses_avg_merged['q_name'].unique())

        # Get correlations with phq9
        cross_corr, cross_corr_subset, cross_pvals, cross_pvals_subset = get_corrs_gen_wphq9(paths, gen_qs,
                                                                                             closed_data_long,
                                                                                             q_names,
                                                                                             context_names, qs_config,
                                                                                             task_v,
                                                                                             p_thr=p_thr)

        # llm vs ppt gen correlations
        df_corr, df_corr_wide, N_range = get_corrs_pvals_genq_logits(responses_avg_merged,
                                                                     context_names, q_names,
                                                                     p_thr=p_thr)

        corr_diff = (np.tril(cross_corr_subset.iloc[:8, :]) - np.tril(df_corr_wide['r'].values)).flatten()
        corr_diff = [float(corr) for corr in corr_diff if corr != 0]
        tmp_dict_corr = {'model': sample_config.model_name_plot, 'gen_qs': gen_qs, 'corr_diff': corr_diff}
        corr_diff_df.append(pd.DataFrame(tmp_dict_corr))

        # calculate total scores for the llm samples and compare with subject data
        totals_sub_llm, r_totals, p_totals = responses_totals_gen_logits(responses_avg_merged,
                                                                         closed_data_long)

        tmp_dict_total = {'model': sample_config.model_name_plot, 'gen_qs': gen_qs, 'total_corr': [float(r_totals)]}
        total_corr_df.append(pd.DataFrame(tmp_dict_total))

        # Plotting correlation maps
        pc.annot_fs = annots[pc.i]
        pc.onerow = False
        pc.j = 0
        pc.ax.plot([qs_config.qs_min_total[gen_qs] - ax_space, qs_config.qs_max_total[gen_qs] + ax_space],
                   [qs_config.qs_min_total[gen_qs] - ax_space, qs_config.qs_max_total[gen_qs] + ax_space],
                   color='gray',
                   linewidth=sc_lw)
        sns.scatterplot(totals_sub_llm, x='total_llm', y='total_sub', ax=pc.ax, s=pc.ms)
        pc.ax.set_title(
            f"{gen_qs.upper()}, r: {r_totals:.3f}")
        # pc.ax.set_title(
        #     f"{gen_qs.upper()} - r: {r_totals:.2f}\n p-val: {p_totals:.2e}")
        # pc.ax.set_title(
        #     f"Total score correlations for {gen_qs.upper()}\nr: {r_totals:.2f}, p-val: {p_totals:.3g}, N: {len(totals_sub_llm)}")
        pc.ax.set_aspect('equal', 'box')
        pc.ax.set_xlim([qs_config.qs_min_total[gen_qs] - ax_space, qs_config.qs_max_total[gen_qs] + ax_space])
        pc.ax.set_ylim([qs_config.qs_min_total[gen_qs] - ax_space, qs_config.qs_max_total[gen_qs] + ax_space])
        pc.ax.set_xlabel('LLM total score')
        pc.ax.set_ylabel('Subject total score')
        pc.ax.text(pc.p_lab_spec[0], pc.p_lab_spec[1], pc.p_labs[pc.i, pc.j], transform=pc.ax.transAxes,
                   fontweight='bold',
                   va='top', ha='right', fontsize=pc.p_lab_spec[2])

        no_sui_idx = [c for c in cross_corr_subset.index if c != 'phq9_q9']

        # Plot participant correlation item-level
        pc.j = 1
        sns.heatmap(cross_corr_subset[cross_pvals < p_thr].loc[no_sui_idx], cmap='coolwarm', vmin=-1, vmax=1,
                    annot=True,
                    annot_kws={"size": pc.annot_fs},
                    ax=pc.ax, cbar=False, fmt='.2f')

        sns.heatmap(cross_corr_subset[cross_pvals >= p_thr].loc[no_sui_idx], cmap=notsig_colors, vmin=-1, vmax=1,
                    annot=True, annot_kws={"size": pc.annot_fs},
                    ax=pc.ax, cbar=False, fmt='.2f')
        # sns.heatmap(cross_corr_subset, cmap='coolwarm', vmin=-1, vmax=1, annot=True, annot_kws={"size": pc.annot_fs},
        #             ax=pc.ax, cbar=False)
        # pc.ax.set_title(f"Spearman correlations (p-val<={p_thr})")
        pc.ax.set_title(f"Subject item score pairwise correlations")
        pc.ax.set_xlabel(f"{gen_qs.upper()} question")
        pc.ax.set_ylabel('PHQ-8 question')

        pc.ax.set_xticks(np.arange(len(cross_corr_subset.columns)) + 0.5)
        pc.ax.set_xticklabels([f'{q + 1}' for q in range(len((cross_corr_subset.columns)))], rotation='horizontal')

        pc.ax.set_yticks(np.arange(len(cross_corr_subset.loc[no_sui_idx].index)) + 0.5)
        pc.ax.set_yticklabels([f'{q + 1}' for q in range(len((cross_corr_subset.loc[no_sui_idx].index)))])
        pc.ax.text(pc.p_lab_spec[0] + 0.05, pc.p_lab_spec[1], pc.p_labs[pc.i, pc.j], transform=pc.ax.transAxes,
                   fontweight='bold',
                   va='top', ha='right',
                   fontsize=pc.p_lab_spec[2])

        # Plot item-level correlations between participant score on generalised questionnaire and score on that given level 3 question
        pc.j = 2
        sns.heatmap(df_corr_wide['r'][df_corr_wide['p-val'] < p_thr], cmap='coolwarm', vmin=-1, vmax=1, annot=True,
                    annot_kws={"size": pc.annot_fs},
                    ax=pc.ax, cbar=False, fmt='.2f')

        sns.heatmap(df_corr_wide['r'][df_corr_wide['p-val'] >= p_thr], cmap=notsig_colors, vmin=-1, vmax=1,
                    annot=True,
                    annot_kws={"size": pc.annot_fs},
                    ax=pc.ax, cbar=False, fmt='.2f')

        # sns.heatmap(df_corr_wide['r'], cmap='coolwarm', annot=True, vmin=-1, vmax=1, annot_kws={"size": pc.annot_fs},
        #             fmt='.2f',
        #             ax=pc.ax, cbar=False)
        # pc.t = f"Item correlations between participant and LLM | each context"
        pc.ax.set_title(f"LLM vs subject item score correlations")
        # pc.t = f"Spearman correlations (p-val<={p_thr})"
        pc.ax.set_xlabel(gen_qs.upper() + ' question')
        pc.ax.set_ylabel('Open PHQ-8 question')
        pc.ax.set_xticks(np.arange(len(cross_corr_subset.columns)) + 0.5)
        pc.ax.set_xticklabels([f'{q + 1}' for q in range(len((cross_corr_subset.columns)))], rotation='horizontal')

        pc.ax.set_yticks(np.arange(len(df_corr_wide.index)) + 0.5)
        pc.ax.set_yticklabels([f'{q + 1}' for q in range(len((df_corr_wide.index)))])
        # set_a_hist(pc)
        pc.ax.text(pc.p_lab_spec[0] + 0.05, pc.p_lab_spec[1], pc.p_labs[pc.i, pc.j], transform=pc.ax.transAxes,
                   fontweight='bold',
                   va='top', ha='right', fontsize=pc.p_lab_spec[2])

    plt.suptitle(f'{sample_config.model_name_plot}')
    plt.tight_layout()

    if bools.saveFig:
        plt.savefig(
            f"{paths.plots_path}gen_qs_total_item_corrs_{sample_config.model_name}.pdf", dpi=300)

corr_diff_df = pd.concat(corr_diff_df)
total_corr_df = pd.concat(total_corr_df)
# %%
corr_diff_df_avg = corr_diff_df.groupby(['model', 'gen_qs'], as_index=False).mean()
total_corr_df_avg = total_corr_df.groupby(['model', 'gen_qs'], as_index=False).mean()
corr_diff_df_avg_wide = corr_diff_df_avg.pivot(index='model', columns='gen_qs', values='corr_diff')
total_corr_df_avg_wide = total_corr_df_avg.pivot(index='model', columns='gen_qs', values='total_corr')

# %% Plot performance across models
plt.close('all')
pc.r, pc.c, pc.mlt = 1, 2, 2.25
pc.figsize = ((pc.c + 2.25) * pc.mlt, (pc.r + 0.75) * pc.mlt)
fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize,sharey=True)
# pc.axes = np.array([[axes]])
# pc.axes= axes
pc.onerow = False
pc.axes = np.array([axes])
# pc.axes = np.array([[axes]])
pc.i, pc.j = 0, 0

pc.ax_ts(18, 1.1)
pc.l_fs(12, 0.85)
ts = 15
pc.xyt_ls(ts, ts)
pc.ax_ls(20)
pc.kde_lw = 3
pc.p_lab_spec[2] = 14
pc.p_lab_spec[0] = -0.4
pc.p_lab_spec[1] = 1.1
pc.dpi_val = 300
pc.ms = 100
ax_space = 5
hrat = 3.75
pc.annot_fs = 16
sc_lw = 3
diff_corrs_sorted = corr_diff_df_avg_wide.mean(axis=1).sort_values(ascending=True)
model_list = diff_corrs_sorted.index.tolist()
q_list = corr_diff_df_avg_wide.columns.sort_values(ascending=False).tolist()

sns.heatmap(corr_diff_df_avg_wide.loc[model_list, q_list], annot=True, cmap='Oranges', ax=pc.ax,
            annot_kws={"size": pc.annot_fs}, cbar=False, fmt='.2f')
pc.ax.set_xticklabels([f'{q.upper()}' for q in q_list], rotation='horizontal')
pc.ax.set_xlabel('Questionnaire')
pc.ax.set_ylabel('Model')
pc.ax.text(pc.p_lab_spec[0] + 0.05, pc.p_lab_spec[1], pc.p_labs[pc.i, pc.j], transform=pc.ax.transAxes,
           fontweight='bold',
           va='top', ha='right',
           fontsize=pc.p_lab_spec[2])
pc.ax.set_title('Average cov item difference')

# plot totals
pc.j=1
pc.p_lab_spec[0] = -0.05
totals_corrs_sorted = total_corr_df_avg_wide.mean(axis=1).sort_values(ascending=False)
model_list = totals_corrs_sorted.index.tolist()

sns.heatmap(total_corr_df_avg_wide.loc[model_list, q_list], annot=True, cmap='Oranges', ax=pc.ax,
            annot_kws={"size": pc.annot_fs}, cbar=False, fmt='.2f')
pc.ax.set_xticklabels([f'{q.upper()}' for q in q_list], rotation='horizontal')
pc.ax.set_xlabel('Questionnaire')
pc.ax.set_ylabel('')
pc.ax.text(pc.p_lab_spec[0] + 0.05, pc.p_lab_spec[1], pc.p_labs[pc.i, pc.j], transform=pc.ax.transAxes,
           fontweight='bold',
           va='top', ha='right',
           fontsize=pc.p_lab_spec[2])
pc.ax.set_title('Totals correlations')

plt.tight_layout()

if bools.saveFig:
    plt.savefig(
        f"{paths.plots_path}all_model_gen_corrs_diff.pdf", dpi=300)
    plt.savefig(
        f"{paths.plots_path}{fig_no}_p2_test1_gen_level.pdf", dpi=300)

# %% Plot best model and sds only
model_names = ['gemma2-9b-it']
gen_qs_list = ['sds']
for model_name in model_names:
    print(model_name)
    # Prepare plot
    pc.r, pc.c, pc.mlt = len(gen_qs_list), 3, 2.75
    pc.figsize = ((pc.c + 2.75) * pc.mlt, (pc.r + 0.75) * pc.mlt)
    # pc.annot_fs = 7.25
    pc.ax_ts(15.5, 1.1)
    pc.l_fs(12, 0.85)
    ts = 14
    pc.xyt_ls(ts, ts)
    pc.ax_ls(16)
    pc.kde_lw = 3
    pc.p_lab_spec[2] = 14
    pc.p_lab_spec[0] = -0.075
    pc.p_lab_spec[1] = 1.1
    pc.dpi_val = 300
    pc.ms = 100
    ax_space = 5
    hrat = 3.75
    plt.close('all')
    fig, pc.axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize, gridspec_kw={'width_ratios': [1, hrat, hrat]})
    annots = [9, 14, 12]
    # notsig_colors = [sns.color_palette("Greys", n_colors=1, desat=1)[0]]
    notsig_colors = 'Greys'
    sc_lw = 3

    for pc.i, gen_qs in enumerate(gen_qs_list):
        print(gen_qs)
        sample_config = SampleLogitsConfig(paths, model_name=model_name, qs_name='phq9', instr_name=instr_name_str)
        paths.save_responses_path = f"{files_path_new}cross/{gen_qs}/{label_permutation}/"
        paths.gen_responses_path = f"{files_path}cross/{gen_qs}/{label_permutation}/subjects/"
        responses_mean = retrieve_logits_genLevel(gen_qs, bools, sample_config, paths)
        responses_avg_merged, closed_data_long, q_names = process_logits_genLevel(gen_qs, bools,
                                                                                  sample_config, paths,
                                                                                  task_v)
        context_names = natsorted(responses_avg_merged['q_name_context'].unique())
        context_names = [c for c in context_names if 'lvl3' in c]
        responses_avg_merged = responses_avg_merged[responses_avg_merged['q_name_context'].isin(context_names)]
        q_names = natsorted(responses_avg_merged['q_name'].unique())

        # Get correlations with phq9
        cross_corr, cross_corr_subset, cross_pvals, cross_pvals_subset = get_corrs_gen_wphq9(paths, gen_qs,
                                                                                             closed_data_long,
                                                                                             q_names,
                                                                                             context_names, qs_config,
                                                                                             task_v,
                                                                                             p_thr=p_thr)

        # llm vs ppt gen correlations
        df_corr, df_corr_wide, N_range = get_corrs_pvals_genq_logits(responses_avg_merged,
                                                                     context_names, q_names,
                                                                     p_thr=p_thr)

        # calculate total scores for the llm samples and compare with subject data
        totals_sub_llm, r_totals, p_totals = responses_totals_gen_logits(responses_avg_merged,
                                                                         closed_data_long)

        # Plotting correlation maps
        pc.annot_fs = annots[pc.i]
        pc.onerow = False
        pc.j = 0
        pc.ax.plot([qs_config.qs_min_total[gen_qs] - ax_space, qs_config.qs_max_total[gen_qs] + ax_space],
                   [qs_config.qs_min_total[gen_qs] - ax_space, qs_config.qs_max_total[gen_qs] + ax_space],
                   color='gray',
                   linewidth=sc_lw)
        sns.scatterplot(totals_sub_llm, x='total_llm', y='total_sub', ax=pc.ax, s=pc.ms)
        pc.ax.set_title(
            f"{gen_qs.upper()}, r: {r_totals:.3f}")
        # pc.ax.set_title(
        #     f"{gen_qs.upper()} - r: {r_totals:.2f}\n p-val: {p_totals:.2e}")
        # pc.ax.set_title(
        #     f"Total score correlations for {gen_qs.upper()}\nr: {r_totals:.2f}, p-val: {p_totals:.3g}, N: {len(totals_sub_llm)}")
        pc.ax.set_aspect('equal', 'box')
        pc.ax.set_xlim([qs_config.qs_min_total[gen_qs] - ax_space, qs_config.qs_max_total[gen_qs] + ax_space])
        pc.ax.set_ylim([qs_config.qs_min_total[gen_qs] - ax_space, qs_config.qs_max_total[gen_qs] + ax_space])
        pc.ax.set_xlabel('LLM total score')
        pc.ax.set_ylabel('Subject total score')
        pc.ax.text(pc.p_lab_spec[0], pc.p_lab_spec[1], pc.p_labs[pc.i, pc.j], transform=pc.ax.transAxes,
                   fontweight='bold',
                   va='top', ha='right', fontsize=pc.p_lab_spec[2])

        no_sui_idx = [c for c in cross_corr_subset.index if c != 'phq9_q9']

        # Plot participant correlation item-level
        pc.j = 1
        sns.heatmap(cross_corr_subset[cross_pvals < p_thr].loc[no_sui_idx], cmap='coolwarm', vmin=-1, vmax=1,
                    annot=True,
                    annot_kws={"size": pc.annot_fs},
                    ax=pc.ax, cbar=False, fmt='.2f')

        sns.heatmap(cross_corr_subset[cross_pvals >= p_thr].loc[no_sui_idx], cmap=notsig_colors, vmin=-1, vmax=1,
                    annot=True, annot_kws={"size": pc.annot_fs},
                    ax=pc.ax, cbar=False, fmt='.2f')
        # sns.heatmap(cross_corr_subset, cmap='coolwarm', vmin=-1, vmax=1, annot=True, annot_kws={"size": pc.annot_fs},
        #             ax=pc.ax, cbar=False)
        # pc.ax.set_title(f"Spearman correlations (p-val<={p_thr})")
        pc.ax.set_title(f"Subject item score pairwise correlations")
        pc.ax.set_xlabel(f"{gen_qs.upper()} question")
        pc.ax.set_ylabel('PHQ-8 question')

        pc.ax.set_xticks(np.arange(len(cross_corr_subset.columns)) + 0.5)
        pc.ax.set_xticklabels([f'{q + 1}' for q in range(len((cross_corr_subset.columns)))], rotation='horizontal')

        pc.ax.set_yticks(np.arange(len(cross_corr_subset.loc[no_sui_idx].index)) + 0.5)
        pc.ax.set_yticklabels([f'{q + 1}' for q in range(len((cross_corr_subset.loc[no_sui_idx].index)))])
        pc.ax.text(pc.p_lab_spec[0] + 0.05, pc.p_lab_spec[1], pc.p_labs[pc.i, pc.j], transform=pc.ax.transAxes,
                   fontweight='bold',
                   va='top', ha='right',
                   fontsize=pc.p_lab_spec[2])

        # Plot item-level correlations between participant score on generalised questionnaire and score on that given level 3 question
        pc.j = 2
        sns.heatmap(df_corr_wide['r'][df_corr_wide['p-val'] < p_thr], cmap='coolwarm', vmin=-1, vmax=1, annot=True,
                    annot_kws={"size": pc.annot_fs},
                    ax=pc.ax, cbar=False, fmt='.2f')

        sns.heatmap(df_corr_wide['r'][df_corr_wide['p-val'] >= p_thr], cmap=notsig_colors, vmin=-1, vmax=1,
                    annot=True,
                    annot_kws={"size": pc.annot_fs},
                    ax=pc.ax, cbar=False, fmt='.2f')

        # sns.heatmap(df_corr_wide['r'], cmap='coolwarm', annot=True, vmin=-1, vmax=1, annot_kws={"size": pc.annot_fs},
        #             fmt='.2f',
        #             ax=pc.ax, cbar=False)
        # pc.t = f"Item correlations between participant and LLM | each context"
        pc.ax.set_title(f"LLM vs subject item score correlations")
        # pc.t = f"Spearman correlations (p-val<={p_thr})"
        pc.ax.set_xlabel(gen_qs.upper() + ' question')
        pc.ax.set_ylabel('Open PHQ-8 question')
        pc.ax.set_xticks(np.arange(len(cross_corr_subset.columns)) + 0.5)
        pc.ax.set_xticklabels([f'{q + 1}' for q in range(len((cross_corr_subset.columns)))], rotation='horizontal')

        pc.ax.set_yticks(np.arange(len(df_corr_wide.index)) + 0.5)
        pc.ax.set_yticklabels([f'{q + 1}' for q in range(len((df_corr_wide.index)))])
        # set_a_hist(pc)
        pc.ax.text(pc.p_lab_spec[0] + 0.05, pc.p_lab_spec[1], pc.p_labs[pc.i, pc.j], transform=pc.ax.transAxes,
                   fontweight='bold',
                   va='top', ha='right', fontsize=pc.p_lab_spec[2])

    plt.suptitle(f'{sample_config.model_name_plot}')
    plt.tight_layout()

    if bools.saveFig:
        if model_name == 'gemma2-9b-it':
            plt.savefig(f"{paths.plots_path}{fig_no}_p1_test1_gen_level.pdf", dpi=300)
