# %% Import libraries and load data
import csv
import os
import pandas as pd
import numpy as np
import matplotlib
import scipy.stats as stats
from transformers import pipeline
import statsmodels.api as sm
import statsmodels.formula.api as smf
from textblob import TextBlob
import re
from difflib import get_close_matches
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity as cossim

device_name = 'mps'

matplotlib.use('Qt5Agg')
# matplotlib.use('TkAgg')
# matplotlib.get_backend()
import matplotlib.pyplot as plt

plt.ion()
import seaborn as sns

# establish working directory
cwd = os.getcwd()
cwd_split = cwd.split('/')
base_task = 'qs_intervention'
path_to_go = '_analysis'
cwd_base = '/'.join(cwd_split[:np.argwhere([p == "online_tasks" for p in cwd_split])[0][0] + 1])
os.chdir(f"{cwd_base}/{base_task}/{path_to_go}")

task_versions = ['v2', 'v2b']
# task_versions = ['v4']
#
# saveMe = False
saveMe = True
from objects.plot_config import *
from utils.utils import percentile

from objects.plot_config import *
from utils.utils import percentile
from utils.plot_utils import set_a_hist

pc = PlotConfig()
plots_path = '_plots/int/'
savePlot = True
# savePlot = False

model = SentenceTransformer("all-MiniLM-L6-v2", device=device_name)
# %% Load data and transcripts
task_versions = ['v2', 'v2b']
store_int_dfs = []
for task_version in task_versions:
    file_path = f"_data/qs-intervention-{task_version}/processed/int_data.csv"
    int_data_pd = pd.read_csv(file_path).drop(columns=['task_version'])  # [int_cols]
    store_int_dfs.append(int_data_pd)
int_data = pd.concat(store_int_dfs).reset_index(drop=True)
int_data = int_data[int_data['group'] == 'D']
int_data = int_data[~int_data.isna().any(axis=1)]
text_cols = int_data.columns[3:3 + 5]
rt_cols = int_data.columns[8:]

# act_data = int_data[['sub', 'condition', 'group', 'responses_act_0']]
int_data[text_cols] = int_data[text_cols].replace(r'\s+\.', '.', regex=True)
int_data[text_cols] = int_data[text_cols].replace(r'\s+,', ', ', regex=True)
int_data[text_cols] = int_data[text_cols].replace(r'\n+,', ' ', regex=True)
for text_col in text_cols:
    int_data[f'{text_col}_wc'] = int_data[text_col].str.split(' ').apply(lambda x: len(x))

transcripts = {'mh': [], 'ml': []}
transcripts_embd = {'mh': [], 'ml': []}
for k, v in transcripts.items():
    with open(f"_data/transcripts/{k}.txt") as f:
        for line in f:
            v.append(line)
    transcripts_embd[k] = model.encode(v)

rts_dict = {rt_col + '_avg': [] for rt_col in rt_cols}
rts_dict['sub'] = []
for r, row in int_data.iterrows():
    rt_avg = row[rt_cols].apply(lambda x: np.mean([int(x) for x in re.sub(r'\[|\]', '', x).split(', ') if x != 'None']))
    rts_dict['sub'].append(row['sub'])
    for rt_col in rt_cols:
        rts_dict[rt_col + '_avg'].append(rt_avg[rt_col])
rts_dict = pd.DataFrame(rts_dict)
int_data = pd.merge(int_data, rts_dict, on='sub').reset_index(drop=True)
# %%
rec_cols = [c for c in text_cols if 'recreate' in c]
store_sim_mats = {'mh': [], 'ml': []}
store_sub_act_embd = {'mh': [], 'ml': []}
stores_sub_avg_rec_embd = {'mh': [], 'ml': []}
for r, row in int_data.iterrows():
    condition = row['condition'].lower()
    cond_embd = transcripts_embd[condition]
    sub_texts_embd = model.encode(row[rec_cols].values.tolist())
    sub_texts_embd_avg = sub_texts_embd.mean(axis=0)
    stores_sub_avg_rec_embd[condition].append(sub_texts_embd_avg)

    sim_sub_tr = cossim(sub_texts_embd, cond_embd)
    store_sim_mats[condition].append(sim_sub_tr)
    act_emb = model.encode(row['responses_act_0'])
    store_sub_act_embd[condition].append(act_emb)

# %% Plot trasncript and recreated diaries similarity
plt.close('all')
pc.r, pc.c, pc.mlt = 2, 2, 2.75
pc.figsize = ((pc.c + 4.25) * pc.mlt, (pc.r + 1.25) * pc.mlt)

fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize)
# axes = np.array([[axes]])
# axes = np.array([axes]).T
pc.axes = axes
pc.ax_ts(16, 1.1)
pc.l_fs(12, 0.85)
ts = 14
pc.xyt_ls(ts, 10)
pc.ax_ls(16)
pc.kde_lw = 3
pc.p_lab_spec[2] = 12
pc.p_lab_spec[0] = -0.1
pc.p_lab_spec[1] = 1.1
pc.dpi_val = 300

pc.onerow = False
pc.i = 0
pc.j = 0
mh_sim_mat = np.mean(store_sim_mats['mh'], axis=0)
mh_sim_mat_st = np.std(store_sim_mats['mh'], axis=0)
sns.heatmap(mh_sim_mat, cmap='Oranges', ax=pc.ax, annot=True)
pc.ax.set_title('MH mean')

pc.i = 1
mh_sim_mat_st = np.std(store_sim_mats['mh'], axis=0)
sns.heatmap(mh_sim_mat_st, cmap='Oranges', ax=pc.ax, annot=True)
pc.ax.set_title('MH std')

pc.i = 0
pc.j = 1
ml_sim_mat = np.mean(store_sim_mats['ml'], axis=0)
sns.heatmap(ml_sim_mat, cmap='Blues', ax=pc.ax, annot=True)
pc.ax.set_title('ML mean')

pc.i = 1
ml_sim_mat_st = np.std(store_sim_mats['ml'], axis=0)
sns.heatmap(ml_sim_mat_st, cmap='Blues', ax=pc.ax, annot=True)
pc.ax.set_title('ML std')

plt.suptitle(f'Mean and Std of pairwise cosine sim. between transcripts and recreated diaries')
plt.tight_layout()
if savePlot:
    plt.savefig(f"{plots_path}diary-transcript_sims.pdf", dpi=300)

# %% Reaction times
int_rts_pd = int_data[['sub', 'group', 'condition'] + rts_dict.columns.tolist()[:-1]]
int_rts_pd = pd.melt(int_rts_pd, id_vars=['sub', 'group', 'condition'], var_name='timepoint', value_name='rt')

plt.close('all')
pc.r, pc.c, pc.mlt = 1, 1, 2.75
pc.figsize = ((pc.c + 4.25) * pc.mlt, (pc.r + 1.25) * pc.mlt)

fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize)
pc.onerow = True
pc.i = 0
pc.j = 0
axes = np.array([[axes]])
# axes = np.array([axes]).T
pc.axes = axes
pc.ax_ts(16, 1.1)
pc.l_fs(12, 0.85)
ts = 14
pc.xyt_ls(ts, 10)
pc.ax_ls(16)
pc.kde_lw = 3
pc.p_lab_spec[2] = 12
pc.p_lab_spec[0] = -0.1
pc.p_lab_spec[1] = 1.1
pc.dpi_val = 300

ssize = 5.75
slw = 0.2
alpha = 0.5
hue_cols = ['tab:blue', 'tab:orange']
bp = sns.boxplot(data=int_rts_pd, x='timepoint', y='rt', hue='condition', width=0.4,
                 linewidth=slw + 1, ax=pc.ax, legend=True, gap=0.2)

# for patch in bp.patches:
#     face_color = patch.get_facecolor()
#     patch.set_facecolor((*face_color[:3], alpha))
sns.violinplot(data=int_rts_pd, x='timepoint', y='rt', hue='condition', width=0.4, linewidth=slw + 1, ax=pc.ax,
               legend=False, gap=0.2)
sns.stripplot(data=int_rts_pd, x='timepoint', y='rt', hue='condition', edgecolor='black', linewidth=slw, dodge=True,
              jitter=0.05, size=ssize, legend=False, ax=pc.ax)

plt.tight_layout()
if savePlot:
    plt.savefig(f"{plots_path}int_rts.pdf", dpi=300)

# %% Plot pariwse similairity between act texts - mh and ml togetehr
mh_act_emb = np.array(store_sub_act_embd['mh'])
ml_act_emb = np.array(store_sub_act_embd['ml'])

act_emb_all = np.concatenate((mh_act_emb, ml_act_emb))
act_emb_sim = cossim(act_emb_all, act_emb_all)

plt.close('all')
pc.r, pc.c, pc.mlt = 1, 1, 2.75
pc.figsize = ((pc.c + 2.25) * pc.mlt, (pc.r + 1.25) * pc.mlt)

fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize)
pc.onerow = True
pc.i = 0
pc.j = 0
axes = np.array([[axes]])
# axes = np.array([axes]).T
pc.axes = axes
pc.ax_ts(16, 1.1)
pc.l_fs(12, 0.85)
ts = 14
pc.xyt_ls(ts, 10)
pc.ax_ls(16)
pc.kde_lw = 3
pc.p_lab_spec[2] = 12
pc.p_lab_spec[0] = -0.1
pc.p_lab_spec[1] = 1.1
pc.dpi_val = 300

ssize = 5.75
slw = 0.2
alpha = 0.5
sns.heatmap(act_emb_sim, cmap='Blues', ax=pc.ax)
pc.ax.set_aspect('equal', 'box')
pc.ax.set_title(f"pairwise sub created diaries similiarity\n MH: 0- {mh_act_emb.shape[0]}; ML- rest")
plt.tight_layout()
if savePlot:
    plt.savefig(f"{plots_path}act_sub_sims.pdf", dpi=300)

# %% Plot avg recreate vs act embedding

act_avg_rec_sim = {c: [cossim(e_a[:, None].T, e_r[:, None].T)[0, 0] for e_a, e_r in
                       zip(store_sub_act_embd[c], stores_sub_avg_rec_embd[c])] for c in ['mh', 'ml']}
d_mh = pd.DataFrame(act_avg_rec_sim['mh'], columns=['sim'])
d_mh.insert(1, 'condition', 'mh')
d_ml = pd.DataFrame(act_avg_rec_sim['ml'], columns=['sim'])
d_ml.insert(1, 'condition', 'ml')

act_avg_rec_sim_df = pd.concat([d_mh, d_ml], axis=0)

plt.close('all')
pc.r, pc.c, pc.mlt = 1, 1, 2.75
pc.figsize = ((pc.c + 4.25) * pc.mlt, (pc.r + 1.25) * pc.mlt)

fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize)
pc.onerow = True
pc.i = 0
pc.j = 0
axes = np.array([[axes]])
# axes = np.array([axes]).T
pc.axes = axes
pc.ax_ts(16, 1.1)
pc.l_fs(12, 0.85)
ts = 14
pc.xyt_ls(ts, 10)
pc.ax_ls(16)
pc.kde_lw = 3
pc.p_lab_spec[2] = 12
pc.p_lab_spec[0] = -0.1
pc.p_lab_spec[1] = 1.1
pc.dpi_val = 300

ssize = 5.75
slw = 0.2
alpha = 0.5
bp = sns.boxplot(data=act_avg_rec_sim_df, x='condition', y='sim', width=0.4,
                 linewidth=slw + 1, ax=pc.ax, legend=True, gap=0.2)

for patch in bp.patches:
    face_color = patch.get_facecolor()
    patch.set_facecolor((*face_color[:3], alpha))
sns.violinplot(data=act_avg_rec_sim_df, x='condition', y='sim', width=0.4, linewidth=slw + 1, ax=pc.ax,
               legend=False, gap=0.2)
sns.stripplot(data=act_avg_rec_sim_df, x='condition', y='sim', edgecolor='black', linewidth=slw, dodge=True,
              jitter=0.05, size=ssize, legend=False, ax=pc.ax)

#
pc.ax.set_ylabel('Cosine similarity')
pc.ax.set_title('Similarity between average embeddings of recreated text and created diary')
plt.tight_layout()
if savePlot:
    plt.savefig(f"{plots_path}avg-recreat_vs_create_cossim.pdf", dpi=300)

act_avg_rec_sim_df.to_csv(f"_data/avg-recreat_vs_create_cossim.csv", index=True,index_label='sub')
# %% Get open q data
task_versions = ['v2', 'v2b']
store_openq_dfs = []
for task_version in task_versions:
    file_path = f"_data/qs-intervention-{task_version}/processed/openq_data.csv"
    openq_data_pd = pd.read_csv(file_path).drop(columns=['task_version'])  # [openq_cols]
    store_openq_dfs.append(openq_data_pd)

openq_data = pd.concat(store_openq_dfs).reset_index(drop=True)
openq_data = openq_data[openq_data['group'] == 'D']
openq_data = openq_data[~openq_data.isna().any(axis=1)]

openq_cols = ['oq_energy_text', 'oq_mood_text', 'oq_pospert_text']
openq_data[openq_cols] = openq_data[openq_cols].replace(r'\s+\.', '.', regex=True)
openq_data[openq_cols] = openq_data[openq_cols].replace(r'\s+,', ', ', regex=True)
openq_data[openq_cols] = openq_data[openq_cols].replace(r'\n+,', ' ', regex=True)

for openq_col in openq_cols:
    openq_data[f'{openq_col}_wc'] = openq_data[openq_col].str.split(' ').apply(lambda x: len(x))

store_openq_sims = {'mh': [], 'ml': []}
store_openq_embs = {'mh': {'e':[],'m':[],'p':[]}, 'ml':{'e':[],'m':[],'p':[]}}
for r, row in openq_data.iterrows():
    condition = row['condition'].lower()
    mood_emb = model.encode(row['oq_mood_text'])
    energy_emb = model.encode(row['oq_energy_text'])
    baseline_embd = np.array([mood_emb, energy_emb]).mean(axis=0)
    pospert_emb = model.encode(row['oq_pospert_text'])

    store_openq_embs[condition]['m'].append(mood_emb)
    store_openq_embs[condition]['e'].append(energy_emb)
    store_openq_embs[condition]['p'].append(pospert_emb)

    tmp_sim = cossim(baseline_embd[:,None].T, pospert_emb[:,None].T)[0,0]
    store_openq_sims[condition].append(tmp_sim)

#%% Baseline and pospert similarty
d_mh = pd.DataFrame(store_openq_sims['mh'], columns=['sim'])
d_mh.insert(1, 'condition', 'mh')
d_ml = pd.DataFrame(store_openq_sims['ml'], columns=['sim'])
d_ml.insert(1, 'condition', 'ml')

openq_sims_df = pd.concat([d_mh, d_ml], axis=0)

plt.close('all')
pc.r, pc.c, pc.mlt = 1, 1, 2.75
pc.onerow = True
pc.i = 0
pc.j = 0
pc.figsize = ((pc.c + 4.25) * pc.mlt, (pc.r + 1.25) * pc.mlt)

fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize)
axes = np.array([[axes]])
# axes = np.array([axes]).T
pc.axes = axes
pc.ax_ts(16, 1.1)
pc.l_fs(12, 0.85)
ts = 14
pc.xyt_ls(ts, 10)
pc.ax_ls(16)
pc.kde_lw = 3
pc.p_lab_spec[2] = 12
pc.p_lab_spec[0] = -0.1
pc.p_lab_spec[1] = 1.1
pc.dpi_val = 300

ssize = 5.75
slw = 0.2
alpha = 0.5
bp = sns.boxplot(data=openq_sims_df, x='condition', y='sim', width=0.4,
                 linewidth=slw + 1, ax=pc.ax, legend=True, gap=0.2)

for patch in bp.patches:
    face_color = patch.get_facecolor()
    patch.set_facecolor((*face_color[:3], alpha))
sns.violinplot(data=openq_sims_df, x='condition', y='sim', width=0.4, linewidth=slw + 1, ax=pc.ax,
               legend=False, gap=0.2)
sns.stripplot(data=openq_sims_df, x='condition', y='sim', edgecolor='black', linewidth=slw, dodge=True,
              jitter=0.05, size=ssize, legend=False, ax=pc.ax)

#
pc.ax.set_ylabel('Cosine similarity')
pc.ax.set_title('Similarity between average embeddings of baseline quesitons text and positive perturbation')
plt.tight_layout()
if savePlot:
    plt.savefig(f"{plots_path}baseline_vs_pospert_sim.pdf", dpi=300)

openq_sims_df.to_csv(f"_data/baseline_vs_pospert_sim.csv", index=True,index_label='sub')