# python3 -m stats.test_conditions +data=finetuning_template +exp=tstat ++data.cached_data_array=/storage/czw/seeg_decoding/cached_data_arrays ++data.duration=1.0 ++data.interval_duration=1.0 ++data.name="sentence_position_finetuning" ++exp.target="pos"
from data.subject_data import SubjectData
import logging
import os
from omegaconf import DictConfig, OmegaConf
from data.electrode_selection import get_clean_laplacian_electrodes
import hydra
import json
from pathlib import Path
import numpy as np
from scipy import stats
from glob import glob
import pandas as pd

samp_rate = 2048

log = logging.getLogger(__name__)
FLOAT_FEATURES = ["delta_magnitude", "delta_pitch", "magnitude", "word_length", "rms", "pitch", "max_vector_magnitude", "mean_pixel_brightness", "max_global_angle", "gpt2_surprisal", "morpheme_count", "delta_rms", "face_num", "is_onset", "max_global_magnitude", "max_vector_angle", "max_mean_magnitude", "charecter_num", "word_diff", "bin_head"] #NOTE that is_onset and bin_head are categorical variables, but for our purposes, can be treated as a float
CAT_FEATURES = ["idx_in_sentence", "pos"]


def get_subject_df(file_path):
    subj, target = file_path.split('/')[-3:-1]
    print(subj, target)
    df = pd.read_csv(file_path)
    df = standardize_df(df, subj)
    return df

def remove_intercept(df):
    res = df[~df.label.str.contains('Intercept')]
    return res

def standardize_df(df, subj):
    df = df.copy()
    
    df = df.loc[:,~df.columns.isin(['DC10', 'DC4', 'TRIG4'])]

    keys = df.keys()
    new_keys = list(keys).copy()
    new_keys[0] = 'label'
    new_keys = [f'{x}_{subj}' if x != 'label' else x for x in new_keys ]
    mapping = {k:v for k,v in zip(keys,new_keys)}
    df.rename(columns=mapping, inplace=True)
    
    df.columns = df.columns.str.replace("\#",'',regex=True)
    df.columns = df.columns.str.replace("\*",'',regex=True)
    def rename_pos(x):
        all_words = x.split('-')
        init = all_words[:-1]
        last = all_words[-1]
        if last in ["posVERB", "pos[T.VERB]"]:
            new_all_words = init + ["posVERB"]
            return '-'.join(new_all_words)
        return x
    df.label = df.label.apply(rename_pos)
    df = remove_intercept(df)
    df.loc[~df.label.str.startswith('Sig'),~df.columns.isin(['label'])] = df.loc[~df.label.str.startswith('Sig'),~df.columns.isin(['label'])].astype('float32')
    return df

def get_all_results(results_dir):
    files = glob(f'{results_dir}/*/*/*.csv')
    files = [f for f in files if 'amp' in f]
    all_dfs = []
    for file_path in files:
        df = get_subject_df(file_path)
        df = df.set_index("label")
#         print(df.index.tolist())
        if 'Estimate-phoneme_num' in df.index.tolist():
            print(file_path)
        all_dfs.append(df)

    all_results = pd.concat(all_dfs,axis=1)
    return all_results

def subsample_control_equally_among_targets(features, control, target_1_idxs, target_2_idxs):
    #gets onset_idxs, midset_idxs, offset_idxs, where midset_idxs contains, for sentence_position_i, an equal amounts of nouns and verbs
    if control != "idx_in_sentence":
        print("feature not supported")
        import pdb; pdb.set_trace()

    onset_idxs = features[features["is_onset"]==1].index
    offset_idxs = features[features["is_offset"]==1].index
    midset_idxs = features[(features["is_onset"]==0) & (features["is_offset"]==0)].index

    midset_idxs = []
    max_idx = int(max(features.idx_in_sentence))
    for i in range(1,max_idx): #this is everything that occurs in the middle of the sentence
        i_idxs = features[(features.idx_in_sentence==i) & (features["is_offset"]==0)].index 
        target_1_i = np.intersect1d(i_idxs, target_1_idxs)
        target_2_i = np.intersect1d(i_idxs, target_2_idxs)
        min_len = min(len(target_1_i), len(target_2_i))
        midset_idxs.append(target_1_i[:min_len])
        midset_idxs.append(target_2_i[:min_len])
    midset_idxs = np.concatenate(midset_idxs)
        
    offset_idxs = []
    for i in range(1,max_idx+1): #this is everything that occurs in the middle of the sentence
        i_idxs = features[(features.idx_in_sentence==i) & (features["is_offset"]==1)].index 
        target_1_i = np.intersect1d(i_idxs, target_1_idxs)
        target_2_i = np.intersect1d(i_idxs, target_2_idxs)
        min_len = min(len(target_1_i), len(target_2_i))
        offset_idxs.append(target_1_i[:min_len])
        offset_idxs.append(target_2_i[:min_len])

    offset_idxs = np.concatenate(offset_idxs)
    return [onset_idxs, midset_idxs, offset_idxs]

def get_control_idxs(features, control, target_1_idxs=None, target_2_idxs=None):
    '''
        features -- dataframe of word features per word
        control -- the name of a feature, e.g. rms
        returns
            a list of indexs, each of which correspond with a subsample of the control conditions
    '''
    #restrict attention to high (low) rms and compare between targets

    if control == "idx_in_sentence":
        assert target_1_idxs is not None and target_2_idxs is not None
        return subsample_control_equally_among_targets(features, control, target_1_idxs, target_2_idxs)
    elif control in FLOAT_FEATURES:
        high = np.percentile(features[control], 75)
        low = np.percentile(features[control], 25)
        if control in ["bin_head", "is_onset"]:
            low, high = 0.4, 0.6
        if high <= low or np.isnan(high) or np.isnan(low):
            #if we make it in this branch, just sort by the control feature and separate into the top and bottom quarter
            sorted_idx = features[control].sort_values().index
            quarter = int(len(features)/4)
            control_high_idxs = sorted_idx[-quarter:]
            control_low_idxs = sorted_idx[:quarter]
        else:
            assert high > low 
            control_high_idxs = features[features[control] >= high].index
            control_low_idxs = features[features[control] <= low].index
#            if control=="bin_head":
#                import pdb; pdb.set_trace()
        assert len(control_high_idxs) > 2 and len(control_low_idxs) > 2
        return [control_high_idxs, control_low_idxs]
    elif control == "pos":
        noun_idxs = features[features["pos"]=="NOUN"].index
        verb_idxs = features[features["pos"]=="VERB"].index
        return [noun_idxs, verb_idxs]
    elif control == "uncontrolled":
        all_idxs = features.index
        return [all_idxs]
    else:
        print("control not found")
        import pdb; pdb.set_trace()    

def ttest_target_within_control(control_idxs, data_arr, target_1_idxs, target_2_idxs):
    target_1_control = np.intersect1d(control_idxs, target_1_idxs)
    target_2_control = np.intersect1d(control_idxs, target_2_idxs)
    target_1 = data_arr[:,target_1_control,:].mean(axis=-1)
    target_2 = data_arr[:,target_2_control,:].mean(axis=-1)
    tstat, pval = stats.ttest_ind(target_1, target_2, axis=-1)

    mean_diff = target_1.mean() - target_2.mean()
    std = np.std(np.concatenate([target_1, target_2], axis=1).flatten())
    cohens_d = mean_diff/std
    return pval, cohens_d

def equalize_target_among_control(all_control_idxs, target_idxs):
    controlled_target = [l.intersection(target_idxs) for l in all_control_idxs]
    min_size = min([len(l) for l in controlled_target])
    truncated_idxs = [l[:min_size] for l in controlled_target]
    equalized_idxs = np.concatenate(truncated_idxs)
    return equalized_idxs

def ttest_target_equalize_control(all_control_idxs, data_arr, target_1_idxs, target_2_idxs):
    '''
        all_control_idxs is a list of lists. each member list is a list of indexes.
    '''
    target_1_control = equalize_target_among_control(all_control_idxs, target_1_idxs)
    target_2_control = equalize_target_among_control(all_control_idxs, target_2_idxs)
    target_1 = data_arr[:,target_1_control,:].mean(axis=-1)
    target_2 = data_arr[:,target_2_control,:].mean(axis=-1)
    tstat, pval = stats.ttest_ind(target_1, target_2, axis=-1)

    mean_diff = target_1.mean() - target_2.mean()
    std = np.std(np.concatenate([target_1, target_2], axis=1).flatten())
    cohens_d = mean_diff/std
    assert pval.shape==(1,)
    return pval.item(), cohens_d

def run_ttest(data_arr, features, target, control, control_type, cfg):
    '''
    target - usually part-of-speech. The variable we are interested in. Are verbs and nouns different?
    control - e.g. volume
    control_type -- equalize or sub_sample
    returns ttest results for high and low volume
    data_arr shape [n_electrodes, n_words, n_timesteps]
    '''
    pval = 0.001
    assert data_arr.shape[1] == features.shape[0]
    assert len(data_arr.shape) == 3
    assert len(features) == features.index[-1]+1
    if target=="pos":
        target_1_idxs = features[(features[target]=="NOUN")].index
        target_2_idxs = features[(features[target]=="VERB")].index
    elif target in ["gpt2_surprisal"]:
        target_1_idxs, target_2_idxs = get_control_idxs(features, target)
    else:
        print("not handled yet")
        import pdb; pdb.set_trace() 

    all_control_idxs = get_control_idxs(features, control, target_1_idxs, target_2_idxs)

    all_pvals, all_cohens = [], []
    if control_type=="sub_sample":
        for control_idxs in all_control_idxs:
            min_pval = 1
            cohens = None
            for interval_start in np.arange(cfg.exp.search_start, cfg.exp.search_end, cfg.exp.interval):
                #get the data in the small interval
                data_delta = cfg.data.delta

                assert interval_start > data_delta
                assert interval_start + cfg.exp.interval < data_delta + cfg.data.duration

                interval_start_idx = int((interval_start-data_delta)*samp_rate)
                interval_end_idx = interval_start_idx + int(samp_rate*cfg.exp.interval)

                pval, d = ttest_target_within_control(control_idxs, data_arr[:,:,interval_start_idx:interval_end_idx], target_1_idxs, target_2_idxs)
                if pval < min_pval:
                    min_pval = pval
                    cohens = d
            all_pvals.append(min_pval)
            all_cohens.append(cohens)
        all_pvals = np.concatenate(all_pvals)
        return all_pvals.tolist(), all_cohens
    elif control_type=="equalize":
        pval, d = ttest_target_equalize_control(all_control_idxs, data_arr, target_1_idxs, target_2_idxs)
        if control=="bin_head":
            import pdb; pdb.set_trace()
        return [pval], [d]
    else:
        print("not found")
        import pdb; pdb.set_trace()
    
def run_all_features(neural_data, words_df, sig_features, subj, elec, target, cfg, control_type="sub_sample"):
    all_results = {}
    other_features = sig_features.copy().index.to_list()
    if target in sig_features:
        other_features.remove(target)

    other_features.append("uncontrolled")
    for other_feature in other_features:
        pvals, cohens = run_ttest(neural_data, words_df, target, other_feature, control_type, cfg)
        all_results[other_feature] = {"pvals": pvals,
                                      "cohens": cohens}
    return all_results
           
@hydra.main(config_path="../conf")
def main(cfg: DictConfig) -> None:
    log.info(f"Run testing for all electrodes in all test_subjects")
    log.info(OmegaConf.to_yaml(cfg, resolve=True))
    assert cfg.exp.control_type in ["sub_sample", "equalize"]
    '''
        sub_sample -- find rms high and rms low. For each domain, compare nouns vs. verbs
        equalize -- find rms high and rms low. equalize each for nouns and equalize each for verbs. compare n vs v.
    '''
    out_dir = cfg.exp.get("out_dir", None)
    if not out_dir:
        out_dir = os.getcwd()
    else:
        Path(out_dir).mkdir(exist_ok=True, parents=True)
    log.info(f'Working directory {out_dir}')

    glm_results_path = cfg.test.glm_results_path
    glm_results = get_all_results(glm_results_path)

    alpha = 0.05/1688
    pvals_df = glm_results[glm_results.index.str.startswith("P")]
    pvals_df.index = pvals_df.index.map(lambda x: x[6:])
    pvals_df = pvals_df.rename(lambda x: "pos" if x=="posVERB" else x)

    test_split_path = cfg.test.test_split_path 
    with open(test_split_path, "r") as f:
        test_splits = json.load(f)

    test_electrodes = None #For the topk. Omit this argument if you want everything

    data_cfg = cfg.data
    all_test_results = {}
    target = cfg.exp.target
    pvals_df_t = pvals_df.transpose()
    #target_sig = pvals_df_t[pvals_df_t[target] < alpha].index
    target_sig = pvals_df_t.index #Take everything

    for elec_subj in target_sig:
        elec,subj = elec_subj.split("_");
        #elec,subj = "T1cIe11_sub3".split("_") #TODO
        log.info(f"Subject {subj}")
        data_cfg.subject = subj

        data_cfg.electrodes = [elec]
        data_cfg.brain_runs = test_splits[subj]
        #data_cfg.brain_runs = ["trial002"] #TODO

        #pvals is indexed by features along rows and elec_subj along columns
        subj_elec_features = pvals_df.transpose().loc[f'{elec}_{subj}']
        #sig_features = subj_elec_features[subj_elec_features < alpha]
        sig_features = subj_elec_features #consider all features
        #assert len(set(FLOAT_FEATURES).intersection(set(subj_elec_features.index))) == len(FLOAT_FEATURES)

        subj_data = SubjectData(data_cfg)
        words_df = subj_data.words.reset_index(drop=True)
        neural_data = subj_data.neural_data
        elec_results = run_all_features(neural_data, words_df, sig_features, subj, elec, target, cfg, cfg.exp.control_type)

        all_test_results[elec_subj] = elec_results

    results_path = os.path.join(out_dir, f'all_results.json')
    with open(results_path, "w") as f:
        json.dump(all_test_results, f)
    log.info(f'Working directory {out_dir}')

if __name__ == "__main__":
    main()
