import re
from scipy import stats

import matplotlib

matplotlib.use('Qt5Agg')
import matplotlib.pyplot as plt

plt.ion()
import os
import pandas as pd
import seaborn as sns

os.environ["TOKENIZERS_PARALLELISM"] = "false"

from _objects.model_configs import *
from _objects.configs import *
from _objects.data_configs import QsConfig, phq9_qs_inv_map

from _objects.plot_config import *
from _utils.plot_utils import set_scatter_axes

testCase_dir = 'test1_qs_structure'
subPath = 'analysis_sampling'
test_path =f"{testCase_dir}/{subPath}/"
# os.chdir(f"{testCase_dir}/{subPath}/")

from test1_qs_structure.analysis_sampling.analysis_utils import process_logits_itemLevel

pc = PlotConfig()
bools = Bools()
# %% Set paths and sets
files_path_new = f'{testCase_dir}/llm_sampling/files_logits/'  # new repo path
files_path = '../online_tasks/qs_structure/_analysis/_llm_based/files_logits/'  # old path

paths = Paths(files_dir=files_path, plots_subdir='', plots_subsubdir='')
paths.plots_path = f'{testCase_dir}/_plots/logits_itemLevel/'

label_permutation = 'ABCD'
paths.responses_path = f"{files_path}paired/{label_permutation}/subjects/"
paths.save_responses_path = f"{files_path_new}paired/{label_permutation}"
Path(paths.save_responses_path).mkdir(parents=True, exist_ok=True)
paths.data_path = f'{testCase_dir}/_data/'

bools.loadMe = True
bools.saveMe = False
# bools.loadMe = False
# bools.saveMe = True
bools.saveFig = True
# bools.saveFig = False
fig_no = 'Fig2'

instr_name_str = 'instr3'
task_v = ['v4', 'v4_d', 'v4_dd', 'v4_ddd']
model_names = ['MistralOo', 'gemma2-2b-it', 'llama32-3b-it', 'gemma2-9b-it', 'llama31-8b-it']
# model_names = ['gemma2-9b-it']
# %% Plot item level for each model
model_corrs_df = []
for model_name in model_names:
    print(model_name)

    sample_config = SampleLogitsConfig(paths, model_name=model_name, qs_name='phq9', instr_name=instr_name_str)
    responses_avg_merged, q_names = process_logits_itemLevel(sample_config, paths, bools, task_v, phq9_qs_inv_map)
    q_names = [qn for qn in q_names if 'lvl3' in qn]

    plt.close('all')
    pc.r, pc.c, pc.mlt = 2, 4, 1.75
    pc.figsize = ((pc.c + 2) * pc.mlt, (pc.r + 0.75) * pc.mlt)
    pc.ax_ts(15, 1.1)
    pc.l_fs(12, 0.85)
    pc.xyt_ls(16, 16)
    pc.ax_ls(16)
    pc.kde_lw = 3
    pc.p_lab_spec[2] = 14
    pc.p_lab_spec[0] = -0.05
    pc.p_lab_spec[1] = 1.05
    pc.dpi_val = 300
    ms = 110
    alpha = 0.7
    s_ec = '#ffedcb'
    b_c = '#e19f20'
    s_lw = 0.8
    s_size = 2
    s_out_size = 8
    b_width = 0.5
    v_width = 1.2
    b_lw = 1.5
    v_cols = {'lvl1': '#1f77b4', 'lvl2': '#009ac8', 'lvl3': '#00b9c2'}

    plt.close('all')
    fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize)
    pc.axes = axes.flatten()
    pc.onerow = True
    pc.i = 0
    for pc.j, (q_name, ax) in enumerate(zip(q_names, pc.axes)):
        q_name_short = re.sub('lvl.*_', '', q_name).upper()
        responses_avg_merged_q = responses_avg_merged[(responses_avg_merged['q_name'] == q_name)].dropna(how='any')
        lvl_name = re.sub('_q.*', '', q_name)

        N = len(responses_avg_merged_q['sub'].unique())
        sns.violinplot(responses_avg_merged_q, y='score_llm', x='score_sub', ax=pc.ax, orient='v', inner=None,
                       width=v_width, color=v_cols[lvl_name])
        bp = sns.boxplot(data=responses_avg_merged_q, y='score_llm', x='score_sub', ax=pc.ax, width=b_width, color=b_c,
                         linewidth=b_lw, orient='v', fliersize=s_out_size)
        for patch in bp.patches:
            face_color = patch.get_facecolor()
            patch.set_facecolor((*face_color[:3], alpha))
        sns.stripplot(data=responses_avg_merged_q, y='score_llm', x='score_sub', ax=pc.ax, edgecolor='black',
                      color=s_ec, linewidth=s_lw,
                      size=s_size, orient='v')
        # sns.scatterplot(responses_avg_merged_q, x='score_llm', y='score_sub', ax=pc.ax, s=ms)
        # r, p = stats.pearsonr(responses_avg_merged_q['score_llm'], responses_avg_merged_q['score_sub'])
        r, p = stats.spearmanr(responses_avg_merged_q['score_llm'], responses_avg_merged_q['score_sub'])
        p = min(p * len(q_names), 1)
        # pc.t = f"N: {N}, q: {q_name}\nr (spear): {r:.3f} - p-val: {p:.2e}\n{spec_name}"
        # q_title = f'Level {lvl_name[-1]} {re.sub('lvl.*_', '', q_name).upper()}'
        # q_title = f'{re.sub('lvl.*_', '', q_name).upper()}'
        # pc.t = f"{q_title}\nr {r:.3f} - p-val: {p:.2e}"
        pc.t = f"{q_name_short}: r={r:.3f}\np-val: {p:.2e}"
        # pc.t = f"{q_name_short}: r={r:.3f}"
        pc.ax.set_xlim([-0.75, 3.75])
        pc.ax.set_ylim([-0.75, 3.75])
        set_scatter_axes(pc)
        pc.ax.set_xticks(np.arange(4))
        pc.ax.set_xticklabels(np.arange(4))
        pc.ax.set_yticks(np.arange(4))
        pc.ax.set_yticklabels(np.arange(4))
        pc.ax.set_ylabel('LLM score')
        pc.ax.set_xlabel('Subject score')

        tmp_dict = {'model': sample_config.model_name_plot, 'q_name': q_name_short, 'corr': r, 'pval': p}
        model_corrs_df.append(tmp_dict)

    plt.suptitle(f'{sample_config.model_name_plot}')

    # plt.suptitle('Question level score correlations - LLM vs subjects')
    plt.tight_layout()
    if bools.saveFig:
        plt.savefig(
            f"{paths.plots_path}scores_correlations_{sample_config.model_name}.pdf", dpi=300)
        # if model_name == 'gemma2-9b-it':
        #     plt.savefig(f"{paths.plots_path}{fig_no}_p1_test1_item_level.pdf", dpi=300)

# %% Get correlations across models
model_corrs_df = pd.DataFrame(model_corrs_df)
model_corrs_df_wide = model_corrs_df.pivot(index='model', columns='q_name', values=['corr', 'pval'])

# %% Plot performance across models
plt.close('all')
pc.r, pc.c, pc.mlt = 1, 1, 2
pc.figsize = ((pc.c + 1.5) * pc.mlt, (pc.r + 0.75) * pc.mlt)
fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize)
# pc.axes = np.array([[axes]])
# pc.axes= axes
pc.onerow = False
pc.axes = np.array([[axes]])
pc.i, pc.j = 0, 0

pc.ax_ts(13.5, 1.1)
pc.l_fs(12, 0.85)
ts = 13
pc.xyt_ls(ts, ts)
pc.ax_ls(13)
pc.kde_lw = 3
pc.p_lab_spec[2] = 14
pc.p_lab_spec[0] = -0.4
pc.p_lab_spec[1] = 1.1
pc.dpi_val = 300
pc.ms = 100
ax_space = 5
hrat = 3.75
pc.annot_fs = 10
sc_lw = 3

corrs_sorted = model_corrs_df_wide['corr'].mean(axis=1).sort_values(ascending=False)
model_list = corrs_sorted.index.tolist()
sns.heatmap(model_corrs_df_wide['corr'].loc[model_list, :], annot=True, cmap='Oranges', ax=pc.ax,
            annot_kws={"size": pc.annot_fs}, cbar=False)
pc.ax.set_xlabel('Question')
pc.ax.set_ylabel('Model')
pc.ax.text(pc.p_lab_spec[0] + 0.05, pc.p_lab_spec[1], pc.p_labs[pc.i, pc.j], transform=pc.ax.transAxes,
           fontweight='bold',
           va='top', ha='right',
           fontsize=pc.p_lab_spec[2])
pc.ax.set_title('Item-level Correlations')
plt.tight_layout()

if bools.saveFig:
    plt.savefig(
        f"{paths.plots_path}all_model_item_correlations.pdf", dpi=300)
    plt.savefig(
        f"{paths.plots_path}{fig_no}_p2_test1_item_level.pdf", dpi=300)

#%% Totals
phq9_data = pd.read_csv(f"{paths.data_path}/phq9_data.csv")
print(phq9_data['total'].aggregate(['mean', 'std']))

sds_data = pd.read_csv(f"{paths.data_path}/sds_data.csv")
print(sds_data['total'].aggregate(['mean', 'std']))

gad7_data = pd.read_csv(f"{paths.data_path}/gad7_data.csv")
print(gad7_data['total'].aggregate(['mean', 'std']))
