# %%
from matplotlib.lines import Line2D

import os, re, copy

from scipy import stats
import pandas as pd
import seaborn as sns
# download_blob('llm-bucket-res', 'data_sae.zip', 'data_sae.zip')
from _utils.plot_utils import set_scatter_axes

# from sae_new import train_config

import matplotlib
import matplotlib.pyplot as plt

if os.uname()[0] == 'Darwin':  # if on mac
    device_name = 'mps'
    matplotlib.use('Qt5Agg')
    plt.ion()

else:
    matplotlib.use('Agg')
    matplotlib.get_backend()
    device_name = 'cuda'

from _objects.model_configs import *
from _objects.configs import *
from _objects.plot_config import *

pc = PlotConfig()

# from _objects.sae_models import *


testCase_dir = 'test2_ssae'
subPath = 'analysis_sampling'
test_path = f"{testCase_dir}/{subPath}/"
fig_no = 'Fig4'

bools = Bools()
bools.saveFig = True
bools.saveFig = False
# %%
files_path_new = f'{testCase_dir}/llm_sampling/files_logits/'  # new repo path
files_path = '../online_tasks/qs_structure/_analysis/_llm_based/files_logits/'  # old path
prev_path = '../construct_detector/'  # old path
prev_path_new = ''
paths = Paths(files_dir=files_path, plots_subdir='', plots_subsubdir='')
paths.plots_path = f'{testCase_dir}/_plots/training/'

exp_name_org = 'SAE_Mistral_gemma_llama_exp_v3_best'
exp_name_data = 'SAE_Mistral_gemma_llama_exp_v1'
model_name = 'gemma2-9b-it'

sample_config = SampleLogitsConfig(paths, model_name=model_name, qs_name='phq9', instr_name='instr3')

# %% Find the best hyperparameter setting for each model
loss_dir = f'{prev_path}loss/{exp_name_org}/{model_name}/'
metric_dir = f'{prev_path}metrics/{exp_name_org}/{model_name}/'

loss_files = os.listdir(loss_dir)
metric_files = [f for f in os.listdir(metric_dir) if 'metric' in f]

train_configs = ['model-' + ''.join(l.split('loss_model-')).replace('.csv', '') for l in loss_files]
train_configs = [{s.split('-')[0]: '-'.join(s.split('-')[1:]) for s in t.split('^^')} for t in train_configs]
all_model_loss = pd.concat([pd.read_csv(f'{loss_dir}{l}') for l in loss_files], axis=0)
all_metrics = pd.concat([pd.read_csv(f'{metric_dir}{l}') for l in metric_files], axis=0)
id_vars = list(train_configs[0].keys())
metric_avg = all_metrics.groupby(id_vars, as_index=False)[['r', 'p']].aggregate(['mean', 'std'])
metric_avg.columns = [f'{c1}_{c2}' if c2 != '' else c1 for c1, c2 in
                      zip(metric_avg.columns.get_level_values(0), metric_avg.columns.get_level_values(1))]


# %%
best_layer = int(metric_avg.sort_values(by='r_mean', ascending=False).head(1)['layer'].values[0])
pred_files = [f for f in os.listdir(metric_dir) if 'preds' in f and f'layer-{best_layer}' in f]
best_preds = pd.concat([pd.read_csv(f'{metric_dir}{l}') for l in pred_files], axis=0)
best_preds = best_preds.pivot(index=['sub', 'variable'], columns='source', values='value').reset_index()

# %% Plot metrics layer wise
all_metrics_wide = all_metrics.pivot(index='q_name', columns='layer', values='r')
plt.close('all')
pc.r, pc.c, pc.mlt = 1, 1, 2
pc.figsize = ((pc.c + 3.5) * pc.mlt, (pc.r + 0.5) * pc.mlt)
fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize)
# pc.axes = np.array([[axes]])
# pc.axes= axes
pc.onerow = False
pc.axes = np.array([[axes]])
pc.i, pc.j = 0, 0

pc.ax_ts(13.5, 1.1)
pc.l_fs(12, 0.85)
ts = 13
pc.xyt_ls(ts, ts - 1)
pc.ax_ls(13)
pc.kde_lw = 3
pc.p_lab_spec[2] = 14
pc.p_lab_spec[0] = -.05
pc.p_lab_spec[1] = 1.15
pc.dpi_val = 300
pc.ms = 100
ax_space = 5
hrat = 3.75
pc.annot_fs = 9
sc_lw = 3
sns.heatmap(all_metrics_wide, annot=True, cmap='Oranges', ax=pc.ax,
            annot_kws={"size": pc.annot_fs}, cbar=False, fmt='.2f')
pc.ax.set_xlabel('Layer')
pc.ax.set_ylabel('PHQ9 Latent Question')
pc.ax.set_yticklabels([f'Q{q + 1}' for q in range(9)], rotation='horizontal')
pc.ax.text(pc.p_lab_spec[0] + 0.05, pc.p_lab_spec[1], pc.p_labs[pc.i, pc.j], transform=pc.ax.transAxes,
           fontweight='bold',
           va='top', ha='right',
           fontsize=pc.p_lab_spec[2])
pc.ax.set_title(f'{sample_config.model_name_plot} layerwise sSAE prediction performance')
plt.tight_layout()
if bools.saveFig:
    plt.savefig(
        f"{paths.plots_path}ssae_training_performance_{sample_config.model_name}.pdf", dpi=300)
    plt.savefig(
        f"{paths.plots_path}{fig_no}_p1_ssae_training_performance_{sample_config.model_name}.pdf", dpi=300)

# %% Plot best layer predictions
plt.close('all')
pc.r, pc.c, pc.mlt = 2, 5, 1.8
pc.figsize = ((pc.c + 2.75) * pc.mlt, (pc.r + 1.25) * pc.mlt)
pc.ax_ts(13.5, 1.1)
pc.l_fs(12, 0.85)
pc.xyt_ls(14, 11)
pc.ax_ls(14)
pc.kde_lw = 3
pc.p_lab_spec[2] = 14
pc.p_lab_spec[0] = -0.1
pc.p_lab_spec[1] = 1.1


pc.dpi_val = 300
ms = 110
alpha = 0.7
s_ec = '#ffedcb'
b_c = '#e19f20'
s_lw = 0.8
s_size = 1.5
s_out_size = 5
b_width = 0.35
v_width = 0.8
b_lw = 1.5
v_cols = {'lvl1': '#1f77b4', 'lvl2': '#009ac8', 'lvl3': '#00b9c2'}

plt.close('all')
fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize)
pc.axes = axes.flatten()
pc.onerow = True
pc.i = 0
q_names = best_preds['variable'].unique()

v_col = '#1f77b4'
for pc.j, (q_name, ax) in enumerate(zip(q_names, pc.axes)):
    # q_name_short = re.sub('lvl.*_', '', q_name).upper()
    q_name_short = f'PHQ9 Q{pc.j + 1}'
    best_preds_q = best_preds[best_preds['variable'] == q_name]

    # sns.scatterplot(best_preds_q, y='ae', x='ppt', ax=pc.ax,color=v_col)
    sns.violinplot(best_preds_q, y='ae', x='ppt', ax=pc.ax, orient='v', inner=None,
                   width=v_width, color=v_col)
    bp = sns.boxplot(data=best_preds_q, y='ae', x='ppt', ax=pc.ax, width=b_width, color=b_c,
                     linewidth=b_lw, orient='v', fliersize=s_out_size)
    for patch in bp.patches:
        face_color = patch.get_facecolor()
        patch.set_facecolor((*face_color[:3], alpha))
    sns.stripplot(data=best_preds_q, y='ae', x='ppt', ax=pc.ax, edgecolor='black',
                  color=s_ec, linewidth=s_lw,
                  size=s_size, orient='v')
    r, p = stats.spearmanr(best_preds_q['ae'], best_preds_q['ppt'])
    p = min(p * len(q_names), 1)
    # pc.t = f"{q_name_short}: r={r:.3f}"
    pc.t = f"{q_name_short}: r={r:.3f}\n p={p:.3e}"
    set_scatter_axes(pc)
    xv = [float(v) for v in sorted(best_preds_q['ppt'].unique())]
    pc.ax.set_xticks(range(len(xv)))
    pc.ax.set_xticklabels(np.round(xv, 1))
    pc.ax.set_ylabel('sSAE score')
    pc.ax.set_xlabel('Subject score')
plt.suptitle(f"{sample_config.model_name_plot} best layer sSAE score predictions")
plt.tight_layout()
pc.j = 9
pc.ax.remove()

if bools.saveFig:
    plt.savefig(
        f"{paths.plots_path}ssae_bestlayer_performance_{sample_config.model_name}.pdf", dpi=300)
    plt.savefig(
        f"{paths.plots_path}{fig_no}_p2_ssae_bestlayer_performance_{sample_config.model_name}.pdf", dpi=300)

#%%
plt.close('all')
pc.r, pc.c, pc.mlt = 1, 1, 2
pc.figsize = ((pc.c + 3.5) * pc.mlt, (pc.r + 0.5) * pc.mlt)
fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize)
# pc.axes = np.array([[axes]])
# pc.axes= axes
pc.onerow = False
pc.axes = np.array([[axes]])
pc.i, pc.j = 0, 0

pc.ax_ts(13.5, 1.1)
pc.l_fs(12, 0.85)
ts = 13
pc.xyt_ls(ts, ts - 1)
pc.ax_ls(13)
pc.kde_lw = 3
pc.p_lab_spec[2] = 14
pc.p_lab_spec[0] = -.05
pc.p_lab_spec[1] = 1.15
pc.dpi_val = 300
pc.ms = 100
ax_space = 5
hrat = 3.75
pc.annot_fs = 9
sc_lw = 0.5
alpha=0.1

# a2=sns.lineplot(all_model_loss,x='epoch',y='train_loss',ax=pc.ax,color='tab:blue',lw=2)
sns.lineplot(all_model_loss,x='epoch',y='train_loss',units='layer',ax=pc.ax,color='tab:blue', estimator=None,lw=sc_lw)
sns.lineplot(all_model_loss,x='epoch',y='val_loss',units='layer',ax=pc.ax,color='tab:orange', estimator=None,lw=sc_lw)
custom_lines = [Line2D([0], [0], color='tab:blue', lw=2),
                Line2D([0], [0], color='tab:orange', lw=2)]

pc.ax.legend(custom_lines,['training loss', 'validation loss'])
# pc.ax.legend()
pc.ax.set_ylim([0,10])
pc.ax.set_xlabel('Epoch')
pc.ax.set_ylabel('Loss')
pc.ax.set_title(f'Loss curve for each layer for {sample_config.model_name_plot}')
plt.tight_layout()
if bools.saveFig:
    plt.savefig(
        f"{paths.plots_path}ssae_loss_{sample_config.model_name}.pdf", dpi=300)
