# %% Import and setup
import copy
import random
from pickle import load

import torch
from natsort import natsorted
from sympy.printing.pretty.pretty_symbology import line_width
from torch import dtype
from tqdm import tqdm
from matplotlib.colors import ListedColormap

import matplotlib
import os
import pandas as pd
import re
import numpy as np
import seaborn as sns

from factor_analyzer import FactorAnalyzer
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import Ridge
from sklearn.linear_model import RidgeCV
from scipy import stats
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.decomposition import PCA

os.environ["TOKENIZERS_PARALLELISM"] = "false"

# establish working directory
cwd = os.getcwd()
cwd_split = cwd.split('/')
base_task = 'qs_intervention'
path_to_go = '_analysis/'
cwd_base = '/'.join(cwd_split[:np.argwhere([p == "online_tasks" for p in cwd_split])[0][0] + 1])
os.chdir(f"{cwd_base}/{base_task}/{path_to_go}")

# from _llm_based.utils.sampling_utils import phq9_qs_inv_map
from utils.analysis_utils import spec_levels, spec_levels_crit, spec_level_dict, spec_names, qa_loc_key_names, \
    get_responses, get_responses_avg_merged, get_responses_merged_diff, responses_totals_shuffled
from utils.plot_utils import set_scatter_axes, set_a_hist
from objects.configs import *
from objects.plot_config import *
from _llm_based.objects.model_configs import *

matplotlib.use('Qt5Agg')
import matplotlib.pyplot as plt

plt.ion()

# plots_path = '_plots/llm_based/lm_ridge/'
plots_path = '_plots/prelim/'
files_path = '_llm_based/files/'
# responses_path = f"{files_path}responses/subjects/"
# data_path = '_data/'
# states_path = '_llm_based/states/'
# states_path = f"{files_path}states/subjects"
hs_path = f'{files_path}act_dynamics/timepoint_hs/'

pc = PlotConfig()
paths = Paths(files_dir='files', plots_subdir='', plots_subsubdir='')
sample_config = SampleConfig(paths, model_name='MistralOo', qs_name='phq9', instr_name='instr2', temp='', top_p='')
# change the order where last layer is last
sample_config.layer_list = sample_config.layer_list[1:] + sample_config.layer_list[:1]

# loadMe = False
# saveMe = True
loadMe = True
saveMe = False
savePlot = True
# savePlot = False
# %% Load models and hs
with open(f'_llm_based/saved_models/avg_pca_PCA.pkl', "rb") as f:
    pca = load(f)
with open(f'_llm_based/saved_models/avg_pca_RidgeReg.pkl', "rb") as f:
    clf = load(f)

task_versions = ['v2', 'v2b']
store_int_dfs = []
for task_version in task_versions:
    file_path = f"_data/qs-intervention-{task_version}/processed/int_data.csv"
    int_data_pd = pd.read_csv(file_path).drop(columns=['task_version'])  # [int_cols]
    store_int_dfs.append(int_data_pd)
int_data = pd.concat(store_int_dfs).reset_index(drop=True)
int_data = int_data[int_data['group'] == 'D']
int_data = int_data[~int_data.isna().any(axis=1)]

act_data = int_data[['sub', 'condition', 'group', 'responses_act_0']]
act_data = act_data.replace(r'\s+\.', '.', regex=True)
act_data = act_data.replace(r'\s+,', ', ', regex=True)
act_data = act_data.replace(r'\n+,', ' ', regex=True)
act_data.rename(columns={'responses_act_0': 'act_text'}, inplace=True)
act_data['n_words'] = act_data['act_text'].str.split(' ').apply(lambda x: len(x))
min_word = 100
act_data = act_data[act_data['n_words'] >= min_word].reset_index()
act_data['act_text'] = act_data['act_text'].apply(lambda x: x.strip())
# %%
# pt_files = [f for f in files_list if
#                 re.search(common_tasks, f) and re.search(sample_config.model_name, f) and re.search(r'\.pt$', f)]

best_layer_index = 5
pt_files = [f for f in os.listdir(hs_path) if '.pt' in f]
yt_labs = ['Depression: Factor 1', 'Somatic & Emotional: Factor 2', 'Cognitive: Factor 3',
           'Appetite & Weight: Factor 4']
sds_preds_df = pd.DataFrame()
for pt_file in pt_files:
    sub = pt_file.split('^^')[0]
    group = act_data[act_data['sub'] == sub]['group'].values[0]
    condition = act_data[act_data['sub'] == sub]['condition'].values[0]
    # print(pt_file)
    sub_hs_ts = torch.load(f'{hs_path}{pt_file}', weights_only=True)[best_layer_index].type(torch.float)
    sub_hs_ts = sub_hs_ts / np.linalg.norm(sub_hs_ts, axis=1, keepdims=True)

    sub_hs_ts_pca = pca.transform(sub_hs_ts)
    sub_hs_ts_pca = sub_hs_ts_pca / np.linalg.norm(sub_hs_ts_pca, axis=1, keepdims=True)
    sub_hs_ts_sds = clf.predict(sub_hs_ts_pca)
    # sub_hs_ts_sds = (sub_hs_ts_sds - np.nanmean(sub_hs_ts_sds, axis=0)) / np.nanstd(sub_hs_ts_sds, axis=0)
    # save_sds_preds[sub] = sub_hs_ts_sds
    tmp_dict = {'sub': sub, 'condition': condition, 'group': group}
    for fs in range(sub_hs_ts_sds.shape[1]):
        tmp_dict['t'] = np.arange(1, sub_hs_ts_sds.shape[0] + 1)
        tmp_dict[f'f{fs + 1}'] = sub_hs_ts_sds[:, fs]
    tmp_pd = pd.DataFrame(tmp_dict)
    sds_preds_df = pd.concat([sds_preds_df, tmp_pd], axis=0)
    del sub_hs_ts

#%% get scores iwth text
sds_preds_df_stats = sds_preds_df.drop(columns=['t']).groupby(['sub','condition','group'],as_index=False).aggregate(['mean','std'])
sds_preds_df_stats.columns = [v1+'_'+v2  if v2!='' else v1 for v1,v2 in zip(sds_preds_df_stats.columns.get_level_values(0), sds_preds_df_stats.columns.get_level_values(1))]
sds_pred_df_stats_wtext = pd.merge(act_data,sds_preds_df_stats,on=['sub','condition','group']).drop(columns='index')
sds_pred_df_stats_wtext.to_csv('_data/sds_scores_wtext.csv',index=False)
# %%%
# scaler = StandardScaler()
fs_cols = [f'f{f + 1}' for f in range(4)]
fs_name_dict = {fs: fsname for fs, fsname in zip(fs_cols, yt_labs)}
sel_fs = yt_labs[0:2]
mU = sds_preds_df[fs_cols].mean()
siG = sds_preds_df[fs_cols].std()
sds_preds_df[fs_cols] = (sds_preds_df[fs_cols] - mU) / siG
sds_pred_df_long = pd.melt(sds_preds_df, id_vars=['sub', 'condition', 'group', 't'], var_name='factor')
sds_pred_df_long['factor'] = sds_pred_df_long['factor'].replace(fs_name_dict)

# sds_pred_df_long = sds_pred_df_long[sds_pred_df_long['factor'].isin(sel_fs)]

sds_mh = sds_pred_df_long[(sds_pred_df_long['condition'] == "MH") & (sds_pred_df_long['factor'].isin(sel_fs))]
sds_ml = sds_pred_df_long[(sds_pred_df_long['condition'] == "ML") & (sds_pred_df_long['factor'].isin(sel_fs))]

plt.close('all')
pc.r, pc.c, pc.mlt = 2, 2, 2.75
pc.figsize = ((pc.c + 4.25) * pc.mlt, (pc.r + 1.25) * pc.mlt)

fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize)
# axes = np.array([[axes]])
# axes = np.array([axes]).T
pc.axes = axes
pc.ax_ts(16, 1.1)
pc.l_fs(12, 0.85)
ts = 14
pc.xyt_ls(ts, 10)
pc.ax_ls(16)
pc.kde_lw = 3
pc.p_lab_spec[2] = 12
pc.p_lab_spec[0] = -0.1
pc.p_lab_spec[1] = 1.1
pc.dpi_val = 300

pc.i = 0
pc.onerow = False
pc.j = 0
# Time per question
ssize = 100.75
slw = 1
alpha = 0.5

ylimval = 2.75
hue_cols = ['tab:blue', 'tab:orange']

pc.i = 0
sns.lineplot(sds_mh, x='t', y='value', hue='factor', ax=pc.ax, units='sub', estimator=None, alpha=0.2, lw=slw,
             legend=False, palette=hue_cols)
sns.lineplot(sds_mh, x='t', y='value', hue='factor', ax=pc.ax, errorbar='ci', lw=slw * 3, palette=hue_cols)
pc.ax.set_ylim([-ylimval, ylimval])
pc.ax.set_xticks(sds_mh['t'].unique())
pc.ax.set_xticklabels(sds_mh['t'].unique())
pc.ax.set_title('MH')
pc.ax.set_xlabel('Timepoint (token window)')
pc.ax.set_ylabel('SDS Factor score')

pc.j = 1
sns.lineplot(sds_ml, x='t', y='value', hue='factor', ax=pc.ax, units='sub', estimator=None, alpha=0.2, lw=slw,
             legend=False, palette=hue_cols)
sns.lineplot(sds_ml, x='t', y='value', hue='factor', ax=pc.ax, errorbar='ci', lw=slw * 3, palette=hue_cols)
# pc.ax.set_ylim([-1, 1])
pc.ax.set_ylim([-ylimval, ylimval])
pc.ax.set_xticks(sds_ml['t'].unique())
pc.ax.set_xticklabels(sds_ml['t'].unique())
pc.ax.set_title('ML')
pc.ax.set_xlabel('Timepoint (token window)')
pc.ax.set_ylabel('SDS Factor score')

# rem facotrs
hue_cols = ['tab:purple', 'tab:green']
sel_fs = yt_labs[2:4]
sds_mh = sds_pred_df_long[(sds_pred_df_long['condition'] == "MH") & (sds_pred_df_long['factor'].isin(sel_fs))]
sds_ml = sds_pred_df_long[(sds_pred_df_long['condition'] == "ML") & (sds_pred_df_long['factor'].isin(sel_fs))]

pc.i = 1
pc.j = 0
sns.lineplot(sds_mh, x='t', y='value', hue='factor', ax=pc.ax, units='sub', estimator=None, alpha=0.2, lw=slw,
             legend=False, palette=hue_cols)
sns.lineplot(sds_mh, x='t', y='value', hue='factor', ax=pc.ax, errorbar='ci', lw=slw * 3, palette=hue_cols)
pc.ax.set_ylim([-ylimval, ylimval])
pc.ax.set_xticks(sds_mh['t'].unique())
pc.ax.set_xticklabels(sds_mh['t'].unique())
pc.ax.set_title('MH')
pc.ax.set_xlabel('Timepoint (token window)')
pc.ax.set_ylabel('SDS Factor score')

pc.j = 1
sns.lineplot(sds_ml, x='t', y='value', hue='factor', ax=pc.ax, units='sub', estimator=None, alpha=0.2, lw=slw,
             legend=False, palette=hue_cols)
sns.lineplot(sds_ml, x='t', y='value', hue='factor', ax=pc.ax, errorbar='ci', lw=slw * 3, palette=hue_cols)
# pc.ax.set_ylim([-1, 1])
pc.ax.set_ylim([-ylimval, ylimval])
pc.ax.set_xticks(sds_ml['t'].unique())
pc.ax.set_xticklabels(sds_ml['t'].unique())
pc.ax.set_title('ML')
pc.ax.set_xlabel('Timepoint (token window)')
pc.ax.set_ylabel('SDS Factor score')

plt.tight_layout()
if savePlot:
    plt.savefig(f"{plots_path}hs_ts_sds.pdf", dpi=300)






# %%

task_v = ''
task_v = ['v4', 'v4_d', 'v4_dd']
common_tasks = 'v4'
if task_v == '':
    task_v_name = task_v
else:
    task_v_name = '_' + '-'.join(task_v)

qs_list = ['lvl2_q1', 'lvl2_q2', 'lvl2_q3']
qs_list = ['lvl1_q1', 'lvl2_q1', 'lvl2_q2', 'lvl2_q3', 'lvl3_q1', 'lvl3_q2', 'lvl3_q3', 'lvl3_q4', 'lvl3_q5', 'lvl3_q6',
           'lvl3_q7', 'lvl3_q8']

tok_loc = ['oq_ans', 2]
tok_loc_qs = ['oq_qs', 0]
from _llm_based.objects.data_configs import *

qs_config = QsConfig()

# for cross validations
# doSplit = True
# saveSplit = True
# run_search = True
# save_search = True

doSplit = False
saveSplit = False
run_search = False
save_search = False


# run_search = True
# save_search = True

def spearmanr_pval(x, y, b):
    return min(stats.spearmanr(x, y)[1] * b, 1)


def spearmanr_corr(x, y):
    return stats.spearmanr(x, y)[0]


def pearsonr_pval(x, y, b):
    return min(stats.pearsonr(x, y)[1] * b, 1)


def pearsonr_corr(x, y):
    return stats.pearsonr(x, y)[0]


tt_ratio = 0.75
n_alphas = 20  # 20
n_cv = 20  # 20
alpha_vals = np.logspace(-2, 2, n_alphas)
savePlots = False
# savePlots = True

# %%  Load states files and open-q data
files_list = []
dirs_list = []
for root, dirs, files in os.walk(f"{states_path}"):
    files_list.extend(files)
    dirs_list.extend(dirs)

pt_files = [f for f in files_list if
            re.search(common_tasks, f) and re.search(sample_config.model_name, f) and re.search(r'\.pt$', f)]
if len(qs_list) > 0:
    pt_files = [f for f in pt_files if any([True for q in qs_list if "^^" + q in f])]
subs = natsorted(list(set([f.split('^^')[0] for f in pt_files])))
qs = natsorted(list(set([f.split('^^')[1] for f in pt_files])))

pt_file_paths = []
for pt_file in pt_files:
    sub, q, model = pt_file.split('^^')
    model = model.split('_hidden_states')[0]
    pt_file_path = f"{states_path}/{sub}/{q}/{model}/{pt_file}"
    pt_file_paths.append(pt_file_path)
emb_size = torch.load(f"{pt_file_paths[0]}", weights_only=True).shape  # dim  (layer, location, hidden_dim)

# load open q data
min_words = 20
openq_data_long = pd.read_csv('_data/openq_data_long.csv')
openq_data_long = openq_data_long[openq_data_long['question'].isin(qs)]
openq_data_long = openq_data_long[openq_data_long['task_version'].isin(task_v)].reset_index(drop=True)
openq_data_long = openq_data_long.replace(r'\s+\.', '.', regex=True)
openq_data_long = openq_data_long.replace(r'\s+,', ', ', regex=True)
openq_data_long = openq_data_long.replace(r'\n+,', ' ', regex=True)
openq_data_long['response'] = openq_data_long['response'].astype(str)
openq_data_long['n_words'] = openq_data_long['response'].apply(lambda x: len(x.split()))
# openq_data_long = openq_data_long[openq_data_long['n_words']>=min_words]

q_name = 'sds'
sds_q_names = ['sds_q' + str(q + 1) + 's' for q in range(20)]

sds_data = pd.read_csv(f"_data/sds_data.csv")[['sub', 'task_version'] + sds_q_names]
sds_data = sds_data[sds_data['task_version'].isin(task_v)].reset_index(drop=True)
nan_sds_subs = sds_data[sds_data.isna().any(axis=1)]['sub'].to_list()
# %% Load tensor states for the last token after open answer for each layer
openq_info = []
store_hs = []
store_hs_qs = []
s_ind = -1
bad_subs = []
for sub in subs:
    store_sub_hs = []
    store_sub_hs_qs = []
    openq_data_sub = openq_data_long[openq_data_long['sub'] == sub]
    word_cond = all(openq_data_sub['n_words'] >= min_words)
    sds_cond = sub not in nan_sds_subs
    # print(word_cond, sds_cond, sub)
    if word_cond and sds_cond:
        s_ind += 1
        q_ind = -1
        for q in qs:
            q_ind += 1
            # tmp_info = {}
            file_key = sub + '^^' + q
            file_names = [f for f in pt_files if file_key in f]
            path_names = [f for f in pt_file_paths if file_key in f]
            if len(file_names) == 1:
                # store order of subjects/questions for states
                openq_data = openq_data_long.loc[(openq_data_long['sub'] == sub) & (openq_data_long['question'] == q)]
                # openq_data = openq_data.replace(r'\s+\.', '.', regex=True)
                # openq_data = openq_data.replace(r'\s+,', ', ', regex=True)
                # openq_data = openq_data.replace(r'\n+,', ' ', regex=True)
                # resp = openq_data['response'].values[0]
                # n_words = len(resp.split())

                # if n_words >= min_words:
                tmp_info = {'sub_index': s_ind, 'sub': sub, 'q_index': q_ind, 'question': q,
                            'n_words': int(openq_data['n_words'].values[0])}
                openq_info.append(tmp_info)
                tmp_embds_qsub = torch.load(f"{path_names[0]}", weights_only=True).to('cpu').type(
                    torch.FloatTensor)[:, tok_loc[1], :]
                # change the order where last layer is last
                tmp_embds_qsub = torch.roll(tmp_embds_qsub, -1, 0)
                # normalise the embedding (length 1)
                tmp_embds_qsub /= np.linalg.norm(tmp_embds_qsub, axis=1, keepdims=True)
                store_sub_hs.append(tmp_embds_qsub)

                tmp_embds_qsub_qs = torch.load(f"{path_names[0]}", weights_only=True).to('cpu').type(
                    torch.FloatTensor)[:, tok_loc_qs[1], :]
                # change the order where last layer is last
                tmp_embds_qsub_qs = torch.roll(tmp_embds_qsub_qs, -1, 0)
                # normalise the embedding (length 1)
                tmp_embds_qsub_qs /= np.linalg.norm(tmp_embds_qsub_qs, axis=1, keepdims=True)
                store_sub_hs_qs.append(tmp_embds_qsub_qs)

            elif len(file_names) > 1:
                print('Found many files')
                print(file_names)
            else:
                print('issue empty')
                tmp_embds_qsub = torch.full(emb_size, float('nan'), dtype=torch.float32)[:, tok_loc[1], :]
                tmp_embds_qsub_qs = torch.full(emb_size, float('nan'), dtype=torch.float32)[:, tok_loc[1], :]
        store_hs.append(torch.stack(store_sub_hs))
        store_hs_qs.append(torch.stack(store_sub_hs_qs))
    else:
        bad_subs.append(sub)
store_hs = torch.stack(store_hs)
store_hs_qs = torch.stack(store_hs_qs)
openq_info = pd.DataFrame(openq_info)
# %% Get sds and open q dataframes + subject list
pc.n_qs = qs_config.qs_n_qs['sds']
openq_data_long = openq_data_long[~openq_data_long['sub'].isin(bad_subs)].reset_index(drop=True)
# sds_data = pd.read_csv(f"_data/sds_data.csv")[['sub', 'task_version'] + sds_q_names]
# sds_data = sds_data[sds_data['task_version'].isin(task_v)].reset_index(drop=True)
# sds_data = sds_data[sds_data['sub'].isin(openq_info['sub'].unique())].reset_index(drop=True)
sds_data = sds_data[~sds_data['sub'].isin(bad_subs)].reset_index(drop=True)

sds_data = pd.merge(sds_data, openq_info[['sub_index', 'sub']].drop_duplicates(), on='sub')
col_ord = sds_data.columns[:1].to_list() + sds_data.columns[-1:].to_list() + sds_data.columns[1:-1].to_list()
sds_data = sds_data[col_ord].sort_values(by='sub_index', ascending=True).reset_index(drop=True)

sds_wide = sds_data.iloc[:, 3:]

sds_wide_std = pd.DataFrame(StandardScaler().fit_transform(sds_wide))
sds_cov = sds_wide_std.cov()

# nan_subs = sds_data[sds_data.isna().any(axis=1)][['sub', 'sub_index']]
# subs = [s for s in openq_info['sub'].unique()]  # if (s not in nan_subs['sub'].to_list())]
subs = natsorted(openq_info['sub'].unique().tolist())

# %% Plot covariance matrix
annot_fs = 5
plt.close('sds cov mat')
plt.figure('sds cov mat')
sns.heatmap(sds_cov, vmin=-0.3, vmax=1, annot=True, annot_kws={"size": annot_fs}, cmap="Oranges", fmt='.2f')
plt.yticks(np.arange(0.5, pc.n_qs + 0.5), np.arange(1, pc.n_qs + 1), rotation='horizontal')
plt.xticks(np.arange(0.5, pc.n_qs + 0.5), np.arange(1, pc.n_qs + 1), rotation='horizontal')
plt.xlabel('Question')
plt.ylabel('Question')
plt.title(f"Covariance matrix")
# if savePlots:
#     plt.savefig(f'{plots_path}cov_matrix_{q_name}.pdf', dpi=pc.dpi_val)

# %% Get factor scores
tmp_fa = FactorAnalyzer(rotation='promax', n_factors=qs_config.qs_factors[q_name])
tmp_fa.fit(sds_wide_std)
loadings = tmp_fa.loadings_
f_variances = tmp_fa.get_factor_variance()

f_psi_inv = np.diag(1 / tmp_fa.get_uniquenesses())
fs_weights = (np.linalg.inv(loadings.T @ f_psi_inv @ loadings) @ loadings.T @ f_psi_inv).T
fs_scores = sds_wide_std @ fs_weights
fs_scores = fs_scores.to_numpy()
fs_scores = (fs_scores - np.nanmean(fs_scores, axis=0)) / np.nanstd(fs_scores, axis=0)  # standardise

# fs_scores = fs_scores / np.linalg.norm(fs_scores, axis=1, keepdims=True)
# fs_scores = fs_scores/fs_scores.sum(axis=1,keepdims=True)
for f in range(fs_scores.shape[1]):
    sds_data.insert(2 + f, f"f{f + 1}_score", fs_scores[:, f])

# tmp_embds_qsub /= np.linalg.norm(tmp_embds_qsub, axis=1, keepdims=True)

# response counts
sds_score_cols = [c for c in sds_data.columns if 'sds_q' in c]
sds_long = sds_data.melt(id_vars=['sub'], value_vars=sds_score_cols, var_name='question', value_name='score')
sds_long['score'] = sds_long['score'].astype(int)
sds_long['question'] = sds_long['question'].str.replace('(sds_|s$)', '', regex=True)
sds_counts = sds_long.groupby(['question', 'score'], as_index=False).count().rename(
    columns={'sub': 'count'}).pivot(index=['question'], columns='score', values='count').T
sds_counts = sds_counts.loc[:, natsorted(sds_counts.columns)]

# %% Plot factor scores and response counts
pc.r, pc.c, pc.mlt = 1, 2, 3
pc.figsize = ((pc.c + 2.5) * pc.mlt, (pc.r + 1) * pc.mlt)
pc.annot_fs = 12
pc.ax_ts(17, 1.1)
pc.l_fs(12, 0.85)
ts = 15
pc.xyt_ls(ts + 2, ts)
pc.ax_ls(18)
pc.kde_lw = 3
pc.p_lab_spec[2] = 16
pc.p_lab_spec[0] = -0.075
pc.p_lab_spec[1] = 1
pc.dpi_val = 300
pc.ms = 100
ax_space = 5
pc.onerow = True
pc.i = 0
fname = 'factor loading and counts'
plt.close(fname)
# plt.figure(fname)
fig, pc.axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize, num=fname)

pc.j = 0
sns.heatmap(loadings, annot=True, fmt='.2f', annot_kws={'fontsize': pc.annot_fs}, ax=pc.ax)
pc.ax.set_xticks(np.arange(fs_scores.shape[1]) + 0.5)
pc.ax.set_xticklabels([f'Factor {f + 1}' for f in range(fs_scores.shape[1])], rotation='horizontal')
pc.ax.set_yticks(np.arange(pc.n_qs) + 0.5)
pc.ax.set_yticklabels([f'{q + 1}' for q in range(pc.n_qs)], rotation='horizontal')
pc.ax.set_xlabel('Factor')
pc.ax.set_ylabel('Question')
pc.ax.set_title(f"Factor loadings")
pc.ax.text(pc.p_lab_spec[0], pc.p_lab_spec[1], pc.p_labs[pc.i, pc.j], transform=pc.ax.transAxes, fontweight='bold',
           va='top', ha='right',
           fontsize=pc.p_lab_spec[2])

pc.j = 1
sns.heatmap(sds_counts.T, annot=True, fmt='.0f', annot_kws={'fontsize': pc.annot_fs}, ax=pc.ax)
pc.ax.set_yticks(np.arange(pc.n_qs) + 0.5)
pc.ax.set_yticklabels([f'{q + 1}' for q in range(pc.n_qs)], rotation='horizontal')
# pc.ax.set_yticks(np.arange(0.5, pc.n_qs + 0.5), np.arange(1, pc.n_qs + 1), rotation='horizontal')
pc.ax.set_xlabel('Score')
pc.ax.set_ylabel('Question')
pc.ax.set_title(f"Question score counts")
pc.ax.text(pc.p_lab_spec[0], pc.p_lab_spec[1], pc.p_labs[pc.i, pc.j], transform=pc.ax.transAxes, fontweight='bold',
           va='top', ha='right',
           fontsize=pc.p_lab_spec[2])

plt.tight_layout()
# if savePlots:
#     plt.savefig(f'{plots_path}factor_loading_and_counts_{q_name}.pdf', dpi=pc.dpi_val)

# %% Plot correlations between factor scores and closed scores
p_th = 0.05
lvl1_closed_data = pd.read_csv('_data/lvl1_closed_data.csv')
lvl1_closed_data = pd.merge(lvl1_closed_data, openq_info[['sub_index', 'sub']].drop_duplicates(), on='sub').sort_values(
    by='sub_index').reset_index(drop=True)
lvl2_closed_data = pd.read_csv('_data/lvl2_closed_data.csv')
lvl2_closed_data = pd.merge(lvl2_closed_data, openq_info[['sub_index', 'sub']].drop_duplicates(), on='sub').sort_values(
    by='sub_index').reset_index(drop=True)

lvl3_closed_data = pd.read_csv('_data/phq9_data.csv')
lvl3_closed_data = pd.merge(lvl3_closed_data, openq_info[['sub_index', 'sub']].drop_duplicates(), on='sub').sort_values(
    by='sub_index').reset_index(drop=True)

lvl1_closed_data = lvl1_closed_data.iloc[:, -2:-1]
lvl2_closed_data = lvl2_closed_data.iloc[:, -4:-1]
lvl3_closed_data = lvl3_closed_data.iloc[:, -10:-2]

closed_scores = pd.concat([lvl1_closed_data, lvl2_closed_data, lvl3_closed_data], axis=1)
closed_scores = pd.DataFrame(StandardScaler().fit_transform(closed_scores))
closed_scores.columns = qs
scores = pd.concat([pd.DataFrame(fs_scores, columns=['Factor ' + str(f) for f in np.arange(1, 4 + 1)]), closed_scores],
                   axis=1)

scores_corr = scores.corr(method=pearsonr_corr)
scores_p_vals = scores.corr(method=lambda x, y: pearsonr_pval(x, y, 4 * 12))
scores_corr = scores_corr.iloc[:4, 4:]
scores_p_vals = scores_p_vals.iloc[:4, 4:]

pc.r, pc.c, pc.mlt = 1, 1, 3
pc.figsize = ((pc.c + 4) * pc.mlt, (pc.r + 0) * pc.mlt)
pc.annot_fs = 12
pc.ax_ts(17, 1.1)
pc.l_fs(12, 0.85)
ts = 15
pc.xyt_ls(ts - 2, ts)
pc.ax_ls(18)
pc.kde_lw = 3
pc.p_lab_spec[2] = 16
pc.p_lab_spec[0] = -0.075
pc.p_lab_spec[1] = 1
pc.dpi_val = 300
pc.ms = 100
ax_space = 5
pc.onerow = True
pc.i = 0

fname = 'factor scores vs closed scores corrs'
plt.close(fname)
fig, pc.axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize, num=fname)
axes = np.array([pc.axes])[:, np.newaxis]
pc.axes = axes
pc.j = 0

sns.heatmap(scores_corr, mask=scores_p_vals >= p_th, annot=True, fmt='.2f', cmap='Oranges', ax=pc.ax,
            annot_kws={'fontsize': pc.annot_fs})
sns.heatmap(scores_corr, mask=scores_p_vals < p_th, annot=True, fmt='.2f', cmap='Greys', ax=pc.ax, cbar=False,
            annot_kws={'fontsize': pc.annot_fs})
xt_labs = [re.sub('_q', ' Q', re.sub('lvl', 'Level ', q)) for q in qs]
pc.ax.set_xticks(np.arange(len(qs)) + 0.5)
pc.ax.set_xticklabels(xt_labs, rotation=25)
# plt.xlabel('Closed question score')
pc.ax.set_ylabel('Factor')
pc.ax.set_xlabel('Closed question')
pc.ax.set_yticklabels([f'Factor {f + 1}' for f in range(fs_scores.shape[1])], rotation='horizontal')

prev_name = int(qs[0][3])
ii = -1
lvl_points = [1, 4]
for lp in (lvl_points):
    plt.axvline(x=lp, color='tab:blue', linewidth=3)

plt.tight_layout()
# if savePlots:
#     plt.savefig(f'{plots_path}fscores_vs_qs_scores_{q_name}.pdf', dpi=pc.dpi_val)

# %% Prepare train and test set shuffle or load preshuffled list
if doSplit:
    subs_shuffle = copy.deepcopy(subs)
    random.shuffle(subs_shuffle)
    subs_train = subs_shuffle[:int(len(subs_shuffle) * tt_ratio)]
    subs_test = subs_shuffle[int(len(subs_shuffle) * tt_ratio):]

    train_test_list = 'Train: ' + ', '.join(subs_train) + '\n' + 'Test: ' + ', '.join(subs_test)
    if saveSplit:
        with open(f"_data/latent_train-test_list.txt", 'w') as tt_file:
            tt_file.write(train_test_list)
else:
    with open(f"_data/latent_train-test_list.txt", 'r') as tt_file:
        tt_file_content = tt_file.readlines()
    subs_train = tt_file_content[0].split(': ')[1].split(', ')
    subs_test = tt_file_content[1].split(': ')[1].split(', ')

train_ind = openq_info[openq_info['sub'].isin(subs_train)]['sub_index'].unique()
test_ind = openq_info[openq_info['sub'].isin(subs_test)]['sub_index'].unique()
# % Prepare data for CV Ridge
f_cols = [f"f{f + 1}_score" for f in range(qs_config.qs_factors[q_name])]
Y_train = sds_data[sds_data['sub_index'].isin(train_ind)].loc[:, f_cols].to_numpy()
Y_train = (Y_train - np.nanmean(Y_train, axis=0)) / np.nanstd(Y_train, axis=0)

# %% Run CV to find best alpha and best layer with PCA
# % Prepare data for CV Ridge
# run_search=True
# run_search=False
f_cols = [f"f{f + 1}_score" for f in range(qs_config.qs_factors[q_name])]
Y_train = sds_data[sds_data['sub_index'].isin(train_ind)].loc[:, f_cols].to_numpy()
Y_train = (Y_train - np.nanmean(Y_train, axis=0)) / np.nanstd(Y_train, axis=0)

analysis_fname = 'avg_pca'
store_res_pca = []

if run_search:
    for l, layer in tqdm(enumerate(sample_config.layer_list)):
        print(f"Layer {layer}")
        X_train = store_hs[train_ind, :, l, :].mean(axis=1).numpy()  # average question embedding
        X_train = X_train / np.linalg.norm(X_train, axis=1, keepdims=True)

        pca = PCA(n_components=0.95, svd_solver='full')
        # pca = PCA(n_components=4, svd_solver='full')
        pca.fit(X_train)
        n_pca = pca.n_components_
        X_train_rec = pca.transform(X_train)
        X_train_rec = X_train_rec / np.linalg.norm(X_train_rec, axis=1, keepdims=True)

        clf = RidgeCV(alphas=alpha_vals, cv=n_cv)

        clf.fit(X_train_rec, Y_train)
        tmp_res = {'layer_index': l, 'layer': layer, 'best_alpha': clf.alpha_, 'best_score': clf.best_score_,
                   'n_cv': n_cv, 'n_pca': n_pca, 'q_name': q_name}
        store_res_pca.append(tmp_res)
    store_res_pca = pd.DataFrame(store_res_pca).sort_values(by='best_score', ascending=False).reset_index(drop=True)
    if save_search:
        store_res_pca.to_csv(f"_data/ridge_CV_{analysis_fname}_results_sds_factors.csv", index=False)
else:
    store_res_pca = pd.read_csv(f"_data/ridge_CV_{analysis_fname}_results_sds_factors.csv", index_col=False)
best_l = store_res_pca.loc[0, 'layer_index']
best_alpha = store_res_pca.loc[0, 'best_alpha']
n_pca = store_res_pca.loc[0, 'n_pca']
print(f"Best layer: {best_l}, best alpha: {best_alpha:.3f}, n_pca: {n_pca}")

# Fit the best model
X_train = store_hs[train_ind, :, best_l, :].mean(axis=1).numpy()
X_train = X_train / np.linalg.norm(X_train, axis=1, keepdims=True)
pca = PCA(n_components=n_pca, svd_solver='full')
pca.fit(X_train)
X_train_rec = pca.transform(X_train)
X_train_rec = X_train_rec / np.linalg.norm(X_train_rec, axis=1, keepdims=True)

clf = Ridge(alpha=best_alpha)
clf.fit(X_train_rec, Y_train)
print(f"Train score: {clf.score(X_train_rec, Y_train)}")

# Score best model with PCA on test set
X_test = store_hs[test_ind, :, best_l, :].mean(axis=1)
X_test = X_test / np.linalg.norm(X_test, axis=1, keepdims=True)
X_test_rec = pca.transform(X_test)
X_test_rec = X_test_rec / np.linalg.norm(X_test_rec, axis=1, keepdims=True)

Y_test = sds_data[sds_data['sub_index'].isin(test_ind)].loc[:, f_cols].to_numpy()
Y_test = (Y_test - np.nanmean(Y_test, axis=0)) / np.nanstd(Y_test, axis=0)

Y_pred = clf.predict(X_test_rec)
Y_pred = (Y_pred - np.nanmean(Y_pred, axis=0)) / np.nanstd(Y_pred, axis=0)

my_rss = ((Y_pred - Y_test) ** 2).sum(axis=0)
my_tss = ((Y_pred - Y_test.mean(axis=0)) ** 2).sum(axis=0)
my_r2 = (1 - my_rss / my_tss)

print(f"My R2: {my_r2}, mean r2: {my_r2.mean()}")
# print(f"Test score: {clf.score(X_test, Y_test)}")
print(f"Test score: {clf.score(X_test_rec, Y_test)}")

# Test correlations
corr_test = stats.pearsonr(Y_test, Y_pred)
p_vals = np.min([corr_test.pvalue * 4, np.ones_like(corr_test.pvalue)], axis=0)

# %% Plot scatter and correlations with PCA
pc.r, pc.c, pc.mlt = 1, Y_test.shape[1], 2.1
pc.figsize = ((pc.c + 1.5) * pc.mlt, (pc.r + 0.5) * pc.mlt)
pc.ax_ts(15, 1.1)
pc.l_fs(12, 0.85)
pc.xyt_ls(15, 15)
pc.ax_ls(18)
pc.kde_lw = 3
pc.p_lab_spec[2] = 14
pc.p_lab_spec[0] = -0.025
pc.p_lab_spec[1] = 1.1
pc.dpi_val = 300
ms = 70
lw = 4
fname = 'reg results'

plt.close(fname)
fig, axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize, num=fname)
pc.axes = axes.flatten()
pc.onerow = True
pc.i = 0
N = Y_test.shape[0]

yt_labs = ['Depression: \nFactor 1', 'Somatic & Emotional: \nFactor 2', 'Cognitive: \nFactor 3',
           'Appetite & Weight: \nFactor 4']
for pc.j in range(pc.c):
    # pc.ax.scatter(Y_test[:,pc.j],Y_pred[:,pc.j])
    # sns.regplot(x=Y_test[:, pc.j], y=Y_pred[:, pc.j], ax=pc.ax, ci=95, label=True)
    sns.regplot(x=Y_test[:, pc.j], y=Y_pred[:, pc.j], ax=pc.ax, ci=95, label=True, scatter=True, scatter_kws={'s': ms},
                line_kws={'lw': lw}, color='tab:orange')
    # sns.regplot(x=x, y=y)
    # pc.ax.set_title(f"factor {pc.j+1}")
    pc.ax.set_xlabel('True factor score')
    pc.ax.set_ylabel('Predicted factor score')
    # corr_test = stats.spearmanr(Y_test[:,pc.j], Y_pred[:,pc.j])
    # p_vals = np.min([corr_test.pvalue * 4, np.ones_like(corr_test.pvalue)], axis=0)
    pc.ax.set_title(
        f'{yt_labs[pc.j]}, r={corr_test.correlation[pc.j]:.3f}')
    # pc.ax.set_title(
    #     f'{yt_labs[pc.j]} \nr={corr_test.correlation[pc.j]:.3f}; p={p_vals[pc.j]:.2e}')
    # pc.ax.set_title(
    #     f'Factor {pc.j + 1}\nr={corr_test.correlation[pc.j]:.3f}; p={p_vals[pc.j]:.2e}')
    pc.ax.text(pc.p_lab_spec[0], pc.p_lab_spec[1], pc.p_labs[pc.i, pc.j], transform=pc.ax.transAxes, fontweight='bold',
               va='top', ha='right',
               fontsize=pc.p_lab_spec[2])
    # pc.ax.set_aspect('equal', 'box')
    pc.ax.set_xticks(np.arange(-2, 3, 2))
    xlim = 2.95
    ylim = 2.5
    pc.ax.set_xlim([-xlim, xlim])
    pc.ax.set_yticks(np.arange(-2, 3, 2))
    pc.ax.set_ylim([-ylim, ylim])

# plt.suptitle(f"Ridge regression: true vs predicted factor scores (N={N}) - Test set")
plt.tight_layout()
if savePlots:
    plt.savefig(f'{plots_path}scatter_{analysis_fname}_results_{q_name}.pdf', dpi=pc.dpi_val)
# %% Plot regression weights
plt.close('reg weights')
dd = clf.coef_.shape[1]
nq = int(dd / emb_size[2])
nf = clf.coef_.shape[0]
# xt = np.arange(0,dd+emb_size[2],emb_size[2])
xt = np.arange(emb_size[2] / 2, dd, emb_size[2])
xt_lab = ['open q1', 'open q2', 'open q3']
plt.figure('reg weights')
sns.heatmap(clf.coef_, cmap='vlag', annot=True)
# plt.xticks(np.arange(0.5, pc.n_qs + 0.5),np.arange(1, pc.n_qs +1),rotation='horizontal')
plt.yticks(np.arange(0.5, nf + 0.5), ['factor ' + str(f) for f in np.arange(1, nf + 1)], rotation='horizontal')
# plt.xticks(xt+0.5,xt)
# plt.xticks(xt + 0.5, xt_lab, rotation='horizontal')
# plt.ticks(xt+0.5,xt)
# plt.xticks(np.arange(0.5, pc.n_qs + 0.5),np.arange(1, pc.n_qs +1),rotation='horizontal')
for i in range(nq):
    if i > 0:
        plt.axvline(x=i * emb_size[2], color='black')
for i in range(nf):
    if i > 0:
        plt.axhline(y=i, color='black')
    # print(i*emb_size[2])
# plt.add_patch()
plt.title(f"Regression weights")
plt.tight_layout()
if savePlots:
    plt.savefig(f'{plots_path}reg_{analysis_fname}_weights_{q_name}.pdf', dpi=pc.dpi_val)
# %% Predict factor scores from each individual-question hidden state (PCA-reduced) and calculate correlation for each factor
Y_all = sds_data.loc[:, f_cols].to_numpy()
Y_all = (Y_all - np.nanmean(Y_all, axis=0)) / np.nanstd(Y_all, axis=0)
p_th = 0.05
store_q_pred = []
# for sub_h in store_hs[:, :, best_l, :]:
for sub_h in store_hs[test_ind, :, best_l, :]:
    sub_h = pca.transform(sub_h)
    sub_h /= np.linalg.norm(sub_h, axis=1, keepdims=True)
    q_pred = clf.predict(sub_h).T
    q_pred = (q_pred - np.nanmean(q_pred, axis=0)) / np.nanstd(q_pred, axis=0)
    store_q_pred.append(q_pred)

store_q_pred = np.stack(store_q_pred)
store_corrs = []
store_pvals = []
for f in range(store_q_pred.shape[-1]):
    corr_test_tmp = stats.pearsonr(Y_test, store_q_pred[:, :, f])
    # corr_test_tmp = stats.pearsonr(Y_all, store_q_pred[:, :, f])
    store_corrs.append(corr_test_tmp.correlation)
    p_vals_tmp = np.min([corr_test_tmp.pvalue * 4 * 12, np.ones_like(corr_test_tmp.pvalue)], axis=0)
    store_pvals.append(p_vals_tmp)
store_corrs = np.stack(store_corrs).T
store_pvals = np.stack(store_pvals).T

plt.close(fname)
fname = 'individual question to factor corrs'
pc.r, pc.c, pc.mlt = 1, 1, 2.4
pc.figsize = ((pc.c + 6.1) * pc.mlt, (pc.r + 0.3) * pc.mlt)
pc.annot_fs = 17
pc.ax_ts(17, 1.1)
pc.l_fs(12, 0.85)
ts = 14
pc.xyt_ls(ts - 2, ts)
pc.ax_ls(16)
pc.kde_lw = 3
pc.p_lab_spec[2] = 16
pc.p_lab_spec[0] = -0.075
pc.p_lab_spec[1] = 1
pc.dpi_val = 300
pc.ms = 100
ax_space = 5
pc.onerow = True
pc.i = 0
fig, pc.axes = plt.subplots(pc.r, pc.c, figsize=pc.figsize, num=fname)
axes = np.array([pc.axes])[:, np.newaxis]
pc.axes = axes
pc.j = 0

sns.heatmap(store_corrs, mask=store_pvals < p_th, annot=True, cmap=ListedColormap(['white']), cbar=False,
            fmt='.2f', annot_kws={'size': pc.annot_fs, 'color': "#fcfbfd"}, ax=pc.ax)
sns.heatmap(store_pvals, mask=(store_pvals < p_th) & (store_pvals > 0.2), annot=False, cmap='Purples_r', cbar=True,
            vmax=0.2, vmin=0.05, ax=pc.ax, cbar_kws={'label': 'Trend P-value', 'pad': -0.07})
sns.heatmap(store_corrs, mask=store_pvals >= p_th, annot=True, cmap='Oranges', vmin=0, fmt='.2f',
            annot_kws={'size': pc.annot_fs}, ax=pc.ax, cbar_kws={'label': 'Correlation', 'pad': 0.01})
prev_name = int(qs[0][3])
ii = -1
lvl_points = [1, 4]
for lp in (lvl_points):
    plt.axvline(x=lp, color='tab:blue', linewidth=3)
for lp in range(len(qs) + 1):
    plt.axvline(x=lp, color='black', linewidth=0.5)
# xt_labs = [re.sub('_q', ' Q', re.sub('lvl', 'Level ', q)) for q in qs]

xt_labs = [re.sub('_q', ' Q', re.sub('lvl', 'L', q)) for q in qs]
qs_decr = ["Well-being\n& functioning", "Mood & \nfeelings", "Sleep,\n energy & \n cognition", "Self-worth",
           "Interest & \npleasure",
           "Mood & \nfeelings", "Sleep\nissues", "Energy\nlevels", "Appetite\nchanges", "Self-worth",
           "Focus", "Slowness or \nrestlessness"]
xt_labs = [f"{xl}\n({qd})" for xl, qd in zip(xt_labs, qs_decr)]

yt_labs = ['Depression-1', 'Somatic & Emotional-2', 'Cognitive-3', 'Appetite & Weight-4']
pc.ax.set_xticks(np.arange(len(qs)) + 0.5)
pc.ax.set_xticklabels(xt_labs, rotation="horizontal")
pc.ax.set_ylabel('Factor')
pc.ax.set_xlabel('Open question')
# pc.ax.set_yticklabels([f'Factor {f + 1}' for f in range(fs_scores.shape[1])], rotation='horizontal')
pc.ax.set_yticklabels(yt_labs, rotation=5)
pc.ax.xaxis.set_tick_params(width=2, length=10)
# plt.title('Factor score prediction from each question - corrs')
plt.tight_layout()
fig.subplots_adjust(right=1.075)
if savePlots:
    plt.savefig(f'{plots_path}reg_{analysis_fname}_factor_score_from_each_q_{q_name}.pdf', dpi=pc.dpi_val)
# %% ARCHIVE BElLOW
# ARCHIVE BElLOW
# # %% Similarity to regression wieghts
# # store_hs_avg = store_hs_qs[:, :, best_l, :]
# store_hs_avg = store_hs[test_ind, :, best_l, :].mean(axis=0)
# store_hs_avg /= np.linalg.norm(store_hs_avg, axis=1, keepdims=True)
# store_hs_avg_rec = pca.transform(store_hs_avg)
# store_hs_avg_rec /= np.linalg.norm(store_hs_avg_rec, axis=1, keepdims=True)
# # store_hs_avg_rec = []
# # for q in range(store_hs_avg.shape[1]):
# #     tmp_hs_avg_rec = pca.transform(store_hs_avg[:,q,:])
# #     tmp_hs_avg_rec /= np.linalg.norm(tmp_hs_avg_rec, axis=1, keepdims=True)
# #     store_hs_avg_rec.append(tmp_hs_avg_rec)
# # store_hs_avg_rec = np.stack(store_hs_avg_rec,axis=1)
#
#
# w = clf.coef_ / np.linalg.norm(clf.coef_, axis=1, keepdims=True)
#
# plt.close('reg weights sim')
# plt.figure('reg weights sim')
# # sim = cosine_similarity(w, store_hs_avg)
# sim = cosine_similarity(w, store_hs_avg_rec)
# # sim = cosine_similarity(w_active, store_hs_avg)
# sns.heatmap(sim, annot=True, annot_kws={'fontsize': 8}, fmt='.2f', cmap='coolwarm')
# plt.xlabel('Open question')
# plt.xticks(np.arange(len(qs)) + 0.5, qs, rotation=80)
# plt.ylabel('Factor')
# plt.yticks(np.arange(nf) + 0.5, np.arange(nf) + 1, rotation=0)
# plt.title('Regression vector similiarty to average question vector')
# plt.tight_layout()
# if savePlots:
#     plt.savefig(f'{plots_path}reg_{analysis_fname}_factor_weights_sim_wQ_{q_name}.pdf', dpi=pc.dpi_val)
#
# # %%
# store_q_pred = []
# for y_true, sub_h in zip(Y_all, store_hs[:, :, best_l, :]):
#     sub_h = pca.transform(sub_h)
#     sub_h /= np.linalg.norm(sub_h, axis=1, keepdims=True)
#     q_pred = clf.predict(sub_h).T
#     q_pred = (q_pred - np.nanmean(q_pred, axis=0)) / np.nanstd(q_pred, axis=0)
#     y_t = y_true[np.newaxis, :].T
#     # err = (1/(q_pred.shape[0]*q_pred.shape[1]))*(q_pred - y_t)**2
#     err = (q_pred - y_t) ** 2
#     # clf.score(sub_h[0][np.newaxis,:],y_true[np.newaxis,:])
#     # store_q_pred.append(q_pred)
#     store_q_pred.append(err)
# store_q_pred = np.stack(store_q_pred).mean(axis=0)
# plt.close('reg weights q pred avg')
# plt.figure('reg weights q pred avg')
# sns.heatmap(store_q_pred, annot=True, annot_kws={'fontsize': 8}, fmt='.2f', cmap='Blues_r')
# plt.xlabel('Open question')
# plt.xticks(np.arange(len(qs)) + 0.5, qs, rotation=80)
# plt.ylabel('Factor')
# plt.yticks(np.arange(nf) + 0.5, np.arange(nf) + 1, rotation=0)
# plt.tight_layout()
# # %%
# q_pred = clf.predict(store_hs_avg_rec).T
# q_pred = (q_pred - np.nanmean(q_pred, axis=0)) / np.nanstd(q_pred, axis=0)
# plt.close('reg weights q pred')
# plt.figure('reg weights q pred')
# sns.heatmap(q_pred, annot=True, annot_kws={'fontsize': 8}, fmt='.2f', cmap='coolwarm')
# plt.xlabel('Open question')
# plt.xticks(np.arange(len(qs)) + 0.5, qs, rotation=80)
# plt.ylabel('Factor')
# plt.yticks(np.arange(nf) + 0.5, np.arange(nf) + 1, rotation=0)
# plt.tight_layout()
#
# # %%
#
#
# # %% Sum abs of top regression weights per factor
# top_p = 0.05
# # from scipy.stats import t
# # t_critical = t.ppf(1 - 0.05 / 2, dd)  # Two-tailed
# w_abs = np.abs(clf.coef_)
# # w_abs = clf.coef_
# w_active = np.zeros_like(w_abs)
# ranks = np.argsort(-w_abs, axis=1)
# top_ranks = ranks[:, :int(dd * top_p)]
# for i in range(nf):
#     w_active[i, top_ranks[i, :]] = w_abs[i, top_ranks[i, :]]
#     # w_active[i, top_ranks[i, :]] = w_abs[i, top_ranks[i, :]]/w_abs[i,:].sum()
#
# plt.close('active reg weights explore')
# plt.figure('active reg weights explore')
# sns.heatmap(w_active)
# plt.tight_layout()
# # %%
# # thr_val = (w_abs.mean(axis=1)/(w_abs.std(axis=1)*np.sqrt(dd)))[:,np.newaxis]
# # thr_m=w_abs.mean(axis=1)[:,np.newaxis]
# # thr_s=(w_abs.std(axis=1)*np.sqrt(dd))[:,np.newaxis]
# # t_vals = (w_abs-thr_m)/thr_s
# # w_active = t_vals
# # w_active = np.abs(t_vals)
# # w_active = t_vals[t_vals>t_critical]
# # w_active[w_abs>w_abs.mean()] = w_abs[w_abs>w_abs.mean()]
# # w_active[w_abs>w_abs.mean(axis=1)[:,np.newaxis]] = w_abs[w_abs>w_abs.mean(axis=1)[:,np.newaxis]]
# # w_active[w_abs>thr_val] = w_abs[w_abs>thr_val]
# # w_active[w_abs>w_abs.mean()] = 1
#
# w_abs_sum = np.zeros((nf, nq))
# for i in range(nq):
#     print(i)
#     # w_abs[:, i * emb_size[2]:(i + 1) * emb_size[2]].mean
#
#     # w_abs_sum[:,i] = w_abs[:,i*emb_size[2]:(i+1)*emb_size[2]].sum(axis=1)
#
#     # w_abs_sum[:, i] = w_active[:, i * emb_size[2]:(i + 1) * emb_size[2]].sum(axis=1) <=
#
#     # w_abs_sum[:, i] = w_active[:, i * emb_size[2]:(i + 1) * emb_size[2]].sum(axis=1)/w_abs[:,i * emb_size[2]:(i + 1) * emb_size[2]].sum(axis=1)
#     w_abs_sum[:, i] = w_active[:, i * emb_size[2]:(i + 1) * emb_size[2]].sum(axis=1) / w_abs.sum(axis=1)
#
# plt.close('reg weights sum per q')
# plt.figure('reg weights sum per q')
# sns.heatmap(w_abs_sum, cmap='Oranges', annot=True, fmt='.3f')
# plt.title(f"Sum of top {top_p} active weights")
# plt.xlabel('Open question dimensions')
# plt.ylabel('Factor')
# plt.yticks(np.arange(0.5, nf + 0.5), ['Factor ' + str(f) for f in np.arange(1, nf + 1)], rotation=45)
# plt.xticks(np.arange(0.5, nq + 0.5), ['Q' + str(q) for q in np.arange(1, nq + 1)], rotation='horizontal')
# for i in range(nf):
#     if i > 0:
#         plt.axhline(y=i, color='black')
# plt.tight_layout()
# if savePlots:
#     plt.savefig(f'{plots_path}top_{top_p}_{analysis_fname}_active_weights_{q_name}.pdf', dpi=pc.dpi_val)
#
