# %% Import libraries and load data
import csv
import pickle
import os
import pandas as pd
import numpy as np
import matplotlib
import scipy.stats as stats
from transformers import pipeline
import statsmodels.api as sm
import statsmodels.formula.api as smf
from textblob import TextBlob
import re
from difflib import get_close_matches
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity as cossim

device_name = 'mps'

matplotlib.use('Qt5Agg')
# matplotlib.use('TkAgg')
# matplotlib.get_backend()
import matplotlib.pyplot as plt

plt.ion()
import seaborn as sns

# establish working directory
cwd = os.getcwd()
cwd_split = cwd.split('/')
base_task = 'qs_intervention'
path_to_go = '_analysis'
cwd_base = '/'.join(cwd_split[:np.argwhere([p == "online_tasks" for p in cwd_split])[0][0] + 1])
os.chdir(f"{cwd_base}/{base_task}/{path_to_go}")

task_versions = ['v2', 'v2b', 'v3']
#
# saveMe = False
# saveMe = True
from objects.plot_config import *
from utils.utils import percentile

from objects.plot_config import *
from utils.utils import percentile
from utils.plot_utils import set_a_hist

data_proc_dir = f'_data/processed/'
pc = PlotConfig()
plots_path = '_plots/int/'
savePlot = True
# savePlot = False
loadTextEmb = True
saveTextEmb = False
# loadTextEmb = False
# saveTextEmb = True

runSent = False
saveSent = False
# runSent = True
# saveSent = True

model = SentenceTransformer("all-MiniLM-L6-v2", device=device_name)
sentiment_pipeline = pipeline("sentiment-analysis", model="siebert/sentiment-roberta-large-english", device=device_name)
id_cols = ['sub', 'condition', 'group', 'autobio']
hue_order = ['MH', 'ML']
hue_cols = ['tab:blue', 'tab:orange']
hue_cols2 = ['tab:green', 'tab:purple']

sbin3_order = ['q33', 'm', 'q66']
join_cols = id_cols + ['s_bin', 's_bin3']
phq9_diff = pd.read_csv(f'{data_proc_dir}phq9_diff_data_wide.csv')
mood_diff = pd.read_csv(f'{data_proc_dir}mood_data.csv')
# %% Load int data
recreate_thr, act_thr = 5, 50
cols_to_drop = [f'rts_recreate_{r}' for r in range(4)] + ['rts_act_0']
cols_text = [f'responses_recreate_{r}' for r in range(4)] + ['responses_act_0']
store_int_dfs = []
for task_version in task_versions:
    file_path = f"_data/qs-intervention-{task_version}/processed/int_data.csv"
    # int_data_pd = pd.read_csv(file_path).drop(columns=['task_version'])  # [int_cols]
    int_data_pd = pd.read_csv(file_path)
    store_int_dfs.append(int_data_pd)
int_data = pd.concat(store_int_dfs).reset_index(drop=True)
int_data.insert(4, 'autobio', int_data['task_version'] == 'qs-intervention-v3')
int_data.drop(columns=['task_version'] + cols_to_drop, inplace=True)
# int_data = int_data[int_data['group'] == 'D']
# int_data = int_data[~int_data.isna().any(axis=1)]
# text_cols = int_data.columns[3:3 + 5]

# act_data = int_data[['sub', 'condition', 'group', 'responses_act_0']]
int_data[cols_text] = int_data[cols_text].replace(r'\s+\.', '.', regex=True)
int_data[cols_text] = int_data[cols_text].replace(r'\s+,', ', ', regex=True)
int_data[cols_text] = int_data[cols_text].replace(r'\n+,', ' ', regex=True)
bad_sub_int = []
for col_text in cols_text:
    int_data[f'{col_text}_wc'] = int_data[col_text].str.split(' ').apply(lambda x: len(x) if type(x) == list else 0)
    if 'act' in col_text:
        bad_sub_int += list(int_data[int_data[f'{col_text}_wc'] < act_thr]['sub'])
    else:
        bad_sub_int += list(int_data[int_data[f'{col_text}_wc'] < recreate_thr]['sub'])
bad_sub_int = list(set(bad_sub_int))
int_data.columns = int_data.columns.str.replace('responses_', '')

# %% Load open q data
oq_thr = 15
cols_oq_to_drop = ['oq_mood_rt', 'oq_energy_rt', 'oq_pospert_rt']
cols_oq_text = ['oq_mood_text', 'oq_energy_text', 'oq_pospert_text']
store_openq_dfs = []
for task_version in task_versions:
    file_path = f"_data/qs-intervention-{task_version}/processed/openq_data.csv"
    # openq_data_pd = pd.read_csv(file_path).drop(columns=['task_version'])  # [openq_cols]
    openq_data_pd = pd.read_csv(file_path)
    store_openq_dfs.append(openq_data_pd)
openq_data = pd.concat(store_openq_dfs).reset_index(drop=True)
openq_data.insert(4, 'autobio', openq_data['task_version'] == 'qs-intervention-v3')
openq_data.drop(columns=['task_version'] + cols_oq_to_drop, inplace=True)
# openq_data = openq_data[openq_data['group'] == 'D']
# openq_data = openq_data[~openq_data.isna().any(axis=1)]
# text_cols = openq_data.columns[3:3 + 5]

openq_data[cols_oq_text] = openq_data[cols_oq_text].replace(r'\s+\.', '.', regex=True)
openq_data[cols_oq_text] = openq_data[cols_oq_text].replace(r'\s+,', ', ', regex=True)
openq_data[cols_oq_text] = openq_data[cols_oq_text].replace(r'\n+,', ' ', regex=True)
bad_sub_oq = []
for col_oq_text in cols_oq_text:
    openq_data[f'{col_oq_text}_wc'] = openq_data[col_oq_text].str.split(' ').apply(
        lambda x: len(x) if type(x) == list else 0)
    bad_sub_oq += list(openq_data[openq_data[f'{col_oq_text}_wc'] < oq_thr]['sub'])

bad_sub_oq = list(set(bad_sub_oq))
openq_data.columns = openq_data.columns.str.replace('_text', '')

# %% Combine
text_data = pd.merge(int_data, openq_data, on=id_cols)
text_data = pd.merge(phq9_diff[join_cols], text_data, on=id_cols)
text_data.to_csv(f"{data_proc_dir}text_data.csv", index=False)
# %% Transcripts and PHQ9 statements
transcripts = {'MH': [], 'ML': []}
transcripts_embd = {'MH': [], 'ML': []}
for k, v in transcripts.items():
    with open(f"_data/transcripts/{k}.txt") as f:
        for line in f:
            v.append(line)
    transcripts_embd[k] = model.encode(v)

phq9_q1 = "Little interest or pleasure in doing things."
phq9_q2 = "Feeling down, depressed, or hopeless."
phq9_q3 = "Trouble falling or staying asleep, or sleeping too much."
phq9_q4 = "Feeling tired or having little energy."
phq9_q5 = "Poor appetite or overeating."
phq9_q6 = "Feeling bad about yourself - or that you are a failure or have let yourself or your family down."
phq9_q7 = "Trouble concentrating on things, such as reading the newspaper or watching television."
phq9_q8 = "Moving or speaking so slowly that other people could have noticed? Or the opposite - being so fidgety or restless that you have been moving around a lot more than usual."
phq9_q9 = "Thoughts that you would be better off dead or of hurting yourself in some way."

phq9_qs = [phq9_q1, phq9_q2, phq9_q3, phq9_q4, phq9_q5, phq9_q6, phq9_q7, phq9_q8, phq9_q9]
phq9_qs_lab = [f'phq9_q{q + 1}' for q in range(9)]

phq9_embds = {plab: e for plab, e in zip(phq9_qs_lab, model.encode(phq9_qs))}
phq9_embds = model.encode(phq9_qs)
# %% Calculate/Load Text embeddings
text_keys = [re.sub(r'responses_|_text', '', c) for c in cols_text + cols_oq_text]
if loadTextEmb:
    with open(f'{data_proc_dir}text_embd_autobio.pkl', 'rb') as f:
        embd_autobio = pickle.load(f)
    with open(f'{data_proc_dir}text_embd_nonautobio.pkl', 'rb') as f:
        embd_nonautobio = pickle.load(f)
else:
    embd_autobio = {'MH': {}, 'ML': {}}
    embd_nonautobio = {'MH': {}, 'ML': {}}
    # for r, row in text_data.loc[:10].iterrows():
    for r, row in text_data.iterrows():
        sub = row['sub']
        condition = row['condition']
        is_autobio = row['autobio']
        if is_autobio:
            embd_autobio[condition][sub] = {tk: model.encode(row[tk]) if type(row[tk]) == str else None for tk in
                                            text_keys}
        else:
            embd_nonautobio[condition][sub] = {tk: model.encode(row[tk]) if type(row[tk]) == str else None for tk in
                                               text_keys}
    if saveTextEmb:
        with open(f'{data_proc_dir}text_embd_autobio.pkl', 'wb') as f:
            pickle.dump(embd_autobio, f)
        with open(f'{data_proc_dir}text_embd_nonautobio.pkl', 'wb') as f:
            pickle.dump(embd_nonautobio, f)

# %% Calcualte text features
bad_sub = list(set(bad_sub_int + bad_sub_oq))
text_data = text_data[~text_data['sub'].isin(bad_sub)]

rec_transcript_autobio_dict = {'MH': [], 'ML': []}
rec_transcript_nonautobio_dict = {'MH': [], 'ML': []}

text_measures = []
for r, row in text_data.iterrows():
    sub = row['sub']
    condition = row['condition']
    is_autobio = row['autobio']
    group = row['group']
    s_bin = row['s_bin']
    s_bin3 = row['s_bin3']

    if is_autobio:
        text_embd = embd_autobio[condition][sub]
    else:
        text_embd = embd_nonautobio[condition][sub]

    rec_embs = np.array([v for k, v in text_embd.items() if 'recreate' in k])
    rec_embs_avg = rec_embs.mean(axis=0, keepdims=True)
    oq_rembs = np.array([v for k, v in text_embd.items() if ('oq_' in k) and ('pospert' not in k)])
    oq_rembs_avg = oq_rembs.mean(axis=0, keepdims=True)

    # Transcript vs recreated sim
    rec_transcript_sim = cossim(rec_embs, transcripts_embd[condition])

    if is_autobio:
        rec_transcript_autobio_dict[condition].append(rec_transcript_sim)
    else:
        rec_transcript_nonautobio_dict[condition].append(rec_transcript_sim)

    # Recreated vs created
    rec_act_sim = cossim(text_embd['act_0'][:, None].T, rec_embs_avg)[0, 0]
    # Avg oq vs pospert
    oq_pospert_sim = cossim(text_embd['oq_pospert'][:, None].T, oq_rembs_avg)[0, 0]
    # Open mood vs act
    mood_act_sim = cossim(text_embd['act_0'][:, None].T, text_embd['oq_mood'][:, None].T)[0, 0]
    # Act vs pospert
    act_pospert_sim = cossim(text_embd['act_0'][:, None].T, text_embd['oq_pospert'][:, None].T)[0, 0]

    # Q2 phq9 statemetn vs open mood
    q2_mood_sim = cossim(phq9_embds[1][:, None].T, text_embd['oq_mood'][:, None].T)[0, 0]
    # Q2 phq9 statemetn vs act
    q2_act_sim = cossim(phq9_embds[1][:, None].T, text_embd['act_0'][:, None].T)[0, 0]
    # Q2 phq9 statemetn vs pospoert
    q2_pospert_sim = cossim(phq9_embds[1][:, None].T, text_embd['oq_pospert'][:, None].T)[0, 0]

    tmp_dict = {'sub': sub, 'condition': condition, 'group': group, 'autobio': is_autobio, 's_bin': s_bin,
                's_bin3': s_bin3, 'avgRecAct_sim': rec_act_sim,
                'avgBaselinePospert_sim': oq_pospert_sim, 'moodBaselineAct_sim': mood_act_sim,
                'actPospert_sim': act_pospert_sim, 'q2Mood_sim': q2_mood_sim, 'q2Act_sim': q2_act_sim,
                'q2Pospert_sim': q2_pospert_sim}
    text_measures.append(tmp_dict)

# Calculate average similarity between transcripts and recreated texts
rec_transcript_nonautobio_dict['MH'] = np.array(rec_transcript_nonautobio_dict['MH']).mean(axis=0)
rec_transcript_nonautobio_dict['ML'] = np.array(rec_transcript_nonautobio_dict['ML']).mean(axis=0)
rec_transcript_autobio_dict['MH'] = np.array(rec_transcript_autobio_dict['MH']).mean(axis=0)
rec_transcript_autobio_dict['ML'] = np.array(rec_transcript_autobio_dict['ML']).mean(axis=0)

text_measures = pd.DataFrame(text_measures)
if runSent:
    act_sentiment = sentiment_pipeline(list(text_data['act_0']))
    act_sentiment = [out['score'] * {'POSITIVE': 1, 'NEGATIVE': -1}[out['label']] for out in act_sentiment]
    text_measures['act_sentiment'] = act_sentiment
    pospert_sentiment = sentiment_pipeline(list(text_data['oq_pospert']))
    pospert_sentiment = [out['score'] * {'POSITIVE': 1, 'NEGATIVE': -1}[out['label']] for out in pospert_sentiment]
    text_measures['pospert_sentiment'] = pospert_sentiment

    if saveSent:
        text_measures.to_csv(f'{data_proc_dir}text_measures.csv', index=False)
else:
    text_measures = pd.read_csv(f'{data_proc_dir}text_measures.csv')

# %% Merge text measures
phq9_diff_text = pd.merge(text_measures, phq9_diff, on=join_cols)
phq9_diff_text.to_csv(f'{data_proc_dir}phq9_diff_text.csv', index=False)
mood_diff_text = pd.merge(text_measures, mood_diff, on=join_cols)
mood_diff_text.to_csv(f'{data_proc_dir}mood_diff_text.csv', index=False)

# %% Plot transcript and recreated similarity
plt.close('all')
pc.r, pc.c, pc.mlt = 2, 2, 2.75
pc.figsize = ((pc.c + 4.25) * pc.mlt, (pc.r + 1.25) * pc.mlt)

fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize)
# axes = np.array([[axes]])
# axes = np.array([axes]).T
pc.axes = axes
pc.ax_ts(16, 1.1)
pc.l_fs(12, 0.85)
ts = 14
pc.xyt_ls(ts, 10)
pc.ax_ls(16)
pc.kde_lw = 3
pc.p_lab_spec[2] = 12
pc.p_lab_spec[0] = -0.1
pc.p_lab_spec[1] = 1.1
pc.dpi_val = 300

pc.onerow = False
pc.i, pc.j = 0, 0
sns.heatmap(rec_transcript_nonautobio_dict['MH'], cmap='Oranges', ax=pc.ax, annot=True)
pc.ax.set_title('MH - Non-autobiographical')

pc.i, pc.j = 0, 1
sns.heatmap(rec_transcript_nonautobio_dict['ML'], cmap='Blues', ax=pc.ax, annot=True)
pc.ax.set_title('ML - Non-autobiographical')

pc.i, pc.j = 1, 0
sns.heatmap(rec_transcript_autobio_dict['MH'], cmap='Oranges', ax=pc.ax, annot=True)
pc.ax.set_title('MH - Autobiographical')

pc.i, pc.j = 1, 1
sns.heatmap(rec_transcript_autobio_dict['ML'], cmap='Blues', ax=pc.ax, annot=True)
pc.ax.set_title('ML - Autobiographical')

plt.suptitle('Cosine similarity - transcripts vs recreated')
plt.tight_layout()
if savePlot:
    plt.savefig(f"{plots_path}transcript-vs-recreated_sims.pdf", dpi=300)

# %% Plot pospert sentiment
to_plot = ['avgRecAct_sim', 'avgBaselinePospert_sim', 'moodBaselineAct_sim', 'actPospert_sim', 'q2Mood_sim',
           'q2Act_sim',
           'q2Pospert_sim', 'act_sentiment', 'pospert_sentiment']
to_plot_titles = ['Avg Recreated vs Act', 'Avg Baseline vs Pospert', 'Mood Baseline vs Act', 'Act vs Pospert',
                  'Q2 statement vs Baseline Mood', 'Q2 Statement vs Act', 'Q2 Statement vs Pospert', 'Act Sentiment',
                  'Pospert Sentiment']
for var_to_plot, plot_title in zip(to_plot, to_plot_titles):
    plt.close('all')
    pc.r, pc.c, pc.mlt = 1, 1, 2.75
    pc.figsize = ((pc.c + 4.25) * pc.mlt, (pc.r + 1.25) * pc.mlt)

    fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize)
    axes = np.array([[axes]])
    # axes = np.array([axes]).T
    pc.axes = axes
    pc.i, pc.j = 0, 0
    pc.ax_ts(16, 1.1)
    pc.l_fs(12, 0.85)
    ts = 14
    pc.xyt_ls(ts, 10)
    pc.ax_ls(16)
    pc.kde_lw = 3
    pc.p_lab_spec[2] = 12
    pc.p_lab_spec[0] = -0.1
    pc.p_lab_spec[1] = 1.1
    pc.dpi_val = 300

    ssize = 5.75
    slw = 0.2
    alpha = 0.5
    bp = sns.boxplot(data=phq9_diff_text, x='autobio', y=var_to_plot, hue='condition', width=0.4,
                     linewidth=slw + 1, ax=pc.ax, legend=True, gap=0.2, palette=hue_cols)

    for patch in bp.patches:
        face_color = patch.get_facecolor()
        patch.set_facecolor((*face_color[:3], alpha))
    sns.violinplot(data=phq9_diff_text, x='autobio', y=var_to_plot, hue='condition', width=0.4, linewidth=slw + 1,
                   ax=pc.ax,
                   legend=False, gap=0.2)
    sns.stripplot(data=phq9_diff_text, x='autobio', y=var_to_plot, hue='condition', edgecolor='black', linewidth=slw,
                  dodge=True,
                  jitter=0.05, size=ssize, legend=False, ax=pc.ax)
    if 'sentiment' in var_to_plot:
        pc.ax.set_ylabel('Sentiment')
        plt.suptitle(f'Sentiment - {plot_title}')
    else:
        pc.ax.set_ylabel('Cosine similarity')
        plt.suptitle(f'Similarity - {plot_title}')

    pc.ax.set_xlabel('Autobiographical')
    plt.tight_layout()
    if savePlot:
        plt.savefig(f"{plots_path}{var_to_plot}.pdf", dpi=300)

# %% Plot pospert sentiment - s bins
to_plot = ['avgRecAct_sim', 'avgBaselinePospert_sim', 'moodBaselineAct_sim', 'actPospert_sim', 'q2Mood_sim',
           'q2Act_sim',
           'q2Pospert_sim', 'act_sentiment', 'pospert_sentiment']
to_plot_titles = ['Avg Recreated vs Act', 'Avg Baseline vs Pospert', 'Mood Baseline vs Act', 'Act vs Pospert',
                  'Q2 statement vs Baseline Mood', 'Q2 Statement vs Act', 'Q2 Statement vs Pospert', 'Act Sentiment',
                  'Pospert Sentiment']
for var_to_plot, plot_title in zip(to_plot, to_plot_titles):
    plt.close('all')
    pc.r, pc.c, pc.mlt = 3, 1, 2.75
    pc.figsize = ((pc.c + 4.25) * pc.mlt, (pc.r + 1.25) * pc.mlt)

    pc.i, pc.j = 0, 0
    fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize)
    # axes = np.array([[axes]])
    axes = np.array([axes]).T
    pc.axes = axes
    pc.ax_ts(16, 1.1)
    pc.l_fs(12, 0.85)
    ts = 14
    pc.xyt_ls(ts, 10)
    pc.ax_ls(16)
    pc.kde_lw = 3
    pc.p_lab_spec[2] = 12
    pc.p_lab_spec[0] = -0.1
    pc.p_lab_spec[1] = 1.1
    pc.dpi_val = 300

    ssize = 5.75
    slw = 0.2
    alpha = 0.5
    for pc.i, s_bin3 in enumerate(sbin3_order):
        phq9_diff_text_bin = phq9_diff_text[phq9_diff_text['s_bin3'] == s_bin3]
        bp = sns.boxplot(data=phq9_diff_text_bin, x='autobio', y=var_to_plot, hue='condition', width=0.4,
                         linewidth=slw + 1, ax=pc.ax, legend=True, gap=0.2, palette=hue_cols)

        for patch in bp.patches:
            face_color = patch.get_facecolor()
            patch.set_facecolor((*face_color[:3], alpha))
        sns.violinplot(data=phq9_diff_text_bin, x='autobio', y=var_to_plot, hue='condition', width=0.4, linewidth=slw + 1,
                       ax=pc.ax,
                       legend=False, gap=0.2)
        sns.stripplot(data=phq9_diff_text_bin, x='autobio', y=var_to_plot, hue='condition', edgecolor='black',
                      linewidth=slw,
                      dodge=True,
                      jitter=0.05, size=ssize, legend=False, ax=pc.ax)
        pc.ax.set_title(f'\nPHQ9 Bin: {s_bin3}')
        if 'sentiment' in var_to_plot:
            pc.ax.set_ylabel('Sentiment')
            plt.suptitle(f'Sentiment - {plot_title}')
        else:
            pc.ax.set_ylabel('Cosine similarity')
            plt.suptitle(f'Similarity - {plot_title}')

        pc.ax.set_xlabel('Autobiographical')
        plt.tight_layout()
    if savePlot:
        plt.savefig(f"{plots_path}{var_to_plot}_binned.pdf", dpi=300)

# %% Plot Q2 Diff against stuff
to_plot = ['avgRecAct_sim', 'avgBaselinePospert_sim', 'moodBaselineAct_sim', 'actPospert_sim', 'q2Mood_sim',
           'q2Act_sim', 'q2Pospert_sim']
to_plot_titles = ['Avg Recreated vs Act', 'Avg Baseline vs Pospert', 'Mood Baseline vs Act', 'Act vs Pospert',
                  'Q2 statement vs Baseline Mood', 'Q2 Statement vs Act', 'Q2 Statement vs Pospert']
plt.close('all')
pc.ax_ts(10, 1.1)
pc.l_fs(12, 0.85)
ts = 14
pc.xyt_ls(ts, 10)
pc.ax_ls(12)
pc.kde_lw = 3
pc.p_lab_spec[2] = 12
pc.p_lab_spec[0] = -0.1
pc.p_lab_spec[1] = 1.1
pc.dpi_val = 300
pc.r, pc.c, pc.mlt = 2, len(to_plot), 2
pc.figsize = ((pc.c + 5.25) * pc.mlt, (pc.r + 1.25) * pc.mlt)

fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize, sharey=True)
# axes = np.array([[axes]])
# axes = np.array([axes]).T
pc.axes = axes
pc.i, pc.j = 0, 0

ssize = 5.75
slw = 0.2
alpha = 0.5

text_diff_autobio = phq9_diff_text[phq9_diff_text['autobio'] == True]
text_diff_autobio = text_diff_autobio[~text_diff_autobio.isna().any(axis=1)]
text_diff_autobio_MH = text_diff_autobio[text_diff_autobio['condition'] == 'MH']
text_diff_autobio_ML = text_diff_autobio[text_diff_autobio['condition'] == 'ML']
# text_diff_autobio_MH = phq9_diff_text[(phq9_diff_text['autobio'] == True) & (phq9_diff_text['condition'] == 'MH')]
# text_diff_autobio_ML = phq9_diff_text[(phq9_diff_text['autobio'] == True) & (phq9_diff_text['condition'] == 'ML')]
text_diff_nonautobio = phq9_diff_text[phq9_diff_text['autobio'] == False]
text_diff_nonautobio = text_diff_nonautobio[~text_diff_nonautobio.isna().any(axis=1)]
text_diff_nonautobio_MH = text_diff_nonautobio[text_diff_nonautobio['condition'] == 'MH']
text_diff_nonautobio_ML = text_diff_nonautobio[text_diff_nonautobio['condition'] == 'ML']
# text_diff_nonautobio_MH = phq9_diff_text[(phq9_diff_text['autobio'] == False) & (phq9_diff_text['condition'] == 'MH')]
# text_diff_nonautobio_ML = phq9_diff_text[(phq9_diff_text['autobio'] == False) & (phq9_diff_text['condition'] == 'ML')]
p_corr = 2
# p_corr =  len( var_to_plot) * 4
for pc.j, var_to_plot, plot_title in zip(range(len(to_plot)), to_plot, to_plot_titles):

    pc.i = 0
    sns.regplot(text_diff_nonautobio_MH, x=var_to_plot, y='phq9_q2', color='tab:blue', label='MH', ax=pc.ax)
    sns.regplot(text_diff_nonautobio_ML, x=var_to_plot, y='phq9_q2', color='tab:orange', label='ML', ax=pc.ax)
    if pc.j == 0:
        pc.ax.set_ylabel('Q2 diff: FU-Baseline')
        pc.ax.legend()
    else:
        pc.ax.set_ylabel(None)

    # if pc.j == 3:
    #     pc.ax.set_title('\nNon-Autobiographical\n')

    r_corr_nonabio_ML = stats.pearsonr(text_diff_nonautobio_ML['phq9_q2'], text_diff_nonautobio_ML[var_to_plot])[0]
    p_corr_nonabio_ML = min(
        stats.pearsonr(text_diff_nonautobio_ML['phq9_q2'], text_diff_nonautobio_ML[var_to_plot])[1] * p_corr,
        1)
    r_corr_nonabio_MH = stats.pearsonr(text_diff_nonautobio_MH['phq9_q2'], text_diff_nonautobio_MH[var_to_plot])[0]
    p_corr_nonabio_MH = min(
        stats.pearsonr(text_diff_nonautobio_MH['phq9_q2'], text_diff_nonautobio_MH[var_to_plot])[1] * p_corr,
        1)

    tmp_t = f'Non-Autobiographical\nML: r:{r_corr_nonabio_ML:.2e}, p:{p_corr_nonabio_ML:.2e}\nMH: r:{r_corr_nonabio_MH:.2e}, p:{p_corr_nonabio_MH:.2e}'
    pc.ax.set_title(tmp_t)

    pc.i = 1
    sns.regplot(text_diff_autobio_MH, x=var_to_plot, y='phq9_q2', color='tab:blue', label='MH', ax=pc.ax)
    sns.regplot(text_diff_autobio_ML, x=var_to_plot, y='phq9_q2', color='tab:orange', label='ML', ax=pc.ax)
    if pc.j == 0:
        pc.ax.set_ylabel('Q2 diff: FU-Baseline')
        pc.ax.legend()
    else:
        pc.ax.set_ylabel(None)
    if pc.j == 3:
        pc.ax.set_title('Autobiographical\n')

    r_corr_abio_ML = stats.pearsonr(text_diff_autobio_ML['phq9_q2'], text_diff_autobio_ML[var_to_plot])[0]
    p_corr_abio_ML = min(
        stats.pearsonr(text_diff_autobio_ML['phq9_q2'], text_diff_autobio_ML[var_to_plot])[1] * p_corr,
        1)
    r_corr_abio_MH = stats.pearsonr(text_diff_autobio_MH['phq9_q2'], text_diff_autobio_MH[var_to_plot])[0]
    p_corr_abio_MH = min(
        stats.pearsonr(text_diff_autobio_MH['phq9_q2'], text_diff_autobio_MH[var_to_plot])[1] * p_corr,
        1)

    tmp_t = f'Autobiographical\nML: r:{r_corr_abio_ML:.2e}, p:{p_corr_abio_ML:.2e}\nMH: r:{r_corr_abio_MH:.2e}, p:{p_corr_abio_MH:.2e}'
    pc.ax.set_title(tmp_t)

plt.tight_layout()
if savePlot:
    plt.savefig(f"{plots_path}q2Diff-vs-TextFeatures.pdf", dpi=300)

# %% Plot Q2 Diff against stuff binned
to_plot = ['avgRecAct_sim', 'avgBaselinePospert_sim', 'moodBaselineAct_sim', 'actPospert_sim', 'q2Mood_sim',
           'q2Act_sim', 'q2Pospert_sim']
to_plot_titles = ['Avg Recreated vs Act', 'Avg Baseline vs Pospert', 'Mood Baseline vs Act', 'Act vs Pospert',
                  'Q2 statement vs Baseline Mood', 'Q2 Statement vs Act', 'Q2 Statement vs Pospert']

for s_bin3 in sbin3_order:
    plt.close('all')
    pc.ax_ts(10, 1.1)
    pc.l_fs(12, 0.85)
    ts = 14
    pc.xyt_ls(ts, 10)
    pc.ax_ls(12)
    pc.kde_lw = 3
    pc.p_lab_spec[2] = 12
    pc.p_lab_spec[0] = -0.1
    pc.p_lab_spec[1] = 1.1
    pc.dpi_val = 300
    pc.r, pc.c, pc.mlt = 2, len(to_plot), 2
    pc.figsize = ((pc.c + 5.25) * pc.mlt, (pc.r + 1.25) * pc.mlt)

    fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize, sharey=True)
    # axes = np.array([[axes]])
    # axes = np.array([axes]).T
    pc.axes = axes
    pc.i, pc.j = 0, 0

    ssize = 5.75
    slw = 0.2
    alpha = 0.5

    text_diff_autobio = phq9_diff_text[(phq9_diff_text['autobio'] == True) & (phq9_diff_text['s_bin3'] == s_bin3)]
    text_diff_autobio = text_diff_autobio[~text_diff_autobio.isna().any(axis=1)]
    text_diff_autobio_MH = text_diff_autobio[text_diff_autobio['condition'] == 'MH']
    text_diff_autobio_ML = text_diff_autobio[text_diff_autobio['condition'] == 'ML']
    # text_diff_autobio_MH = phq9_diff_text[(phq9_diff_text['autobio'] == True) & (phq9_diff_text['condition'] == 'MH')]
    # text_diff_autobio_ML = phq9_diff_text[(phq9_diff_text['autobio'] == True) & (phq9_diff_text['condition'] == 'ML')]
    text_diff_nonautobio = phq9_diff_text[(phq9_diff_text['autobio'] == False) & (phq9_diff_text['s_bin3'] == s_bin3)]
    text_diff_nonautobio = text_diff_nonautobio[~text_diff_nonautobio.isna().any(axis=1)]
    text_diff_nonautobio_MH = text_diff_nonautobio[text_diff_nonautobio['condition'] == 'MH']
    text_diff_nonautobio_ML = text_diff_nonautobio[text_diff_nonautobio['condition'] == 'ML']
    # text_diff_nonautobio_MH = phq9_diff_text[(phq9_diff_text['autobio'] == False) & (phq9_diff_text['condition'] == 'MH')]
    # text_diff_nonautobio_ML = phq9_diff_text[(phq9_diff_text['autobio'] == False) & (phq9_diff_text['condition'] == 'ML')]
    p_corr = 2
    # p_corr =  len( var_to_plot) * 4
    for pc.j, var_to_plot, plot_title in zip(range(len(to_plot)), to_plot, to_plot_titles):

        pc.i = 0
        sns.regplot(text_diff_nonautobio_MH, x=var_to_plot, y='phq9_q2', color='tab:blue', label='MH', ax=pc.ax)
        sns.regplot(text_diff_nonautobio_ML, x=var_to_plot, y='phq9_q2', color='tab:orange', label='ML', ax=pc.ax)
        if pc.j == 0:
            pc.ax.set_ylabel('Q2 diff: FU-Baseline')
            pc.ax.legend()
        else:
            pc.ax.set_ylabel(None)

        # if pc.j == 3:
        #     pc.ax.set_title('\nNon-Autobiographical\n')

        r_corr_nonabio_ML = stats.pearsonr(text_diff_nonautobio_ML['phq9_q2'], text_diff_nonautobio_ML[var_to_plot])[0]
        p_corr_nonabio_ML = min(
            stats.pearsonr(text_diff_nonautobio_ML['phq9_q2'], text_diff_nonautobio_ML[var_to_plot])[1] * p_corr,
            1)
        r_corr_nonabio_MH = stats.pearsonr(text_diff_nonautobio_MH['phq9_q2'], text_diff_nonautobio_MH[var_to_plot])[0]
        p_corr_nonabio_MH = min(
            stats.pearsonr(text_diff_nonautobio_MH['phq9_q2'], text_diff_nonautobio_MH[var_to_plot])[1] * p_corr,
            1)

        tmp_t = f'Non-Autobiographical\nML: r:{r_corr_nonabio_ML:.2e}, p:{p_corr_nonabio_ML:.2e}\nMH: r:{r_corr_nonabio_MH:.2e}, p:{p_corr_nonabio_MH:.2e}'
        pc.ax.set_title(tmp_t)

        pc.i = 1
        sns.regplot(text_diff_autobio_MH, x=var_to_plot, y='phq9_q2', color='tab:blue', label='MH', ax=pc.ax)
        sns.regplot(text_diff_autobio_ML, x=var_to_plot, y='phq9_q2', color='tab:orange', label='ML', ax=pc.ax)
        if pc.j == 0:
            pc.ax.set_ylabel('Q2 diff: FU-Baseline')
            pc.ax.legend()
        else:
            pc.ax.set_ylabel(None)
        if pc.j == 3:
            pc.ax.set_title('Autobiographical\n')

        r_corr_abio_ML = stats.pearsonr(text_diff_autobio_ML['phq9_q2'], text_diff_autobio_ML[var_to_plot])[0]
        p_corr_abio_ML = min(
            stats.pearsonr(text_diff_autobio_ML['phq9_q2'], text_diff_autobio_ML[var_to_plot])[1] * p_corr,
            1)
        r_corr_abio_MH = stats.pearsonr(text_diff_autobio_MH['phq9_q2'], text_diff_autobio_MH[var_to_plot])[0]
        p_corr_abio_MH = min(
            stats.pearsonr(text_diff_autobio_MH['phq9_q2'], text_diff_autobio_MH[var_to_plot])[1] * p_corr,
            1)

        tmp_t = f'Autobiographical\nML: r:{r_corr_abio_ML:.2e}, p:{p_corr_abio_ML:.2e}\nMH: r:{r_corr_abio_MH:.2e}, p:{p_corr_abio_MH:.2e}'
        pc.ax.set_title(tmp_t)

    plt.suptitle(f'{s_bin3}')
    plt.tight_layout()
    if savePlot:
        plt.savefig(f"{plots_path}q2Diff-vs-TextFeatures_{s_bin3}.pdf", dpi=300)

# %% Plot Mood Diff against stuff
to_plot = ['avgRecAct_sim', 'avgBaselinePospert_sim', 'moodBaselineAct_sim', 'actPospert_sim', 'q2Mood_sim',
           'q2Act_sim', 'q2Pospert_sim']
to_plot_titles = ['Avg Recreated vs Act', 'Avg Baseline vs Pospert', 'Mood Baseline vs Act', 'Act vs Pospert',
                  'Q2 statement vs Baseline Mood', 'Q2 Statement vs Act', 'Q2 Statement vs Pospert']
plt.close('all')
pc.ax_ts(10, 1.1)
pc.l_fs(12, 0.85)
ts = 14
pc.xyt_ls(ts, 10)
pc.ax_ls(12)
pc.kde_lw = 3
pc.p_lab_spec[2] = 12
pc.p_lab_spec[0] = -0.1
pc.p_lab_spec[1] = 1.1
pc.dpi_val = 300
pc.r, pc.c, pc.mlt = 1, len(to_plot), 2
pc.figsize = ((pc.c + 5.25) * pc.mlt, (pc.r + 1.25) * pc.mlt)

fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize, sharey=True)
# axes = np.array([[axes]])
# axes = np.array([axes]).T
pc.axes = axes
pc.i, pc.j = 0, 0

ssize = 5.75
slw = 0.2
alpha = 0.5

text_diff_autobio = mood_diff_text[mood_diff_text['autobio'] == True]
text_diff_autobio = text_diff_autobio[~text_diff_autobio.isna().any(axis=1)]
text_diff_autobio_MH = text_diff_autobio[text_diff_autobio['condition'] == 'MH']
text_diff_autobio_ML = text_diff_autobio[text_diff_autobio['condition'] == 'ML']
# text_diff_autobio_MH = mood_diff_text[(mood_diff_text['autobio'] == True) & (mood_diff_text['condition'] == 'MH')]
# text_diff_autobio_ML = mood_diff_text[(mood_diff_text['autobio'] == True) & (mood_diff_text['condition'] == 'ML')]
# text_diff_nonautobio = mood_diff_text[mood_diff_text['autobio'] == False]
# text_diff_nonautobio = text_diff_nonautobio[~text_diff_nonautobio.isna().any(axis=1)]
# text_diff_nonautobio_MH = text_diff_nonautobio[text_diff_nonautobio['condition'] == 'MH']
# text_diff_nonautobio_ML = text_diff_nonautobio[text_diff_nonautobio['condition'] == 'ML']
# text_diff_nonautobio_MH = mood_diff_text[(mood_diff_text['autobio'] == False) & (mood_diff_text['condition'] == 'MH')]
# text_diff_nonautobio_ML = mood_diff_text[(mood_diff_text['autobio'] == False) & (mood_diff_text['condition'] == 'ML')]
p_corr = 2
# p_corr =  len( var_to_plot) * 4
for pc.j, var_to_plot, plot_title in zip(range(len(to_plot)), to_plot, to_plot_titles):

    # pc.i = 0
    # sns.regplot(text_diff_nonautobio_MH, x=var_to_plot, y='mood_diff', color='tab:blue', label='MH', ax=pc.ax)
    # sns.regplot(text_diff_nonautobio_ML, x=var_to_plot, y='mood_diff', color='tab:orange', label='ML', ax=pc.ax)
    # if pc.j == 0:
    #     pc.ax.set_ylabel('Q2 diff: FU-Baseline')
    #     pc.ax.legend()
    # else:
    #     pc.ax.set_ylabel(None)
    #
    # # if pc.j == 3:
    # #     pc.ax.set_title('\nNon-Autobiographical\n')
    #
    # r_corr_nonabio_ML = stats.pearsonr(text_diff_nonautobio_ML['mood_diff'], text_diff_nonautobio_ML[var_to_plot])[0]
    # p_corr_nonabio_ML = min(
    #     stats.pearsonr(text_diff_nonautobio_ML['mood_diff'], text_diff_nonautobio_ML[var_to_plot])[1] * p_corr,
    #     1)
    # r_corr_nonabio_MH = stats.pearsonr(text_diff_nonautobio_MH['mood_diff'], text_diff_nonautobio_MH[var_to_plot])[0]
    # p_corr_nonabio_MH = min(
    #     stats.pearsonr(text_diff_nonautobio_MH['mood_diff'], text_diff_nonautobio_MH[var_to_plot])[1] * p_corr,
    #     1)
    #
    # tmp_t = f'Non-Autobiographical\nML: r:{r_corr_nonabio_ML:.2e}, p:{p_corr_nonabio_ML:.2e}\nMH: r:{r_corr_nonabio_MH:.2e}, p:{p_corr_nonabio_MH:.2e}'
    # pc.ax.set_title(tmp_t)

    # pc.i = 1
    sns.regplot(text_diff_autobio_MH, x=var_to_plot, y='mood_diff', color='tab:blue', label='MH', ax=pc.ax)
    sns.regplot(text_diff_autobio_ML, x=var_to_plot, y='mood_diff', color='tab:orange', label='ML', ax=pc.ax)
    if pc.j == 0:
        pc.ax.set_ylabel('Mood diff: FU-Baseline')
        pc.ax.legend()
    else:
        pc.ax.set_ylabel(None)
    if pc.j == 3:
        pc.ax.set_title('Autobiographical\n')

    r_corr_abio_ML = stats.pearsonr(text_diff_autobio_ML['mood_diff'], text_diff_autobio_ML[var_to_plot])[0]
    p_corr_abio_ML = min(
        stats.pearsonr(text_diff_autobio_ML['mood_diff'], text_diff_autobio_ML[var_to_plot])[1] * p_corr,
        1)
    r_corr_abio_MH = stats.pearsonr(text_diff_autobio_MH['mood_diff'], text_diff_autobio_MH[var_to_plot])[0]
    p_corr_abio_MH = min(
        stats.pearsonr(text_diff_autobio_MH['mood_diff'], text_diff_autobio_MH[var_to_plot])[1] * p_corr,
        1)

    tmp_t = f'Autobiographical\nML: r:{r_corr_abio_ML:.2e}, p:{p_corr_abio_ML:.2e}\nMH: r:{r_corr_abio_MH:.2e}, p:{p_corr_abio_MH:.2e}'
    pc.ax.set_title(tmp_t)

plt.tight_layout()
if savePlot:
    plt.savefig(f"{plots_path}moodDiff-vs-TextFeatures.pdf", dpi=300)

# %% Plot Mood Diff against stuff binned
to_plot = ['avgRecAct_sim', 'avgBaselinePospert_sim', 'moodBaselineAct_sim', 'actPospert_sim', 'q2Mood_sim',
           'q2Act_sim', 'q2Pospert_sim']
to_plot_titles = ['Avg Recreated vs Act', 'Avg Baseline vs Pospert', 'Mood Baseline vs Act', 'Act vs Pospert',
                  'Q2 statement vs Baseline Mood', 'Q2 Statement vs Act', 'Q2 Statement vs Pospert']

for s_bin3 in sbin3_order:
    plt.close('all')
    pc.ax_ts(10, 1.1)
    pc.l_fs(12, 0.85)
    ts = 14
    pc.xyt_ls(ts, 10)
    pc.ax_ls(12)
    pc.kde_lw = 3
    pc.p_lab_spec[2] = 12
    pc.p_lab_spec[0] = -0.1
    pc.p_lab_spec[1] = 1.1
    pc.dpi_val = 300
    pc.r, pc.c, pc.mlt = 1, len(to_plot), 2
    pc.figsize = ((pc.c + 5.25) * pc.mlt, (pc.r + 1.25) * pc.mlt)

    fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize, sharey=True)
    # axes = np.array([[axes]])
    # axes = np.array([axes]).T
    pc.axes = axes
    pc.i, pc.j = 0, 0

    ssize = 5.75
    slw = 0.2
    alpha = 0.5

    text_diff_autobio = mood_diff_text[(mood_diff_text['autobio'] == True) & (mood_diff_text['s_bin3'] == s_bin3)]
    text_diff_autobio = text_diff_autobio[~text_diff_autobio.isna().any(axis=1)]
    text_diff_autobio_MH = text_diff_autobio[text_diff_autobio['condition'] == 'MH']
    text_diff_autobio_ML = text_diff_autobio[text_diff_autobio['condition'] == 'ML']
    # text_diff_autobio_MH = mood_diff_text[(mood_diff_text['autobio'] == True) & (mood_diff_text['condition'] == 'MH')]
    # text_diff_autobio_ML = mood_diff_text[(mood_diff_text['autobio'] == True) & (mood_diff_text['condition'] == 'ML')]
    # text_diff_nonautobio = mood_diff_text[(mood_diff_text['autobio'] == False) & (mood_diff_text['s_bin3'] == s_bin3)]
    # text_diff_nonautobio = text_diff_nonautobio[~text_diff_nonautobio.isna().any(axis=1)]
    # text_diff_nonautobio_MH = text_diff_nonautobio[text_diff_nonautobio['condition'] == 'MH']
    # text_diff_nonautobio_ML = text_diff_nonautobio[text_diff_nonautobio['condition'] == 'ML']
    # text_diff_nonautobio_MH = mood_diff_text[(mood_diff_text['autobio'] == False) & (mood_diff_text['condition'] == 'MH')]
    # text_diff_nonautobio_ML = mood_diff_text[(mood_diff_text['autobio'] == False) & (mood_diff_text['condition'] == 'ML')]
    p_corr = 2
    # p_corr =  len( var_to_plot) * 4
    for pc.j, var_to_plot, plot_title in zip(range(len(to_plot)), to_plot, to_plot_titles):

        # pc.i = 0
        # sns.regplot(text_diff_nonautobio_MH, x=var_to_plot, y='phq9_q2', color='tab:blue', label='MH', ax=pc.ax)
        # sns.regplot(text_diff_nonautobio_ML, x=var_to_plot, y='phq9_q2', color='tab:orange', label='ML', ax=pc.ax)
        # if pc.j == 0:
        #     pc.ax.set_ylabel('Q2 diff: FU-Baseline')
        #     pc.ax.legend()
        # else:
        #     pc.ax.set_ylabel(None)
        #
        # # if pc.j == 3:
        # #     pc.ax.set_title('\nNon-Autobiographical\n')
        #
        # r_corr_nonabio_ML = stats.pearsonr(text_diff_nonautobio_ML['phq9_q2'], text_diff_nonautobio_ML[var_to_plot])[0]
        # p_corr_nonabio_ML = min(
        #     stats.pearsonr(text_diff_nonautobio_ML['phq9_q2'], text_diff_nonautobio_ML[var_to_plot])[1] * p_corr,
        #     1)
        # r_corr_nonabio_MH = stats.pearsonr(text_diff_nonautobio_MH['phq9_q2'], text_diff_nonautobio_MH[var_to_plot])[0]
        # p_corr_nonabio_MH = min(
        #     stats.pearsonr(text_diff_nonautobio_MH['phq9_q2'], text_diff_nonautobio_MH[var_to_plot])[1] * p_corr,
        #     1)
        #
        # tmp_t = f'Non-Autobiographical\nML: r:{r_corr_nonabio_ML:.2e}, p:{p_corr_nonabio_ML:.2e}\nMH: r:{r_corr_nonabio_MH:.2e}, p:{p_corr_nonabio_MH:.2e}'
        # pc.ax.set_title(tmp_t)
        #
        # pc.i = 1
        sns.regplot(text_diff_autobio_MH, x=var_to_plot, y='mood_diff', color='tab:blue', label='MH', ax=pc.ax)
        sns.regplot(text_diff_autobio_ML, x=var_to_plot, y='mood_diff', color='tab:orange', label='ML', ax=pc.ax)
        if pc.j == 0:
            pc.ax.set_ylabel('Mood diff: FU-Baseline')
            pc.ax.legend()
        else:
            pc.ax.set_ylabel(None)
        if pc.j == 3:
            pc.ax.set_title('Autobiographical\n')

        r_corr_abio_ML = stats.pearsonr(text_diff_autobio_ML['mood_diff'], text_diff_autobio_ML[var_to_plot])[0]
        p_corr_abio_ML = min(
            stats.pearsonr(text_diff_autobio_ML['mood_diff'], text_diff_autobio_ML[var_to_plot])[1] * p_corr,
            1)
        r_corr_abio_MH = stats.pearsonr(text_diff_autobio_MH['mood_diff'], text_diff_autobio_MH[var_to_plot])[0]
        p_corr_abio_MH = min(
            stats.pearsonr(text_diff_autobio_MH['mood_diff'], text_diff_autobio_MH[var_to_plot])[1] * p_corr,
            1)

        tmp_t = f'Autobiographical\nML: r:{r_corr_abio_ML:.2e}, p:{p_corr_abio_ML:.2e}\nMH: r:{r_corr_abio_MH:.2e}, p:{p_corr_abio_MH:.2e}'
        pc.ax.set_title(tmp_t)

    plt.suptitle(f'{s_bin3}')
    plt.tight_layout()
    if savePlot:
        plt.savefig(f"{plots_path}moodDiff-vs-TextFeatures_{s_bin3}.pdf", dpi=300)
