# %% Import libraries and load data
import os

import pandas as pd
import numpy as np
import matplotlib
import scipy.stats as stats

matplotlib.use('Qt5Agg')
# matplotlib.get_backend()
import matplotlib.pyplot as plt

plt.ion()
import seaborn as sns

task_versions = ['v2', 'v2b', 'v3']
# task_versions = ['v3']

from _objects.plot_config import *
from _objects.configs import *
from _utils.utils import q1, q3, percentile

pc = PlotConfig()

bools = Bools()
paths = Paths(files_dir='', plots_subdir='', plots_subsubdir='')
id_cols = ['sub', 'condition', 'group', 'autobio']
hue_order = ['MH', 'ML']
hue_cols = ['tab:blue', 'tab:orange']
hue_cols2 = ['tab:green', 'tab:purple', 'tab:red']
sbin3_order = ['q33', 'm', 'q66']
join_cols = id_cols + ['s_bin', 's_bin3']

bools.saveMe = False
bools.saveMe = True
bools.savePlot = True

testCase_dir = 'test3_mood_induction'
paths.plots_path = f'{testCase_dir}/_plots/scores/'
Path(paths.plots_path).mkdir(parents=True, exist_ok=True)
paths.data_proc_dir = f'{testCase_dir}/_data/processed/'
paths.data_path = f'{testCase_dir}/_data/'

# %% Load PHQ9 scores and totals
store_phq9_dfs = []
for task_version in task_versions:
    file_path = f"{paths.data_path}qs-intervention-{task_version}/processed/phq9_data.csv"
    # phq9_data_pd = pd.read_csv(file_path).drop(columns=['task_version'])  # [phq9_cols]
    phq9_data_pd = pd.read_csv(file_path)  # [phq9_cols]
    store_phq9_dfs.append(phq9_data_pd)
phq9_data = pd.concat(store_phq9_dfs).reset_index(drop=True)
phq9_data = phq9_data[phq9_data['group'] == 'D']
# phq9_data = phq9_data[~phq9_data.isna().any(axis=1)]
phq9_data.insert(4, 'autobio', phq9_data['task_version'] == 'qs-intervention-v3')
phq9_data.drop(columns=['task_version'], inplace=True)

phq9_totals = phq9_data.groupby(id_cols, as_index=False)['score_b'].sum().rename(
    columns={'score_b': 's_total'})
phq9_totals_bf = phq9_data.groupby(id_cols, as_index=False)[
    ['score_b', 'score_fu']].sum().rename(
    columns={'score_b': 'baseline', 'score_fu': 'fu'})
phq9_totals_bf['total_diff'] = phq9_totals_bf['fu'] - phq9_totals_bf['baseline']

phq9_diff_data_wide = phq9_data.pivot(index=id_cols, columns='question',
                                      values='score_diff').reset_index()
phq9_diff_data_wide = pd.merge(phq9_diff_data_wide, phq9_totals, on=id_cols)

# Add phq9 bins (q1, m, q3 quantiles)
qtls = phq9_diff_data_wide['s_total'].apply({'q1': q1, 'q3': q3})
qtls3 = phq9_diff_data_wide['s_total'].apply({'q33': percentile(0.33), 'q66': percentile(0.66)})
phq9_diff_data_wide.insert(4, 's_bin', phq9_diff_data_wide['s_total'].apply(
    lambda x: 'q1' if x < qtls['q1'] else ('q3' if x > qtls['q3'] else 'm')))

phq9_diff_data_wide.insert(5, 's_bin3', phq9_diff_data_wide['s_total'].apply(
    lambda x: 'q33' if x < qtls3['q33'] else ('q66' if x > qtls3['q66'] else 'm')))

phq9_diff_data_wide.to_csv(f'{paths.data_proc_dir}phq9_diff_data_wide.csv', index=False)
phq9_data.to_csv(f'{paths.data_proc_dir}phq9_data.csv', index=False)
# Remove outliers for q2
# pcols = [c for c in phq9_diff_data_wide.columns if 'phq9_q' in c]
pcols = ['phq9_q2']
out_subs = []
for (cond, autobio), g_df in phq9_diff_data_wide.groupby(['condition', 'autobio'])[['sub'] + pcols]:
    sig = g_df[pcols].std()
    mu = g_df[pcols].mean()
    tmp_thr = mu + 3 * sig
    tmp_sub_list = g_df[(g_df[pcols].abs() > tmp_thr).any(axis=1)]['sub'].to_list()
    out_subs.extend(tmp_sub_list)
out_subs = list(set(out_subs))

sub_outliers = out_subs

# phq9_data = phq9_data[~phq9_data['sub'].isin(sub_outliers)]
# phq9_diff_data_wide = phq9_diff_data_wide[~phq9_diff_data_wide['sub'].isin(sub_outliers)]
# phq9_diff_data_wide.to_csv(f'{paths.data_proc_dir}phq9_diff_data_wide_outRem.csv', index=False)
# phq9_data.to_csv(f'{paths.data_proc_dir}phq9_data_outRem.csv', index=False)

# %% Load Mood data
store_mood_dfs = []
for task_version in task_versions[-1:]:
    file_path = f"{paths.data_path}qs-intervention-{task_version}/processed/mood_data.csv"
    # mood_data_pd = pd.read_csv(file_path).drop(columns=['task_version'])  # [mood_cols]
    mood_data_pd = pd.read_csv(file_path)  # [mood_cols]
    store_mood_dfs.append(mood_data_pd)
mood_data = pd.concat(store_mood_dfs).reset_index(drop=True)
mood_data = mood_data[mood_data['group'] == 'D']
# mood_data = mood_data[~mood_data.isna().any(axis=1)]
mood_data.insert(4, 'autobio', mood_data['task_version'] == 'qs-intervention-v3')
mood_data.drop(columns=['task_version', 'question'], inplace=True)
mood_data = pd.merge(phq9_diff_data_wide[join_cols], mood_data, on=id_cols)
mood_data.to_csv(f'{paths.data_proc_dir}mood_data.csv', index=False)

# %% Load Feedback data
store_feedback_dfs = []
for task_version in task_versions:
    file_path = f"{paths.data_path}qs-intervention-{task_version}/processed/feedback_data.csv"
    # feedback_data_pd = pd.read_csv(file_path).drop(columns=['task_version'])  # [feedback_cols]
    feedback_data_pd = pd.read_csv(file_path)  # [feedback_cols]
    feedback_data_pd.insert(4, 'autobio', feedback_data_pd['task_version'] == 'qs-intervention-v3')
    feedback_data_pd.drop(columns=['task_version'], inplace=True)
    if task_version == 'v2':
        feedback_data_pd.drop(columns=['mood_change2'], inplace=True)
    feedback_data_pd_long = pd.melt(feedback_data_pd, id_vars=id_cols,
                                    value_name='score', var_name='feedback_q')
    store_feedback_dfs.append(feedback_data_pd_long)
feedback_data = pd.concat(store_feedback_dfs, axis=0).reset_index(drop=True)
# feedback_data = feedback_data[~feedback_data.isna().any(axis=1)]
feedback_data = pd.merge(phq9_diff_data_wide[join_cols], feedback_data, on=id_cols)
feedback_data.to_csv(f'{paths.data_proc_dir}feedback_data.csv', index=False)

feedback_data = feedback_data[~feedback_data['sub'].isin(sub_outliers)]

feedback_cols = ['shoes', 'sim_sit', 'mood_change_new', 'mood_change']
feedback_data_subset = feedback_data[feedback_data['feedback_q'].isin(feedback_cols)]
feedback_data_subset.loc[:, 'feedback_q'] = feedback_data_subset['feedback_q'].str.replace('mood_change_new',
                                                                                           'mood_change')
feedback_data_subset_wide = feedback_data_subset.pivot(index=join_cols,
                                                       columns='feedback_q', values='score').reset_index()
feedback_data_subset_wide.to_csv(f'{paths.data_proc_dir}feedback_data_subset_wide.csv', index=False)

# %% Merge phq9 score diffs with feedback / mood
phq9_diff_feedback_df = pd.merge(phq9_diff_data_wide, feedback_data_subset_wide,
                                 on=join_cols)
phq9_diff_feedback_df.to_csv(f'{paths.data_proc_dir}phq9-diff_feedback.csv', index=False)

phq9_diff_mood_df = pd.merge(phq9_diff_data_wide, mood_data, on=join_cols)
phq9_diff_mood_df.to_csv(f'{paths.data_proc_dir}phq9-diff_mood.csv', index=False)

#%% Remove outliers
print(f'Outliers: {sub_outliers}')
phq9_data = phq9_data[~phq9_data['sub'].isin(sub_outliers)]
phq9_diff_data_wide = phq9_diff_data_wide[~phq9_diff_data_wide['sub'].isin(sub_outliers)]
mood_data = mood_data[~mood_data['sub'].isin(sub_outliers)]
phq9_diff_mood_df = phq9_diff_mood_df[~phq9_diff_mood_df['sub'].isin(sub_outliers)]
phq9_diff_feedback_df = phq9_diff_feedback_df[~phq9_diff_feedback_df['sub'].isin(sub_outliers)]
# %% Plot mood difference for autobiographical study
plt.close('all')
pc.r, pc.c, pc.mlt = 1, 1, 2
pc.figsize = ((pc.c + 4.55) * pc.mlt, (pc.r + 1.25) * pc.mlt)

fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize)
pc.i = 0
pc.onerow = False
pc.j = 0
# axes = np.array([[axes]])
axes = np.array([axes])
pc.axes = axes
pc.ax_ts(16, 1.1)
pc.l_fs(12, 0.85)
ts = 14
pc.xyt_ls(ts, 10)
pc.ax_ls(16)
pc.kde_lw = 3
pc.p_lab_spec[2] = 12
pc.p_lab_spec[0] = -0.1
pc.p_lab_spec[1] = 1.1
pc.dpi_val = 300

# Time per question
ssize = 4.75
slw = 0.2
alpha = 0.5
sns.violinplot(data=mood_data, x='condition', y='mood_diff', inner='box', ax=pc.ax, width=0.3, legend=False,
               hue='condition', hue_order=hue_order, palette=hue_cols, order=hue_order)

sns.stripplot(data=mood_data, x='condition', y='mood_diff', edgecolor='black',
              linewidth=slw, dodge=True, jitter=0.05,
              size=ssize, legend=False, ax=pc.ax, hue='condition', hue_order=hue_order, palette=hue_cols)
N_mood = mood_data.groupby(['condition'])['mood_diff'].count()
t_extra = f' (N_mh: {N_mood['MH']}, N_ml: {N_mood['ML']})'
pc.ax.set_ylabel('VAS Mood Diff: FU - Baseline')
pc.ax.set_xlabel('Condition')
pc.ax.set_title(f'VAS Mood Diff: FU - Baseline{t_extra}')

# pc.j = 1
# sns.violinplot(data=mood_data, hue='condition', y='mood_diff', x='s_bin3', inner='box', ax=pc.ax, width=0.3,
#                palette=hue_cols, legend=True, hue_order=hue_order, order=sbin3_order)
#
# sns.stripplot(data=mood_data, hue='condition', y='mood_diff', x='s_bin3', edgecolor='black',
#               linewidth=slw, dodge=True, jitter=0.05,
#               size=ssize, legend=False, ax=pc.ax, palette=hue_cols, hue_order=hue_order, order=sbin3_order)
N_mood = mood_data.groupby(['condition'])['mood_diff'].count()
t_extra = f' (N_mh: {N_mood['MH']}, N_ml: {N_mood['ML']})'
pc.ax.set_ylabel('VAS Mood Diff: FU - Baseline')
pc.ax.set_xlabel('Condition')
pc.ax.set_title(f'VAS Mood Diff: FU - Baseline{t_extra}')
plt.tight_layout()
if bools.savePlot:
    plt.savefig(f"{paths.plots_path}mood_diff.pdf", dpi=300)


#%% Phq9 data scores
print((phq9_diff_data_wide['s_total']*3).aggregate(['mean', 'std']))
