# %% Import libraries and load data
import os

import pandas as pd
import numpy as np
import matplotlib
import scipy.stats as stats

matplotlib.use('Qt5Agg')
# matplotlib.get_backend()
import matplotlib.pyplot as plt

plt.ion()
import seaborn as sns


from _objects.plot_config import *
from _objects.configs import *

pc = PlotConfig()

bools = Bools()
paths = Paths(files_dir='', plots_subdir='', plots_subsubdir='')
id_cols = ['sub', 'condition', 'group', 'autobio']
# hue_order = ['MH', 'ML']
hue_order = ['ML', 'MH']
# hue_cols = ['tab:blue', 'tab:orange']
hue_cols = [ 'tab:orange', 'tab:blue']
hue_cols2 = ['tab:green', 'tab:purple', 'tab:red']
sbin3_order = ['q33', 'm', 'q66']
join_cols = id_cols + ['s_bin', 's_bin3']

cond_names = ['High mood','Low mood']

bools.saveMe = False
bools.saveMe = True
bools.savePlot = True

testCase_dir = 'test3_mood_induction'
paths.plots_path = f'{testCase_dir}/_plots/_scores/'
Path(paths.plots_path).mkdir(parents=True, exist_ok=True)
paths.data_proc_dir = f'{testCase_dir}/_data/processed/'
paths.data_path = f'{testCase_dir}/_data/'

sub_outliers = ['sub95_v2b', 'sub96_v2b', 'sub150_v2b', 'sub133_v2b', 'sub180_v3', 'sub131_v3']

phq9_diff = pd.read_csv(f'{paths.data_proc_dir}phq9_diff_data_wide.csv')
phq9_diff = phq9_diff[(phq9_diff['autobio']) & (~phq9_diff['sub'].isin(sub_outliers))]
mood_data = pd.read_csv(f'{paths.data_proc_dir}mood_data.csv')
mood_data = mood_data[(mood_data['autobio']) & (~mood_data['sub'].isin(sub_outliers))]
recall_data_wide = pd.read_csv(f"{paths.data_proc_dir}recall_data.csv")
recall_data_wide = recall_data_wide[(recall_data_wide['autobio']) & (~recall_data_wide['sub'].isin(sub_outliers))]
fig_no = 'Fig6'

measure_cols = ['phq9_q2','mood_diff','recall_diffSent']
all_data = pd.merge(phq9_diff,mood_data,on=join_cols)
all_data = pd.merge(all_data,recall_data_wide,on=id_cols)[id_cols+measure_cols]
all_data = all_data.melt(id_vars=id_cols, value_vars=measure_cols,var_name='measure')
#%% Plot measures
plt.close('all')
pc.r, pc.c, pc.mlt = 1, 1, 2.75
pc.figsize = ((pc.c + 1.55) * pc.mlt, (pc.r + 0.4) * pc.mlt)
fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize)
pc.i = 0
pc.j = 0
pc.onerow = False
axes = np.array([[axes]])
# axes = np.array([axes]).T
pc.axes = axes
pc.ax_ts(16, 1.1)
pc.l_fs(12, 0.85)
ts = 15
pc.xyt_ls(ts, ts)
pc.ax_ls(16)
pc.kde_lw = 3
pc.p_lab_spec[2] = 12
pc.p_lab_spec[0] = -0.1
pc.p_lab_spec[1] = 1.1
pc.dpi_val = 300
ssize = 5.75
slw = 0.2
alpha = 0.3
b_width = 0.25
v_width = 0.8
b_lw = 1.5
s_lw = 0.8
s_size = 2
s_out_size = 8
xlabs  = ['PHQ9 Q2','Momentary mood','Recall sentiment']

sns.violinplot(all_data, y='value', x='measure', hue='condition', ax=pc.ax, orient='v', inner=None,
               width=v_width, palette=hue_cols, hue_order=hue_order,legend=False,split=True)
bp = sns.boxplot(data=all_data, x='measure', y='value', width=b_width,
                 palette=hue_cols, hue='condition', hue_order=hue_order, fliersize=s_out_size,
                 linewidth=slw + 1, ax=pc.ax, legend=True, gap=0.2,dodge=True)
pc.ax.legend(loc='upper left',bbox_to_anchor=(0.2, 0.3))
for patch in bp.patches:
    face_color = patch.get_facecolor()
    patch.set_facecolor((*face_color[:3], alpha))
sns.stripplot(data=all_data, x='measure', y='value', hue='condition',
              edgecolor='black', palette=hue_cols, hue_order=hue_order,
              linewidth=slw, dodge=True, jitter=0.05,
              size=ssize, legend=False, ax=pc.ax)
pc.ax.set_title(f'Affective change measures')
pc.ax.set_ylabel('FU - Baseline')
pc.ax.set_xlabel('\nMeasure change')
pc.ax.set_xticks(range(len(xlabs)))
pc.ax.set_xticklabels(xlabs)
pc.ax.text(pc.p_lab_spec[0] + 0.05, pc.p_lab_spec[1], pc.p_labs[pc.i, pc.j], transform=pc.ax.transAxes,
           fontweight='bold',
           va='top', ha='right',
           fontsize=pc.p_lab_spec[2])

plt.tight_layout()
if bools.savePlot:
    plt.savefig(f"{paths.plots_path}{fig_no}_p1_{testCase_dir}_measures.pdf", dpi=300)

# %% Plot score diff
plt.close('all')
pc.r, pc.c, pc.mlt = 1, 3, 2.75
pc.figsize = ((pc.c + 1.25) * pc.mlt, (pc.r + 0.25) * pc.mlt)
fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize)
pc.i = 0
pc.j = 0
pc.onerow = False
# axes = np.array([[axes]])
axes = np.array([axes]).T
pc.axes = axes
pc.ax_ts(16, 1.1)
pc.l_fs(12, 0.85)
ts = 14
pc.xyt_ls(ts, 10)
pc.ax_ls(16)
pc.kde_lw = 3
pc.p_lab_spec[2] = 12
pc.p_lab_spec[0] = -0.1
pc.p_lab_spec[1] = 1.1
pc.dpi_val = 300
ssize = 5.75
slw = 0.2
alpha = 0.3
b_width = 0.2
v_width = 0.5
b_lw = 1.5
s_lw = 0.8
s_size = 2
s_out_size = 8

# Plot score diff for q2
pc.i = 0
leg_bool=False
sns.violinplot(phq9_diff, y='phq9_q2', x='condition', hue='condition', ax=pc.ax, orient='v', inner=None,
               width=v_width, palette=hue_cols, hue_order=hue_order)
bp = sns.boxplot(data=phq9_diff, x='condition', y='phq9_q2', width=b_width,
                 palette=hue_cols, hue='condition', hue_order=hue_order, fliersize=s_out_size,
                 linewidth=slw + 1, ax=pc.ax, legend=leg_bool, gap=0.2)
for patch in bp.patches:
    face_color = patch.get_facecolor()
    patch.set_facecolor((*face_color[:3], alpha))
sns.stripplot(data=phq9_diff, x='condition', y='phq9_q2', hue='condition',
              edgecolor='black', palette=hue_cols, hue_order=hue_order,
              linewidth=slw, dodge=True, jitter=0.05,
              size=ssize, legend=False, ax=pc.ax)
pc.ax.set_title(f'PHQ9 Q2 Score Differences')
pc.ax.set_ylabel('FU - Baseline')

# Plot mood diff for q2
pc.i = 1
sns.violinplot(mood_data, y='mood_diff', x='condition', hue='condition', ax=pc.ax, orient='v', inner=None,
               width=v_width, palette=hue_cols, hue_order=hue_order)
bp = sns.boxplot(data=mood_data, x='condition', y='mood_diff', width=b_width,
                 palette=hue_cols, hue='condition', hue_order=hue_order, fliersize=s_out_size,
                 linewidth=slw + 1, ax=pc.ax, legend=leg_bool, gap=0.2)
for patch in bp.patches:
    face_color = patch.get_facecolor()
    patch.set_facecolor((*face_color[:3], alpha))
sns.stripplot(data=mood_data, x='condition', y='mood_diff', hue='condition',
              edgecolor='black', palette=hue_cols, hue_order=hue_order,
              linewidth=slw, dodge=True, jitter=0.05,
              size=ssize, legend=False, ax=pc.ax)
pc.ax.set_title(f'Momentary Mood Differences')
pc.ax.set_ylabel('FU - Baseline')
pc.ax.set_xlabel('Condition')
# pc.ax.set_xticklabels(cond_names)

# Plot recall
pc.i = 2
sns.violinplot(recall_data_wide, y='recall_diffSent', x='condition', hue='condition', ax=pc.ax, orient='v', inner=None,
               width=v_width, palette=hue_cols, hue_order=hue_order)
bp = sns.boxplot(data=recall_data_wide, x='condition', y='recall_diffSent', width=b_width,
                 palette=hue_cols, hue='condition', hue_order=hue_order, fliersize=s_out_size,
                 linewidth=slw + 1, ax=pc.ax, legend=leg_bool, gap=0.2)
for patch in bp.patches:
    face_color = patch.get_facecolor()
    patch.set_facecolor((*face_color[:3], alpha))
sns.stripplot(data=recall_data_wide, x='condition', y='recall_diffSent', hue='condition',
              edgecolor='black', palette=hue_cols, hue_order=hue_order,
              linewidth=slw, dodge=True, jitter=0.05,
              size=ssize, legend=False, ax=pc.ax)
pc.ax.set_title(f'Recall sentiment Differences')
pc.ax.set_ylabel('FU - Baseline')


plt.tight_layout()
if bools.savePlot:
    plt.savefig(f"{paths.plots_path}{fig_no}_{testCase_dir}_p1_measures.pdf", dpi=300)
