from regression.lm_embeddings.embeddings_store import MEGFeatureMapStore, EmbeddingsStore
import numpy as np

#should make this 
def delayed_sentence_effects_features_old(configs, embedding_store_loc, meg_store_loc, delays = 40, fold_tensors = True):
    embeddings_store = EmbeddingsStore(embedding_store_loc)
    meg_store_loc = MEGFeatureMapStore(meg_store_loc)
    
    def context_index_to_category(index):
        if index == 0:
            return 2
        if contexts[index][-1] in ".?!":
            return 3
        if contexts[index - 1][-1] in ".?!":
            return 2
        return 1
    out = []
    #start of sentence 1, middle 2, end 3. No word 0
    for config in configs:
        config_out = []
        
        meg_map = meg_store_loc.load_meg_map(config)
        contexts = embeddings_store.load_contexts(config.story)
        word_category = [context_index_to_category(i) for i in range(len(contexts))]
        
        categories_to_index = np.array(word_category + [0])
        sentence_categories = categories_to_index[meg_map]
        one_hot = np.eye(4, dtype = int, k=-1)[:,:-1]
        word_features = one_hot[sentence_categories]
        #should be N by 3
        padded_word_features = np.concat((np.zeros((delays-1, 3)), word_features), axis=0)
        for i in range(len(meg_map)):
            #3 by 40
            config_out.append(padded_word_features[i:i+delays].T)
        out_arr = np.stack(config_out, axis=0)
        if fold_tensors:
            out_arr = out_arr.reshape(out_arr.shape[0], -1)
        out.append(out_arr)
    return out

def delayed_sentence_effects_features(configs, embedding_store_loc, meg_store_loc, delays = 40, fold_tensors = True):
    embeddings_store = EmbeddingsStore(embedding_store_loc)
    meg_store_loc = MEGFeatureMapStore(meg_store_loc)
    
    def context_index_to_category(index):
        if contexts[index][-1] in ".?!":
            return 2
        if contexts[index - 1][-1] in ".?!":
            return 1
        return 0
    out = []
    #start of sentence 1, middle 2, end 3. No word 0
    for config in configs:
        config_out = []
        
        meg_map = meg_store_loc.load_meg_map(config)
        contexts = embeddings_store.load_contexts(config.story)
        word_category = [context_index_to_category(i) for i in range(len(contexts))]
        
        categories_to_index = np.array(word_category + [0])
        sentence_categories = categories_to_index[meg_map]
        one_hot = np.eye(3, dtype = int, k=-1)[:,:-1]
        word_features = one_hot[sentence_categories]
        #should be N by 2
        padded_word_features = np.concat((np.zeros((delays-1, 2)), word_features), axis=0)
        for i in range(len(meg_map)):
            #3 by 40
            config_out.append(padded_word_features[i:i+delays].T)
        out_arr = np.stack(config_out, axis=0)
        if fold_tensors:
            out_arr = out_arr.reshape(out_arr.shape[0], -1)
        out.append(out_arr)
    return out
    
def delayed_word_onset_features(meg_store_loc, configs, delays = 40, fold_tensors = True):
    meg_store = MEGFeatureMapStore(meg_store_loc)
    out = []
    for config in configs:
        config_out = []
        meg_index_maps = meg_store.load_meg_map(config)
        word_occured = (meg_index_maps != -1).astype(float)
        padded_word_occured = np.concat((np.zeros(delays-1), word_occured))
        for i in range(len(meg_index_maps)):
            config_out.append(padded_word_occured[i:i+delays][None, :])
        config_features = np.stack(config_out, axis=0)
        if fold_tensors:
            config_features = config_features.reshape(config_features.shape[0], -1)
        out.append(config_features)
    return out