import pandas as pd
from natsort import natsorted
import os
import re
from scipy.stats import spearmanr, pearsonr


def spearmanr_pval(x, y):
    return spearmanr(x, y)[1]


qa_loc_key_names = ['oq_qs', 'oq_Answer', 'oq_ans', 'cq_qs', 'cq_Answer', 'last']


def process_logits_itemLevel(sample_config, paths, bools, task_v, phq9_qs_inv_map):
    if not bools.loadMe:
        # if task_v == '':
        # task_v_name = task_v
        subs = [d for d in os.listdir(paths.responses_path) if
                '.DS_Store' not in d and any([v == '_'.join(d.split('_')[1:]) for v in task_v])]

        store_responses = []
        for sub in subs:
            sub_path = f"{paths.responses_path}/{sub}"
            qs = [d for d in os.listdir(sub_path) if '.DS_Store' not in d]
            for q in qs:
                sub_q_path = f"{sub_path}/{q}/"
                spec = [d for d in os.listdir(sub_q_path) if '.DS_Store' not in d and sample_config.model_name_rp in d]
                if len(spec) > 0:
                    spec = spec[0]
                    csv_files = [f for f in os.listdir(f"{sub_q_path}/{spec}") if '.csv' in f]
                    csv_file = [f for f in csv_files if 'responses' in f]
                    if len(csv_file) > 0:
                        tmp_df = pd.read_csv(f"{sub_q_path}/{spec}/{csv_file[0]}")
                        store_responses.append(tmp_df)
                    # for csv_file in csv_files:
                    #     tmp_df = pd.read_csv(f"{sub_q_path}/{spec}/{csv_file}")
                    #     store_responses.append(tmp_df)
        store_responses = pd.concat(store_responses)
        if bools.saveMe:
            # store_responses.to_csv(f"{files_path}responses/sub_llm_responses{task_v_name}.csv", index=False)
            store_responses.to_csv(
                f"{paths.save_responses_path}/sub_llm_responses_{sample_config.model_name}.csv",
                index=False)

    else:
        # store_responses = pd.read_csv(f"{files_path}responses/sub_llm_responses{task_v_name}.csv")
        store_responses = pd.read_csv(
            f"{paths.save_responses_path}/sub_llm_responses_{sample_config.model_name}.csv")

    # % Calculate means scores to closed responses across all samples
    responses_mean = store_responses.groupby(['sub', 'question'], as_index=False)['score'].mean()
    responses_mean.rename(columns={'question': 'q_name'}, inplace=True)
    responses_mean = responses_mean[responses_mean['q_name'] != 'rep_lvl2_q1'].sort_values(
        by=['sub', 'q_name']).reset_index(drop=True)
    responses_mean.rename(columns={'question': 'q_name'}, inplace=True)

    # Load participant PHQ-9, LVL1, LVL2 closed responses data
    # PHQ9
    phq9_data = pd.read_csv(f"{paths.data_path}/phq9_data.csv")
    phq9_names = [c for c in phq9_data.columns if 'phq9_q' in c and 'phq9_q9' not in c and 's' in c]
    phq9_data_long = pd.melt(phq9_data, id_vars=['sub'], value_vars=phq9_names, var_name='q_name', value_name='score')
    phq9_data_long['q_name'] = phq9_data_long['q_name'].replace({k + 's': v for k, v in phq9_qs_inv_map.items()})

    # LVL1
    lvl1_closed_data = pd.read_csv(f"{paths.data_path}/lvl1_closed_data.csv")
    lvl1_closed_data_long = pd.melt(lvl1_closed_data, id_vars=['sub'], value_vars=['lvl1_closed_q1s'],
                                    var_name='q_name',
                                    value_name='score')
    lvl1_closed_data_long['q_name'] = lvl1_closed_data_long['q_name'].str.replace(r'_closed|s', '', regex=True)

    # LVL2
    lvl2_closed_data = pd.read_csv(f"{paths.data_path}lvl2_closed_data.csv")
    lvl2_closed_data_long = pd.melt(lvl2_closed_data, id_vars=['sub'],
                                    value_vars=['lvl2_closed_q1s', 'lvl2_closed_q2s', 'lvl2_closed_q3s', ],
                                    var_name='q_name', value_name='score')
    lvl2_closed_data_long['q_name'] = lvl2_closed_data_long['q_name'].str.replace(r'_closed|s', '', regex=True)

    # Concatenate all closed responses from participants
    closed_data_long = pd.concat([lvl1_closed_data_long, lvl2_closed_data_long, phq9_data_long],
                                 ignore_index=True).reset_index(drop=True)
    closed_data_long = closed_data_long.sort_values(by=['sub', 'q_name']).reset_index(drop=True)

    # Merge LLM and sub responses into one DF
    responses_avg_merged = pd.merge(responses_mean, closed_data_long, on=['sub', 'q_name'], how='left',
                                    suffixes=['_llm', '_sub'])

    q_names = responses_avg_merged['q_name'].unique()

    return responses_avg_merged, q_names


def retrieve_logits_genLevel(gen_qs, bools, sample_config, paths):
    if not bools.loadMe:
        store_responses = []
        subs = natsorted([d for d in os.listdir(paths.gen_responses_path) if '.DS_Store' not in d])
        for sub in subs:
            sub_path = f"{paths.gen_responses_path}{sub}/"
            qs = natsorted([d for d in os.listdir(sub_path) if '.DS_Store' not in d])
            for q in qs:
                sub_q_path = f"{paths.gen_responses_path}/{sub}/{q}"
                gqs = natsorted([d for d in os.listdir(sub_q_path) if '.DS_Store' not in d])
                for gq in gqs:
                    sub_gq_path = f"{paths.gen_responses_path}/{sub}/{q}/{gq}"
                    spec = [d for d in os.listdir(sub_gq_path) if
                            '.DS_Store' not in d and sample_config.model_name_rp in d]
                    if len(spec) > 0:
                        spec = spec[0]
                        csv_file = [f for f in os.listdir(f"{sub_gq_path}/{spec}") if '.csv' in f and 'responses' in f]
                        if len(csv_file) > 0:
                            tmp_df = pd.read_csv(f"{sub_gq_path}/{spec}/{csv_file[0]}")
                            tmp_df_mean = tmp_df.groupby(
                                ['sub', 'question', 'question_gen', 'model', 'instr_name', 'qs', 'gen_qs', 'label_perm',
                                 'nSamples'], as_index=False)['score'].mean()
                            store_responses.append(tmp_df_mean)

        store_responses = pd.concat(store_responses)
        if bools.saveMe:
            store_responses.to_csv(
                f"{paths.save_responses_path}/sub_llm_responses_gen_{gen_qs}_{sample_config.model_name}.csv",
                index=False)
    else:
        store_responses = pd.read_csv(
            f"{paths.save_responses_path}/sub_llm_responses_gen_{gen_qs}_{sample_config.model_name}.csv")

    return store_responses


def process_logits_genLevel(gen_qs, bools, sample_config, paths, task_v):
    responses_mean = retrieve_logits_genLevel(gen_qs, bools, sample_config, paths)
    responses_mean.rename(columns={'question': 'q_name_context'}, inplace=True)
    responses_mean.rename(columns={'question_gen': 'q_name'}, inplace=True)
    responses_mean = responses_mean.sort_values(by=['sub', 'q_name_context', 'q_name']).reset_index(drop=True)

    # Relevant closed questionnaire data
    gen_qs_data = pd.read_csv(f"{paths.data_path}/{gen_qs}_data.csv")
    gen_qs_data = gen_qs_data[gen_qs_data['task_version'].isin(task_v)]
    gen_qs_names = [c for c in gen_qs_data.columns if gen_qs + '_q' in c and re.search("[0-9]+s", c) is not None]
    gen_qs_data_long = pd.melt(gen_qs_data, id_vars=['sub'], value_vars=gen_qs_names, var_name='q_name',
                               value_name='score')

    gen_qs_data_long['q_name'] = gen_qs_data_long['q_name'].str.slice_replace(-1, None, '')
    gen_qs_data_long = gen_qs_data_long[~gen_qs_data_long.isna().any(axis=1)]

    # Merge LLM and sub responses into one DF
    responses_avg_merged = pd.merge(responses_mean, gen_qs_data_long, on=['sub', 'q_name'], how='left',
                                    suffixes=['_llm', '_sub'])
    responses_avg_merged = responses_avg_merged[~responses_avg_merged.isna().any(axis=1)]

    q_names = responses_avg_merged['q_name'].unique()
    return responses_avg_merged, gen_qs_data_long, q_names


def get_corrs_gen_wphq9(paths, gen_qs, closed_data_long, context_names, q_names, qs_config, task_v, p_thr=1):
    '''
        Compute item-level correlation between participant PHQ9 scores and partipants other questionnaire scores
        remove nans
    '''
    phq9_data = pd.read_csv(f"{paths.data_path}/phq9_data.csv")
    phq9_data = phq9_data[phq9_data['task_version'].isin(task_v)]
    phq9_data = phq9_data[phq9_data['sub'].isin(closed_data_long['sub'].unique())]
    phq9_data = pd.melt(phq9_data, id_vars=['sub'], value_vars=['phq9_q' + str(q + 1) + 's' for q in range(9)],
                        var_name='q_name', value_name='score')
    phq9_data['q_name'] = phq9_data['q_name'].str.replace('s', '', )

    phq9_data = phq9_data[~phq9_data.isna().any(axis=1)]

    if gen_qs == 'phq9':
        closed_data_long['q_name'] = closed_data_long['q_name'].str.replace('phq9', 'phq9_ppt')

    closed_data_long_wphq9 = pd.concat([closed_data_long, phq9_data], axis=0)
    closed_data_wphq9 = pd.pivot(closed_data_long_wphq9, index='sub', columns='q_name', values='score')
    closed_data_wphq9 = closed_data_wphq9[~closed_data_wphq9.isna().any(axis=1)]
    closed_data_wphq9 = closed_data_wphq9[natsorted(closed_data_wphq9.columns)]

    cross_corr = closed_data_wphq9.corr(method='spearman')
    cross_pvals = closed_data_wphq9.corr(method=spearmanr_pval)
    cross_pvals *= len(q_names) * len(context_names)
    cross_pvals[cross_pvals > 1] = 1
    # cross_corr[cross_pvals > p_thr] = np.nan

    phq9_names = ['phq9_q' + str(q + 1) for q in range(9)]
    if gen_qs == 'phq9':
        gen_qs_names = [gen_qs + '_ppt_q' + str(q + 1) for q in range(qs_config.qs_n_qs[gen_qs])]
    else:
        gen_qs_names = [gen_qs + '_q' + str(q + 1) for q in range(qs_config.qs_n_qs[gen_qs])]
    cross_corr_subset = cross_corr.loc[phq9_names, gen_qs_names]
    cross_pvals_subset = cross_pvals.loc[phq9_names, gen_qs_names]
    return cross_corr, cross_corr_subset, cross_pvals, cross_pvals_subset


def get_corrs_pvals_genq_logits(responses_avg_merged, context_names, q_names, p_thr=1):
    '''
        Get correlations between llm average item score on generalisation questionnaire (given each context) vs true score on the generalisation questionnaire
    '''
    df_corr = []
    for i, context_name in enumerate(context_names):
        for j, q_name in enumerate(q_names):
            responses_avg_merged_q = responses_avg_merged[(responses_avg_merged['q_name_context'] == context_name) & (
                    responses_avg_merged[
                        'q_name'] == q_name)]  # & (responses_avg_merged['spec_level'].isin(spec_list))]

            if len(responses_avg_merged_q) > 0:
                N = len(responses_avg_merged_q['sub'].unique())
                r, p = spearmanr(responses_avg_merged_q['score_llm'], responses_avg_merged_q['score_sub'])
                p = min(p * len(q_names) * len(context_names), 1)
                tmp_dict = {'context_name': context_name, 'q_name': q_name, 'r': r, 'p-val': p,
                            'N': N}
                df_corr.append(tmp_dict)

    df_corr = pd.DataFrame(df_corr)
    df_corr_wide = df_corr.pivot(index='context_name', columns='q_name', values=['r', 'p-val'])
    df_corr_wide = df_corr_wide[natsorted(df_corr_wide.columns)]

    N_range = df_corr['N'].unique().min()
    if df_corr['N'].unique().max() != N_range:
        N_range = str(N_range) + '-' + str(df_corr['N'].unique().max())

    return df_corr, df_corr_wide, N_range


def responses_totals_gen_logits(responses, closed_data_long):
    responses_sum = \
        responses.groupby(['sub', 'q_name_context'], as_index=False)['score_llm'].sum().groupby(['sub'],
                                                                                                as_index=False)[
            'score_llm'].mean()
    responses_sum = responses_sum.rename(columns={'score_llm': 'total_llm'})

    closed_data_sum = closed_data_long.groupby(['sub'], as_index=False)['score'].sum()
    closed_data_sum.rename(columns={'score': 'total_sub'}, inplace=True)
    totals_sub_llm = pd.merge(responses_sum, closed_data_sum, on=['sub'])

    r, p = pearsonr(totals_sub_llm['total_llm'], totals_sub_llm['total_sub'])

    return totals_sub_llm, r, p
