# %% Import libraries and load data
import os

import pandas as pd
import numpy as np
import matplotlib
import scipy.stats as stats
import re
from textblob import TextBlob
from transformers import pipeline

device_name = 'mps'

matplotlib.use('Qt5Agg')
# matplotlib.use('TkAgg')
# matplotlib.get_backend()
import matplotlib.pyplot as plt

plt.ion()
import seaborn as sns


# task_versions = ['v2', 'v2b']
task_versions = ['v2', 'v2b', 'v3']

# # saveMe = False
# saveMe = True

from _objects.plot_config import *
from _objects.configs import *
# from _utils.utils import q1, q3, percentile

pc = PlotConfig()

bools = Bools()
paths = Paths(files_dir='', plots_subdir='', plots_subsubdir='')
bools.saveMe = False
bools.saveMe = True
bools.savePlot = True

testCase_dir = 'test3_mood_induction'
paths.plots_path = f'{testCase_dir}/_plots/recall/'
Path(paths.plots_path).mkdir(parents=True, exist_ok=True)
paths.data_proc_dir = f'{testCase_dir}/_data/processed/'
paths.data_path = f'{testCase_dir}/_data/'

# bools.saveFiles = False
bools.saveFiles = False
bools.savePlot = True

id_cols = ['sub', 'condition', 'group', 'autobio']
hue_order = ['MH', 'ML']
hue_cols = ['tab:blue', 'tab:orange']
hue_cols2 = ['tab:green', 'tab:purple']
cols_to_drop = [f'rts_recall{r + 1}' for r in range(3)]
sbin3_order = ['q33','m','q66']
join_cols = id_cols + ['s_bin', 's_bin3']
phq9_diff = pd.read_csv(f'{paths.data_proc_dir}phq9_diff_data_wide.csv')

# %% Load sentiment analyser
if bools.saveFiles:
    sentiment_pipeline = pipeline("sentiment-analysis", model="siebert/sentiment-roberta-large-english", device=device_name)
# %% Recall analysis
store_recall_dfs = []
for task_version in task_versions:
    file_path = f"{paths.data_path}/qs-intervention-{task_version}/processed/recall_data.csv"
    # recall_data_pd = pd.read_csv(file_path).drop(columns=['task_version'])  # [recall_cols]
    recall_data_pd = pd.read_csv(file_path)
    store_recall_dfs.append(recall_data_pd)
recall_data = pd.concat(store_recall_dfs).reset_index(drop=True)
recall_data.insert(4, 'autobio', recall_data['task_version'] == 'qs-intervention-v3')
recall_data.drop(columns=['task_version'] + cols_to_drop, inplace=True)
recall_data_long = recall_data.melt(id_vars=id_cols, var_name='recall_time', value_name='words')

recall_data_long.loc[recall_data_long.isna().any(axis=1),'words'] = ''
# recall_data_long = recall_data_long[~recall_data_long.isna().any(axis=1)]

regex = re.compile('[^a-zA-Z]')
recall_data_long['words'] = recall_data_long['words'].apply(
    lambda x: [regex.sub('', w) for w in x.split(', ')] if ',' in x else [regex.sub('', x)])

recall_data_long = pd.merge(phq9_diff[join_cols],recall_data_long, on=id_cols)
# %% Clean recalled
recall_data_long['avgSentiment'] = -100.0
# recall_scores = {f'recall{r + 1}_avgSent': [] for r in range(3)}
recall_scores = []
if bools.saveFiles:
    # for i, row in recall_data_long.iloc[:10,:].iterrows():
    for i, row in recall_data_long.iterrows():
        recall_words = row['words']
        recall_words = [str(TextBlob(w).correct()).lower() for w in recall_words]
        recall_words_bool = [(len(w) > 2) for w in recall_words]
        recall_words = [w.lower() for b, w in zip(recall_words_bool, recall_words) if b]
        outputs = sentiment_pipeline(recall_words)
        # outputs = [{'score':np.random.rand(),'label':['POSITIVE','NEGATIVE'][np.random.randint(0,2)]} for r in range(len(recall_words))]
        scores = [out['score'] * {'POSITIVE': 1, 'NEGATIVE': -1}[out['label']] for out in outputs]
        scores_avg = np.nanmean(scores)
        recall_data_long.loc[i, 'avgSentiment'] = scores_avg

    #% Calculate baseline and difference
    recall_data_wide = recall_data_long.pivot(index=id_cols,columns='recall_time', values='avgSentiment').reset_index()
    recall_data_wide.columns = recall_data_wide.columns.str.replace('responses_recall','avgSentiment')
    recall_data_wide['recall_baselineSent'] = (recall_data_wide['avgSentiment1'] + recall_data_wide['avgSentiment2']) / 2
    recall_data_wide['recall_diffSent'] = recall_data_wide['avgSentiment3'] - recall_data_wide['recall_baselineSent']

# %% Save/Load sentiment scores
file_save_path = f"{paths.data_proc_dir}recall_data.csv"
if bools.saveFiles:
    recall_data_wide.to_csv(file_save_path, index=False)
else:
    recall_data_wide = pd.read_csv(file_save_path)
# %% Plot recall sentiment
plt.close('all')
pc.r, pc.c, pc.mlt = 1, 1, 2.75
pc.figsize = ((pc.c + 2.25) * pc.mlt, (pc.r + 1.25) * pc.mlt)

fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize)
pc.onerow = True
pc.i = 0
pc.j = 0
axes = np.array([[axes]])
# axes = np.array([axes])
pc.axes = axes
pc.ax_ts(16, 1.1)
pc.l_fs(12, 0.85)
ts = 14
pc.xyt_ls(ts, 10)
pc.ax_ls(16)
pc.kde_lw = 3
pc.p_lab_spec[2] = 12
pc.p_lab_spec[0] = -0.1
pc.p_lab_spec[1] = 1.1
pc.dpi_val = 300

ssize = 5.75
slw = 0.2
alpha = 0.5
bp = sns.boxplot(data=recall_data_wide, x='autobio', y='recall_diffSent', hue='condition',width=0.4,
                 linewidth=slw + 1, ax=pc.ax, legend=True, gap=0.2,hue_order=hue_order, palette=hue_cols)
for patch in bp.patches:
    face_color = patch.get_facecolor()
    patch.set_facecolor((*face_color[:3], alpha))
sns.violinplot(data=recall_data_wide, x='autobio', y='recall_diffSent', hue='condition', width=0.4, linewidth=slw + 1, ax=pc.ax,
               legend=False, gap=0.2, palette=hue_cols,hue_order=hue_order)
sns.stripplot(data=recall_data_wide, x='autobio', y='recall_diffSent', hue='condition', edgecolor='black', linewidth=slw, dodge=True,
              jitter=0.05, size=ssize, legend=False, ax=pc.ax, palette=hue_cols,hue_order=hue_order)

pc.ax.set_ylabel('Sentiment Diff')
pc.ax.set_xlabel('Autobiographical?')
pc.ax.set_title('Recall Sentiment Diff: FU - Average Baseline')


# bp = sns.boxplot(data=recall_data_wide[recall_data_wide['autobio']==False], x='condition', y='recall_diffSent', width=0.4,
#                  linewidth=slw + 1, ax=pc.ax, legend=True, gap=0.2)
# for patch in bp.patches:
#     face_color = patch.get_facecolor()
#     patch.set_facecolor((*face_color[:3], alpha))
# sns.violinplot(data=recall_data_wide[recall_data_wide['autobio']==False], x='condition', y='recall_diffSent', width=0.4, linewidth=slw + 1, ax=pc.ax,
#                legend=True, gap=0.2, color=hue_cols[0])
# sns.stripplot(data=recall_data_wide[recall_data_wide['autobio']==False], x='condition', y='recall_diffSent', edgecolor='black', linewidth=slw, dodge=True,
#               jitter=0.05, size=ssize, legend=False, ax=pc.ax, color=hue_cols[1])
#
# pc.ax.set_ylabel('Sentiment Diff')
# pc.ax.set_title('Non-Autobiographical\nRecall Sentiment Diff: FU - Average Baseline')
#
# pc.j=1
# bp = sns.boxplot(data=recall_data_wide[recall_data_wide['autobio']==True], x='condition', y='recall_diffSent', width=0.4,
#                  linewidth=slw + 1, ax=pc.ax, legend=True, gap=0.2)
# for patch in bp.patches:
#     face_color = patch.get_facecolor()
#     patch.set_facecolor((*face_color[:3], alpha))
# sns.violinplot(data=recall_data_wide[recall_data_wide['autobio']==True], x='condition', y='recall_diffSent', width=0.4, linewidth=slw + 1, ax=pc.ax,
#                legend=True, gap=0.2, color=hue_cols[0])
# sns.stripplot(data=recall_data_wide[recall_data_wide['autobio']==True], x='condition', y='recall_diffSent', edgecolor='black', linewidth=slw, dodge=True,
#               jitter=0.05, size=ssize, legend=False, ax=pc.ax, color=hue_cols[1])
#
# pc.ax.set_ylabel('Sentiment Diff')
# pc.ax.set_title('Autobiographical\nRecall Sentiment Diff: FU - Average Baseline')



plt.tight_layout()
# if bools.savePlot:
#     plt.savefig(f"{paths.plots_path}recall_sent_diff.pdf", dpi=300)

# %%
#
# # recall_data = recall_data[recall_data['group'] == 'D']
# recall_data = recall_data[~recall_data.isna().any(axis=1)]
# # firstN = -1
# firstN = len(recall_data)
# recall_data = recall_data[:firstN].reset_index(drop=True)
#
# # %% Calculate sentiment scores
# regex = re.compile('[^a-zA-Z]')
# recall_scores = {f'recall{r + 1}_avgSent': [] for r in range(3)}
# if saveFiles:
#     for i, row in recall_data.iterrows():
#         for r in range(3):
#
#             if ',' in row[f'responses_recall{r + 1}']:
#                 recall_words = row[f'responses_recall{r + 1}'].split(', ')
#                 recall_words = [regex.sub('', w) for w in recall_words]
#             else:
#                 recall_words = [row[f'responses_recall{r + 1}']]
#                 recall_words = [regex.sub('', w) for w in recall_words]
#             recall_words = [str(TextBlob(w).correct()).lower() for w in recall_words]
#             # print(recall_words)
#
#             # recall_words_bool = [(len(w)>2 and w in engwords) for w in recall_words]
#             recall_words_bool = [(len(w) > 2) for w in recall_words]
#             recall_words = [w.lower() for b, w in zip(recall_words_bool, recall_words) if b]
#
#             #
#             # print(recall_words_bool)
#             # print(recall_words)
#             outputs = sentiment_pipeline(recall_words)
#             scores = [out['score'] * {'POSITIVE': 1, 'NEGATIVE': -1}[out['label']] for out in outputs]
#             scores_avg = np.mean(scores)
#             recall_scores[f'recall{r + 1}_avgSent'].append(scores_avg)
#             # print(row['condition'],scores_avg)
#
#     for k, v in recall_scores.items():
#         recall_data[k] = v
#     recall_data = pd.merge(recall_data, phq9_totals, on=['sub', 'condition', 'group'])
#
# # %% Calculate baselien and diffs sentiment
# recall_data['recall_baselineSent'] = (recall_data['recall1_avgSent'] + recall_data['recall2_avgSent']) / 2
# recall_data['recall_diffSent'] = recall_data['recall3_avgSent'] - recall_data['recall_baselineSent']
# recall_data_long = recall_data.melt(id_vars=['sub', 'condition', 'group'],
#                                     value_vars=[f'recall{r + 1}_avgSent' for r in range(3)], var_name='timepoint',
#                                     value_name='sentiment')
# # %% Save/Load sentiment scores
# file_save_path = f"_data/recall_data_proc.csv"
# if saveFiles:
#     recall_data.to_csv(file_save_path, index=False)
# else:
#     recall_data = pd.read_csv(file_save_path)
#
# # %% Plot recall sentiment
# plt.close('all')
# pc.r, pc.c, pc.mlt = 1, 1, 2.75
# pc.figsize = ((pc.c + 4.25) * pc.mlt, (pc.r + 1.25) * pc.mlt)
#
# fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize)
# pc.onerow = True
# pc.i = 0
# pc.j = 0
# axes = np.array([[axes]])
# # axes = np.array([axes]).T
# pc.axes = axes
# pc.ax_ts(16, 1.1)
# pc.l_fs(12, 0.85)
# ts = 14
# pc.xyt_ls(ts, 10)
# pc.ax_ls(16)
# pc.kde_lw = 3
# pc.p_lab_spec[2] = 12
# pc.p_lab_spec[0] = -0.1
# pc.p_lab_spec[1] = 1.1
# pc.dpi_val = 300
#
# ssize = 5.75
# slw = 0.2
# alpha = 0.5
# hue_cols = ['tab:blue', 'tab:orange']
# bp = sns.boxplot(data=recall_data, x='condition', y='recall_diffSent', width=0.4,
#                  linewidth=slw + 1, ax=pc.ax, legend=True, gap=0.2)
# for patch in bp.patches:
#     face_color = patch.get_facecolor()
#     patch.set_facecolor((*face_color[:3], alpha))
# sns.violinplot(data=recall_data, x='condition', y='recall_diffSent', width=0.4, linewidth=slw + 1, ax=pc.ax,
#                legend=True, gap=0.2, color=hue_cols[0])
# sns.stripplot(data=recall_data, x='condition', y='recall_diffSent', edgecolor='black', linewidth=slw, dodge=True,
#               jitter=0.05, size=ssize, legend=False, ax=pc.ax, color=hue_cols[1])
#
# plt.tight_layout()
# if savePlot:
#     plt.savefig(f"{plots_path}recall_sent.pdf", dpi=300)
#
# md = smf.mixedlm("recall_diffSent ~ condition", recall_data, groups=recall_data['sub'])
# mdf = md.fit()
# print(mdf.summary())

# # %% Ratio recall analysis
# word_list = ["vigorous", "energetic", "lively", "exhausted", "tired", "drained", "joyful", "delighted", "happy",
#              "unhappy", "hopeless", "miserable", "genuine", "wholesome", "ethical", "corrupt", "sloppy", "unsafe"]
# wld = {'mh': ["joyful", "delighted", "happy"], 'ml': ["unhappy", "hopeless", "miserable"],
#        'el': ["exhausted", "tired", "drained"], 'eh': ["vigorous", "energetic", "lively"],
#        'p': ["genuine", "wholesome", "ethical"], 'n': ["corrupt", "sloppy", "unsafe"]}
# word_list_broad = {'h': wld['mh'] + wld['eh'], 'l': wld['ml'] + wld['el'], 'ho': wld['p'], 'lo': wld['n']}
#
# # %% Ration recall analysis p2
# recall_ratio = {f'recall{r + 1}_congrRatio': [] for r in range(3)}
# recall_ratio_broad = {f'recall{r + 1}_congrRatioBroad': [] for r in range(3)}
# firstN = len(recall_data)
# # firstN = 10
# for i, row in recall_data[:firstN].iterrows():
#     for r in range(3):
#         if ',' in row[f'responses_recall{r + 1}']:
#             recall_words = row[f'responses_recall{r + 1}'].split(', ')
#             recall_words = [regex.sub('', w) for w in recall_words]
#         else:
#             recall_words = [row[f'responses_recall{r + 1}']]
#             recall_words = [regex.sub('', w) for w in recall_words]
#         # recall_words = [w.lower() for w in recall_words]
#         recall_words = [str(TextBlob(w).correct()).lower() for w in recall_words]
#         # print(recall_words)
#         matched_words = []
#         for w in recall_words:
#             match_w = get_close_matches(w, word_list)
#             if len(match_w) > 0:
#                 matched_words.append(match_w[0])
#             # else:
#             # print(w)
#         match_word_cats = []
#         match_word_cats_broad = []
#         count_cond_congr = 0
#         count_cond_congr_broad = 0
#         for m in matched_words:
#             for k, v in wld.items():
#                 if m in v:
#                     match_word_cats.append(k)
#                     if k == row['condition'].lower():
#                         count_cond_congr += 1
#             for k, v in word_list_broad.items():
#                 if m in v:
#                     match_word_cats_broad.append(k)
#                     if k == row['condition'].lower()[1]:
#                         count_cond_congr_broad += 1
#         # print(count_cond_congr)
#         if len(matched_words) > 0:
#             cond_congr_ratio = count_cond_congr / len(matched_words)
#             cond_congr_ratio_broad = count_cond_congr_broad / len(matched_words)
#         else:
#             cond_congr_ratio = 0
#             cond_congr_ratio_broad = 0
#
#         recall_ratio[f'recall{r + 1}_congrRatio'].append(cond_congr_ratio)
#         recall_ratio_broad[f'recall{r + 1}_congrRatioBroad'].append(cond_congr_ratio_broad)
#
# for k, v in recall_ratio.items():
#     recall_data[k] = v
# for k, v in recall_ratio_broad.items():
#     recall_data[k] = v
#
#
# # %% Calcualte ratio recall stuff
# recall_data['recall_baselineRatio'] = (recall_data['recall1_congrRatio'] + recall_data['recall2_congrRatio']) / 2
# recall_data['recall_diffRatio'] = recall_data['recall3_congrRatio'] - recall_data['recall_baselineRatio']
# recall_data['recall_baselineRatioBroad'] = (recall_data['recall1_congrRatioBroad'] + recall_data[
#     'recall2_congrRatioBroad']) / 2
# recall_data['recall_diffRatioBroad'] = recall_data['recall3_congrRatioBroad'] - recall_data['recall_baselineRatioBroad']
#
# recall_data['recall_diffRatioBroad'].mean()
# recall_data['recall_diffRatioBroad'].std()
# recall_data['recall_diffRatio'].mean()
# recall_data['recall_diffRatio'].std()
#
# # %% Load Ratio recall
# file_save_path_wratio = f"_data/recall_data_proc_wratio.csv"
# if saveFiles:
#     recall_data.to_csv(file_save_path_wratio, index=False)
# else:
#     recall_data = pd.read_csv(file_save_path_wratio)
#
# # %% Plot congruent ratios
# plt.close('all')
# pc.r, pc.c, pc.mlt = 1, 1, 2.75
# pc.figsize = ((pc.c + 4.25) * pc.mlt, (pc.r + 1.25) * pc.mlt)
#
# fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize)
# axes = np.array([[axes]])
# # axes = np.array([axes])
# pc.axes = axes
# pc.ax_ts(16, 1.1)
# pc.l_fs(12, 0.85)
# ts = 14
# pc.xyt_ls(ts, 10)
# pc.ax_ls(16)
# pc.kde_lw = 3
# pc.p_lab_spec[2] = 12
# pc.p_lab_spec[0] = -0.1
# pc.p_lab_spec[1] = 1.1
# pc.dpi_val = 300
#
# pc.onerow = True
# pc.i = 0
# pc.j = 0
# ssize = 5.75
# slw = 0.2
# alpha = 0.5
# hue_cols = ['tab:blue', 'tab:orange']
#
# bp = sns.boxplot(data=recall_data, x='condition', y='recall_diffRatio', width=0.4,
#                  linewidth=slw + 1, ax=pc.ax, legend=True, gap=0.2)
# for patch in bp.patches:
#     face_color = patch.get_facecolor()
#     patch.set_facecolor((*face_color[:3], alpha))
# sns.violinplot(data=recall_data, x='condition', y='recall_diffRatio', width=0.4, linewidth=slw + 1, ax=pc.ax,
#                legend=True, gap=0.2, color=hue_cols[0])
# sns.stripplot(data=recall_data, x='condition', y='recall_diffRatio', edgecolor='black', linewidth=slw, dodge=True,
#               jitter=0.05, size=ssize, legend=False, ax=pc.ax, color=hue_cols[1])
#
# # pc.j=1
# # bp = sns.boxplot(data=recall_data, x='condition',y='recall_diffRatioBroad', width=0.4,
# #                  linewidth=slw + 1, ax=pc.ax, legend=True, gap=0.2)
# # for patch in bp.patches:
# #     face_color = patch.get_facecolor()
# #     patch.set_facecolor((*face_color[:3], alpha))
# # sns.violinplot(data=recall_data,x='condition', y='recall_diffRatioBroad', width=0.4, linewidth=slw + 1, ax=pc.ax,
# #                legend=True, gap=0.2, color=hue_cols[0])
# # sns.stripplot(data=recall_data, x='condition',y='recall_diffRatioBroad', edgecolor='black', linewidth=slw, dodge=True,
# #               jitter=0.05, size=ssize, legend=False, ax=pc.ax, color=hue_cols[1])
#
#
# plt.tight_layout()
#
# if savePlot:
#     plt.savefig(f"{plots_path}recall_ratioCongr.pdf", dpi=300)
#
# md = smf.mixedlm("recall_diffRatio ~ condition", recall_data, groups=recall_data['sub'])
# mdf = md.fit()
# print(mdf.summary())
#
# # md = smf.mixedlm("recall_diffRatioBroad ~ condition", recall_data, groups=recall_data['sub'])
# # mdf = md.fit()
# # print(mdf.summary())
#
# recall_summary = recall_data.groupby(['condition', 'group'], as_index=False)[
#     ['recall_diffSent', 'recall_diffRatio', 'recall_diffRatioBroad']].agg(['mean', 'std'])
