# %% Import libraries and load data
import csv
import os
import pandas as pd
import numpy as np
import matplotlib
import scipy.stats as stats
from transformers import pipeline
import statsmodels.api as sm
import statsmodels.formula.api as smf
import statsmodels.api as sm
from textblob import TextBlob
import re
from difflib import get_close_matches
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity as cossim

device_name = 'mps'

matplotlib.use('Qt5Agg')
# matplotlib.use('TkAgg')
# matplotlib.get_backend()
import matplotlib.pyplot as plt

plt.ion()
import seaborn as sns

# establish working directory
cwd = os.getcwd()
cwd_split = cwd.split('/')
base_task = 'qs_intervention'
path_to_go = '_analysis'
cwd_base = '/'.join(cwd_split[:np.argwhere([p == "online_tasks" for p in cwd_split])[0][0] + 1])
os.chdir(f"{cwd_base}/{base_task}/{path_to_go}")

task_versions = ['v2', 'v2b']
# task_versions = ['v4']
#
# saveMe = False
# saveMe = True
from objects.plot_config import *
from utils.utils import percentile

from objects.plot_config import *
from utils.utils import percentile
from utils.plot_utils import set_a_hist

pc = PlotConfig()
plots_path = '_plots/prelim/'
savePlot = True
# savePlot = False

model = SentenceTransformer("all-MiniLM-L6-v2", device=device_name)

# %% Load data and transcripts
task_versions = ['v2', 'v2b']
store_int_dfs = []
for task_version in task_versions:
    file_path = f"_data/qs-intervention-{task_version}/processed/int_data.csv"
    int_data_pd = pd.read_csv(file_path).drop(columns=['task_version'])  # [int_cols]
    store_int_dfs.append(int_data_pd)
int_data = pd.concat(store_int_dfs).reset_index(drop=True)
int_data = int_data[int_data['group'] == 'D']
int_data = int_data[~int_data.isna().any(axis=1)]
text_cols = int_data.columns[3:3 + 5]

# %% PHQ9 questions
phq9_q1 = "Little interest or pleasure in doing things."
phq9_q2 = "Feeling down, depressed, or hopeless."
phq9_q3 = "Trouble falling or staying asleep, or sleeping too much."
phq9_q4 = "Feeling tired or having little energy."
phq9_q5 = "Poor appetite or overeating."
phq9_q6 = "Feeling bad about yourself - or that you are a failure or have let yourself or your family down."
phq9_q7 = "Trouble concentrating on things, such as reading the newspaper or watching television."
phq9_q8 = "Moving or speaking so slowly that other people could have noticed? Or the opposite - being so fidgety or restless that you have been moving around a lot more than usual."
phq9_q9 = "Thoughts that you would be better off dead or of hurting yourself in some way."

phq9_qs = [phq9_q1, phq9_q2, phq9_q3, phq9_q4, phq9_q5, phq9_q6, phq9_q7, phq9_q8, phq9_q9]
phq9_qs_lab = [f'phq9_q{q + 1}' for q in range(9)]

phq9_embds = model.encode(phq9_qs)

# %% PHQ9
store_phq9_dfs = []
for task_version in task_versions:
    file_path = f"_data/qs-intervention-{task_version}/processed/phq9_data.csv"
    phq9_data_pd = pd.read_csv(file_path).drop(columns=['task_version'])  # [phq9_cols]
    store_phq9_dfs.append(phq9_data_pd)
phq9_data = pd.concat(store_phq9_dfs).reset_index(drop=True)
phq9_data = phq9_data[phq9_data['group'] == 'D']
phq9_data = phq9_data[~phq9_data.isna().any(axis=1)]
phq9_totals = phq9_data.drop(columns='question').groupby(['sub', 'condition', 'group'], as_index=False).sum()
phq9_totals = phq9_totals.iloc[:, :-1]
int_labs = [f'recreate_{r}' for r in range(4)] + ['act']
text_cols_ls = list(text_cols)
rec_cols = [c for c in text_cols if 'recreate' in c]
act_cols = [c for c in text_cols if 'act' in c]
# %%
store_phq9_text_simmats = {'mh': [], 'ml': []}
store_subs = {'mh': [], 'ml': []}
store_sims_pd = pd.DataFrame()
store_act_emb = {'mh': [], 'ml': []}
for r, row in int_data.iterrows():
    condition = row['condition'].lower()
    sub = row['sub']
    store_subs[condition].append(sub)
    sub_rec_texts_embd = model.encode(row[rec_cols].values.tolist())
    sub_act_texts_embd = model.encode(row[act_cols].values.tolist())
    store_act_emb[condition].append(sub_act_texts_embd)
    sub_texts_embd = np.concat([sub_rec_texts_embd, sub_act_texts_embd])
    phq9_texts_sim = cossim(sub_texts_embd, phq9_embds)
    store_phq9_text_simmats[condition].append(phq9_texts_sim)

    tmp_pd = pd.DataFrame(phq9_texts_sim, index=rec_cols + act_cols, columns=phq9_qs_lab).reset_index()
    tmp_pd = tmp_pd.melt(id_vars=['index'], var_name='question', value_name='sim')
    tmp_pd.insert(0, 'sub', sub)
    tmp_pd.insert(1, 'condition', condition.upper())
    store_sims_pd = pd.concat([store_sims_pd, tmp_pd])

store_sims_pd.rename(columns={'index': 'int_text', 'question': 'phq9_q'}, inplace=True)

# %% Plot average similiarity
store_phq9_text_simmats_avg = {k: np.mean(v, axis=0) for k, v in store_phq9_text_simmats.items()}

plt.close('all')
pc.r, pc.c, pc.mlt = 1, 2, 2.75
pc.figsize = ((pc.c + 4.25) * pc.mlt, (pc.r + 1.25) * pc.mlt)

fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize)
pc.i = 0
pc.j = 0
pc.onerow = True
# axes = np.array([[axes]])
axes = np.array([axes])
# axes = np.array([axes]).T
pc.axes = axes
pc.ax_ts(16, 1.1)
pc.l_fs(12, 0.85)
ts = 14
pc.xyt_ls(ts, 10)
pc.ax_ls(16)
pc.kde_lw = 3
pc.p_lab_spec[2] = 12
pc.p_lab_spec[0] = -0.1
pc.p_lab_spec[1] = 1.1
pc.dpi_val = 300

sns.heatmap(store_phq9_text_simmats_avg['mh'], cmap='Blues', ax=pc.ax, annot=True, vmin=0, vmax=0.5)
pc.ax.set_xticklabels(phq9_qs_lab, rotation=45)
pc.ax.set_yticklabels(int_labs, rotation=45)
pc.ax.set_title('MH')

pc.j = 1
sns.heatmap(store_phq9_text_simmats_avg['ml'], cmap='Blues', ax=pc.ax, annot=True, vmin=0, vmax=0.5)
pc.ax.set_xticklabels(phq9_qs_lab, rotation=45)
pc.ax.set_yticklabels(int_labs, rotation=45)
pc.ax.set_title('ML')

plt.tight_layout()
if savePlot:
    plt.savefig(f"{plots_path}avg_similiarity_phq9-vs-ppt-entries.pdf", dpi=300)

# %% Plot sim between phq9 statements and recrete, act texts
plt.close('all')
pc.r, pc.c, pc.mlt = 5, 1, 1.75
pc.figsize = ((pc.c + 8.25) * pc.mlt, (pc.r + 1.75) * pc.mlt)

fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize)
axes = np.array([axes]).T

pc.axes = axes
pc.i, pc.j = 0, 0
pc.onerow = False

pc.ax_ts(16, 1.1)
pc.l_fs(8, 0)
ts = 14
pc.xyt_ls(5, 5)
pc.ax_ls(10)
pc.kde_lw = 3
pc.p_lab_spec[2] = 12
pc.p_lab_spec[0] = -0.1
pc.p_lab_spec[1] = 1.1
pc.dpi_val = 300

ssize = 5.75
slw = 0.2
alpha = 0.5

for pc.i, int_text in enumerate(text_cols_ls):
    pd_to_plot = store_sims_pd[store_sims_pd['int_text'] == int_text]
    bp = sns.boxplot(data=pd_to_plot, x='phq9_q', hue='condition', y='sim', width=0.4,
                     linewidth=slw + 1, ax=pc.ax, legend=True, gap=0.2)

    for patch in bp.patches:
        face_color = patch.get_facecolor()
        patch.set_facecolor((*face_color[:3], alpha))
    sns.violinplot(data=pd_to_plot, x='phq9_q', hue='condition', y='sim', width=0.4, linewidth=slw + 1, ax=pc.ax,
                   legend=False, gap=0.2)
    sns.stripplot(data=pd_to_plot, x='phq9_q', hue='condition', y='sim', edgecolor='black', linewidth=slw, dodge=True,
                  jitter=0.05, size=ssize, legend=False, ax=pc.ax)

    pc.ax.set_title(int_labs[pc.i])
    pc.ax.set_ylim(-0.2, 0.75)

plt.tight_layout()
if savePlot:
    plt.savefig(f"{plots_path}ind_similiarity_phq9-vs-ppt-entries.pdf", dpi=300)

# %%
phq9_diff_text_sim_pd = store_sims_pd.pivot(index=['sub', 'condition'], columns=['int_text', 'phq9_q'], values='sim')
phq9_diff_text_sim_pd.columns = phq9_diff_text_sim_pd.columns.get_level_values(
    0) + '-' + phq9_diff_text_sim_pd.columns.get_level_values(1)
phq9_diff_text_sim_pd.columns = phq9_diff_text_sim_pd.columns.str.replace('responses_', '')
phq9_diff_text_sim_pd.reset_index(inplace=True)
# phq9_diff_text_sim_pd.rename(columns={'sub-':'sub','condition-':'condition'},inplace=True)


phq9_diff_text_sim_pd = pd.merge(phq9_diff_text_sim_pd, phq9_totals, on=['sub', 'condition'])
phq9_data_long = phq9_data[:-1]
phq9_data_long = pd.pivot(phq9_data_long, index=['sub', 'condition', 'group'], columns=['question'])
phq9_data_long.columns = phq9_data_long.columns.get_level_values(0) + '-' + phq9_data_long.columns.get_level_values(1)
phq9_data_long.columns = phq9_data_long.columns.str.replace('responses_', '')
phq9_data_long.reset_index(inplace=True)
# phq9_diff_text_sim_pd.rename(columns={'sub-':'sub','condition-':'condition'},inplace=True)

phq9_diff_text_sim_pd = pd.merge(phq9_diff_text_sim_pd, phq9_data_long, on=['sub', 'condition', 'group'])

yvar = 'score_diff-phq9_q2'
xvars = ['score_b', 'act_0-phq9_q2']
col_ren = {yvar: 'q2Diff', xvars[0]: 'baselineTotal', xvars[1]: 'actQ2sim'} | {f'recreate_{r}-phq9_q2': f'rec{r}-Q2sim'
                                                                               for r in range(4)}
phq9_diff_text_sim_pd.rename(columns=col_ren, inplace=True)
phq9_diff_text_sim_pd['recAvgQ2Sim'] = phq9_diff_text_sim_pd[list(col_ren.values())[-4:]].mean(axis=1)

# %% FEEDBACK
store_feedback_dfs = []
for task_version in task_versions:
    file_path = f"_data/qs-intervention-{task_version}/processed/feedback_data.csv"
    feedback_data_pd = pd.read_csv(file_path).drop(columns=['task_version'])  # [feedback_cols]
    store_feedback_dfs.append(feedback_data_pd)
feedback_data = pd.concat(store_feedback_dfs).reset_index(drop=True)
feedback_data = feedback_data[feedback_data['group'] == 'D']

phq9_diff_text_sim_pd = pd.merge(phq9_diff_text_sim_pd, feedback_data, on=['sub', 'condition', 'group'])

feedback_cols = ['shoes', 'sim_sit', 'mood_change']

# %% Get open q data
task_versions = ['v2', 'v2b']
store_openq_dfs = []
for task_version in task_versions:
    file_path = f"_data/qs-intervention-{task_version}/processed/openq_data.csv"
    openq_data_pd = pd.read_csv(file_path).drop(columns=['task_version'])  # [openq_cols]
    store_openq_dfs.append(openq_data_pd)

openq_data = pd.concat(store_openq_dfs).reset_index(drop=True)
openq_data = openq_data[openq_data['group'] == 'D']
openq_data = openq_data[~openq_data.isna().any(axis=1)]

openq_cols = ['oq_energy_text', 'oq_mood_text', 'oq_pospert_text']
openq_data[openq_cols] = openq_data[openq_cols].replace(r'\s+\.', '.', regex=True)
openq_data[openq_cols] = openq_data[openq_cols].replace(r'\s+,', ', ', regex=True)
openq_data[openq_cols] = openq_data[openq_cols].replace(r'\n+,', ' ', regex=True)

for openq_col in openq_cols:
    openq_data[f'{openq_col}_wc'] = openq_data[openq_col].str.split(' ').apply(lambda x: len(x))

# store_openq_embs = {'mh': {'e':[],'m':[],'p':[]}, 'ml':{'e':[],'m':[],'p':[]}}
store_sims_oq_act = pd.DataFrame()

store_sims_pospert_phq9 = pd.DataFrame()
for r, row in openq_data.iterrows():
    condition = row['condition'].lower()
    sub = row['sub']

    baseline_embd = model.encode(list(row[['oq_mood_text', 'oq_energy_text']].values))
    if int_data[int_data['sub'] == sub]['responses_act_0'].shape[0] > 0:
        sub_act_text = [int_data[int_data['sub'] == sub]['responses_act_0'].values[0]]
        act_embd = model.encode(sub_act_text)
        baseline_act_sim = cossim(baseline_embd, act_embd).T
    else:
        baseline_act_sim = np.nan

    pospert_embd = model.encode(list(row[['oq_pospert_text']].values))
    pospert_phq9_sim = cossim(pospert_embd, phq9_embds)

    # store_openq_embs[condition]['m'].append(mood_emb)
    # store_openq_embs[condition]['e'].append(energy_emb)
    # store_openq_embs[condition]['p'].append(pospert_emb)

    tmp_pd = pd.DataFrame(baseline_act_sim, index=['act'], columns=['mood', 'energy']).reset_index()
    tmp_pd = tmp_pd.melt(id_vars=['index'], var_name='question', value_name='sim')
    tmp_pd.insert(0, 'sub', sub)
    tmp_pd.insert(1, 'condition', condition.upper())
    store_sims_oq_act = pd.concat([store_sims_oq_act, tmp_pd])

    tmp_pd = pd.DataFrame(pospert_phq9_sim, index=['pospert'],
                          columns=[f'q_{q + 1}' for q in range(9)]).reset_index()
    tmp_pd = tmp_pd.melt(id_vars=['index'], var_name='question', value_name='sim')
    tmp_pd.insert(0, 'sub', sub)
    tmp_pd.insert(1, 'condition', condition.upper())
    store_sims_pospert_phq9 = pd.concat([store_sims_pospert_phq9, tmp_pd])

store_sims_oq_act.rename(columns={'index': 'int_text', 'question': 'oq'}, inplace=True)
store_sims_pospert_phq9.rename(columns={'index': 'oq_text', 'question': 'phq9_q'}, inplace=True)
# %%
store_sims_oq_act_wide = pd.pivot(store_sims_oq_act, index=['sub', 'condition'], columns='oq',
                                  values='sim').reset_index()
store_sims_oq_act_wide.rename(columns={'energy': 'energyActSim', 'mood': 'moodActSim'}, inplace=True)

# %%
phq9_diff_text_sim_pd = phq9_diff_text_sim_pd[phq9_diff_text_sim_pd['sub'] != 'sub95_v2b']  # outlier

lmer_pd = phq9_diff_text_sim_pd[['sub', 'condition'] + list(col_ren.values()) + ['recAvgQ2Sim'] + feedback_cols]
lmer_pd = pd.merge(lmer_pd, store_sims_oq_act_wide, on=['sub', 'condition'])
lmer_pd.to_csv('_data/sim_diff_data.csv', index=False)

lmer_wtext_pd = pd.merge(int_data[['sub', 'condition', 'responses_act_0']], lmer_pd, on=['sub', 'condition'])
lmer_wtext_pd.to_csv('_data/sim_diff_data_wtext.csv', index=False)

# plt.scatter(phq9_diff_text_sim_pd['act_0-phq9_q2'],phq9_diff_text_sim_pd['score_diff-phq9_q2'])
#
# md = smf.mixedlm("q2Diff ~ (actQ2sim+baselineTotal)*condition", data=lmer_pd, groups=lmer_pd['sub'])
# mdf = md.fit()
# print(mdf.summary())
# %%
plt.close('all')
lmer_pd_mh = lmer_pd[lmer_pd['condition'] == 'MH']
lmer_pd_ml = lmer_pd[lmer_pd['condition'] == 'ML']
corr_mh = stats.pearsonr(lmer_pd_mh['actQ2sim'], lmer_pd_mh['q2Diff'])
r_corr_mh = corr_mh[0]
p_corr_mh = min(corr_mh[1] * 2, 1)

corr_ml = stats.pearsonr(lmer_pd_ml['actQ2sim'], lmer_pd_ml['q2Diff'])
r_corr_ml = corr_ml[0]
p_corr_ml = min(corr_ml[1] * 2, 1)
# %%
plt.close('all')
pc.r, pc.c, pc.mlt = 1, 2, 1.75
pc.figsize = ((pc.c + 4.25) * pc.mlt, (pc.r + 1.25) * pc.mlt)

fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize, sharex=True, sharey=True)
# axes = np.array([[axes]])
# axes = np.array([axes]).T
axes = np.array([axes])
pc.onerow = True
pc.i = 0
pc.j = 0

pc.axes = axes
pc.ax_ts(16, 1.1)
pc.l_fs(12, 0.85)
ts = 14
pc.xyt_ls(ts, 10)
pc.ax_ls(16)
pc.kde_lw = 3
pc.p_lab_spec[2] = 12
pc.p_lab_spec[0] = -0.1
pc.p_lab_spec[1] = 1.1
pc.dpi_val = 300

ssize = 5.75
slw = 0.2
alpha = 0.5
sns.regplot(lmer_pd_mh, x='actQ2sim', y='q2Diff', color='tab:blue', label='MH', ax=pc.ax)
sns.regplot(lmer_pd_ml, x='actQ2sim', y='q2Diff', color='tab:orange', label='ML', ax=pc.ax)
pc.ax.set_xlabel('Act Text and PHQ-9 Q2 cos sim.')
pc.ax.set_ylabel('PHQ-9 Q2 change.')
# .legend()
pc.ax.legend()

pc.j = 1
sns.regplot(lmer_pd_mh, x='moodActSim', y='q2Diff', color='tab:blue', label='MH', ax=pc.ax)
sns.regplot(lmer_pd_ml, x='moodActSim', y='q2Diff', color='tab:orange', label='ML', ax=pc.ax)
pc.ax.set_xlabel('Mood Open Q to Act Text cos sim.')
pc.ax.set_ylabel('PHQ-9 Q2 change.')
pc.ax.legend()
# pc.ax.set_ylim(0,0.8)

plt.tight_layout()
if savePlot:
    plt.savefig(f"{plots_path}actText-q2Sim_vs_q2Diff.pdf", dpi=300)

# %% Pospert and q2 sim
plt.close('all')
pc.r, pc.c, pc.mlt = 1, 1, 2
pc.figsize = ((pc.c + 5.25) * pc.mlt, (pc.r + 1.25) * pc.mlt)

fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize)
pc.i = 0
pc.onerow = True
pc.j = 0

axes = np.array([[axes]])
# axes = np.array([axes]).T
pc.axes = axes
pc.ax_ts(16, 1.1)
pc.l_fs(12, 0.85)
ts = 14
pc.xyt_ls(ts, 10)
pc.ax_ls(16)
pc.kde_lw = 3
pc.p_lab_spec[2] = 12
pc.p_lab_spec[0] = -0.1
pc.p_lab_spec[1] = 1.1
pc.dpi_val = 300

# Time per question
ssize = 5.75
slw = 0.2
alpha = 0.5
hue_cols = ['tab:blue', 'tab:orange']
sns.violinplot(data=store_sims_pospert_phq9, x='phq9_q', y='sim', hue='condition', inner=None, ax=pc.ax,
               palette=hue_cols, legend=False)
bp = sns.boxplot(data=store_sims_pospert_phq9, x='phq9_q', y='sim', hue='condition', width=0.4, palette=hue_cols,
                 linewidth=slw + 1, ax=pc.ax, legend=True, gap=0.2)
for patch in bp.patches:
    face_color = patch.get_facecolor()
    patch.set_facecolor((*face_color[:3], alpha))
sns.stripplot(data=store_sims_pospert_phq9, x='phq9_q', y='sim', hue='condition', edgecolor='black', palette=hue_cols,
              linewidth=slw, dodge=True, jitter=0.05,
              size=ssize, legend=False, ax=pc.ax)
pc.ax.set_xlabel('PHQ-9 question')
pc.ax.set_ylabel('Cos similarity')

pc.ax.set_title(f'Cos sim between +ve perturbation and PHQ-9 questions')
plt.tight_layout()
if savePlot:
    plt.savefig(f"{plots_path}pospert-phq9qs_sim.pdf", dpi=300)

# %% Load sentiment analyser
sentiment_pipeline = pipeline("sentiment-analysis",
                              # model="distilbert/distilbert-base-uncased-finetuned-sst-2-english",
                              # model="distilbert/distilbert-base-uncased",
                              model="siebert/sentiment-roberta-large-english",
                              device=device_name)

act_texts = int_data['responses_act_0'].tolist()

outputs = sentiment_pipeline(act_texts)
scores = [out['score'] * {'POSITIVE': 1, 'NEGATIVE': -1}[out['label']] for out in outputs]
int_data['act_sentiment'] = scores
phq9_diff_int_sent = pd.merge(phq9_diff_text_sim_pd, int_data, on=['sub', 'condition'])

phq9_diff_int_sent.to_csv('_data/phq9_diff_int_sent.csv', index=False)
# %%

phq9_diff_int_sent_mh = phq9_diff_int_sent[phq9_diff_int_sent['condition'] == 'MH']
phq9_diff_int_sent_ml = phq9_diff_int_sent[phq9_diff_int_sent['condition'] == 'ML']
# corr_diff_sent = stats.pearsonr(phq9_diff_int_sent['act_sentiment'], phq9_diff_int_sent['q2Diff'])
# r_corr_diff_sent = corr_diff_sent[0]
# p_corr_diff_sent = min(corr_diff_sent[1] * 2, 1)

# %%
plt.close('all')
pc.r, pc.c, pc.mlt = 1, 1, 1.75
pc.figsize = ((pc.c + 4.25) * pc.mlt, (pc.r + 1.25) * pc.mlt)

fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize, sharex=True, sharey=True)
# axes = np.array([[axes]])
# axes = np.array([axes]).T
axes = np.array([axes])
pc.onerow = True
pc.i = 0
pc.j = 0

pc.axes = axes
pc.ax_ts(16, 1.1)
pc.l_fs(12, 0.85)
ts = 14
pc.xyt_ls(ts, 10)
pc.ax_ls(16)
pc.kde_lw = 3
pc.p_lab_spec[2] = 12
pc.p_lab_spec[0] = -0.1
pc.p_lab_spec[1] = 1.1
pc.dpi_val = 300

ssize = 5.75
slw = 0.2
alpha = 0.5
sns.regplot(phq9_diff_int_sent_mh, x='act_sentiment', y='q2Diff', color='tab:blue', label='MH', ax=pc.ax)
sns.regplot(phq9_diff_int_sent_ml, x='act_sentiment', y='q2Diff', color='tab:orange', label='ML', ax=pc.ax)
# sns.regplot(lmer_pd_ml, x='actQ2sim', y='q2Diff', color='tab:orange', label='ML', ax=pc.ax)
pc.ax.set_xlabel('Act Text Sentiment')
pc.ax.set_ylabel('PHQ-9 Q2 change.')
# .legend()
pc.ax.legend()

pc.ax.set_ylim([-0.5, 0.5])
# pc.j = 1
# sns.regplot(lmer_pd_mh, x='moodActSim', y='q2Diff', color='tab:blue', label='MH', ax=pc.ax)
# sns.regplot(lmer_pd_ml, x='moodActSim', y='q2Diff', color='tab:orange', label='ML', ax=pc.ax)
# pc.ax.set_xlabel('Mood Open Q to Act Text cos sim.')
# pc.ax.set_ylabel('PHQ-9 Q2 change.')
# pc.ax.legend()
# # pc.ax.set_ylim(0,0.8)
#
plt.tight_layout()
if savePlot:
    plt.savefig(f"{plots_path}actText-sent_vs_q2Diff.pdf", dpi=300)
# %%
phq9_q2_sev = [f'I have not at all been bothered by {phq9_q2.lower()}',
               f'For several days, I have been bothered by {phq9_q2.lower()}',
               f'For more than half the days, I have been bothered by {phq9_q2.lower()}',
               f'For nearly every day, I have been bothered by {phq9_q2.lower()}']
phq9_q2_sev = [f'I have never been bothered by {phq9_q2.lower()}',
               f'I have been bothered by {phq9_q2.lower()[:-1]} for several days.',
               f'I have been bothered by {phq9_q2.lower()[:-1]} for more than half the days.',
               f'I have been bothered by {phq9_q2.lower()[:-1]} nearly every day.']

phq9_q2_sev_embds = model.encode(phq9_q2_sev)
# act_embds= model.encode(act_texts)#
# %%
act_sev_embd_sim_mh = cossim(np.concat(store_act_emb['mh']), phq9_q2_sev_embds)
act_sev_scores_mh = np.arange(0, 4) @ act_sev_embd_sim_mh.T

act_sev_embd_sim_ml = cossim(np.concat(store_act_emb['ml']), phq9_q2_sev_embds)
act_sev_scores_ml = np.arange(0, 4) @ act_sev_embd_sim_ml.T

act_sev_mh = pd.DataFrame({'sub': store_subs['mh'], 'condition': 'MH', 'sim_sev_scores': act_sev_scores_mh})
act_sev_ml = pd.DataFrame({'sub': store_subs['ml'], 'condition': 'ML', 'sim_sev_scores': act_sev_scores_ml})

act_sev_pd = pd.concat([act_sev_mh, act_sev_ml])
act_sev_pd = pd.merge(act_sev_pd, phq9_diff_int_sent[['sub', 'condition', 'q2Diff']], on=['sub', 'condition'])
act_sev_pd['sim_sev_scores'] = (act_sev_pd['sim_sev_scores'] - (act_sev_pd['sim_sev_scores'].mean()))/act_sev_pd['sim_sev_scores'].std()

act_sev_pd_mh = act_sev_pd[act_sev_pd['condition'] == 'MH']
act_sev_pd_ml = act_sev_pd[act_sev_pd['condition'] == 'ML']

# %%
plt.close('all')
pc.r, pc.c, pc.mlt = 1, 1, 1.75
pc.figsize = ((pc.c + 4.25) * pc.mlt, (pc.r + 1.25) * pc.mlt)

fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize, sharex=True, sharey=True)
# axes = np.array([[axes]])
# axes = np.array([axes]).T
axes = np.array([axes])
pc.onerow = True
pc.i = 0
pc.j = 0

pc.axes = axes
pc.ax_ts(16, 1.1)
pc.l_fs(12, 0.85)
ts = 14
pc.xyt_ls(ts, 10)
pc.ax_ls(16)
pc.kde_lw = 3
pc.p_lab_spec[2] = 12
pc.p_lab_spec[0] = -0.1
pc.p_lab_spec[1] = 1.1
pc.dpi_val = 300

ssize = 5.75
slw = 0.2
alpha = 0.5
sns.regplot(act_sev_pd_mh, x='sim_sev_scores', y='q2Diff', color='tab:blue', label='MH', ax=pc.ax)
sns.regplot(act_sev_pd_ml, x='sim_sev_scores', y='q2Diff', color='tab:orange', label='ML', ax=pc.ax)
# sns.regplot(lmer_pd_ml, x='actQ2sim', y='q2Diff', color='tab:orange', label='ML', ax=pc.ax)
pc.ax.set_xlabel('PHQ9 Q2 sim based sev score')
pc.ax.set_ylabel('PHQ-9 Q2 change.')
# .legend()
pc.ax.legend()
plt.tight_layout()
act_sev_pd.to_csv('_data/act_sev_pd.csv', index=False)
if savePlot:
    plt.savefig(f"{plots_path}actText-sim-sev-score_vs_q2Diff.pdf", dpi=300)
