import numpy as np
import os
import pandas as pd
from sklearn.decomposition import PCA
import os
import scipy.stats as stats
import seaborn as sns
import stimulus_features.nlp_model_util.elmo_utils as elmo_utils
import numexpr as ne
from matplotlib import pyplot as plt
from matplotlib.ticker import MaxNLocator

def getELMoOutput(textList, seq_len = 25, word_ind_to_extract= -1, remove_chars = ["?", ",", ".", "!", ":", ";"]):
	# seq_len = 25 # get ELMO representation using 25 words as context
	# word_ind_to_extract = -1 # always get the last word of the seq     
    ELMoOutput = elmo_utils.get_elmo_layer_representations(seq_len, textList, remove_chars, word_ind_to_extract)
    return elmoOutput

def subsetFirst512ElmoFeatures(features, layer = 0): 
    # want forward direction as that is most similar to the information available to viewer
    layerFeatures = features[layer]   
    first512 = [x[0:512] for x in layerFeatures] 
    return first512

def getELMo_first_hidden_layer(textList, seqLen = 25):
    runElmo = getElmoOutput(textList, seq_len = seqLen)

    runElmo_512_first_hidden_layer = subsetFirst512ElmoFeatures(runElmo, 0) # 1024 features per layer - we only want forward direction

    return runElmo_512_first_hidden_layer

# from Alex Huth / Gallant Lab
def lanczosinterp2D(data, oldtime, newtime, window=3):
    """Interpolates the columns of [data], assuming that the i'th row of data corresponds to
    oldtime(i).  A new matrix with the same number of columns and a number of rows given
    by the length of [newtime] is returned.  If [causal], only past time points will be used
    to computed the present value, and future time points will be ignored.
    The time points in [newtime] are assumed to be evenly spaced, and their frequency will
    be used to calculate the low-pass cutoff of the sinc interpolation filter.
    [window] lobes of the sinc function will be used.  [window] should be an integer.
    """
    ## Find the cutoff frequency ##
    cutoff = 1/np.mean(np.diff(newtime))
    print ("Doing sinc interpolation with cutoff={} and {} lobes.".format(cutoff, window))
    ## Build up sinc matrix ##
    sincmat = np.zeros((len(newtime), len(oldtime)))
    for ndi in range(len(newtime)):
        sincmat[ndi,:] = lanczosfun(cutoff, newtime[ndi]-oldtime, window)
    ## Construct new signal by multiplying the sinc matrix by the data ##
    #print("sincmat.shape, data.shape ", sincmat.shape, data.shape)
    newdata = np.dot(sincmat, data)
    return newdata
def lanczosfun(cutoff, t, window=3):
    """Compute the lanczos function with some cutoff frequency [B] at some time [t].
    [t] can be a scalar or any shaped numpy array.
    If given a [window], only the lowest-order [window] lobes of the sinc function
    will be non-zero.
    """
    t = t * cutoff
    
    pi = np.pi
   
    #val = window * np.sin(np.pi*t) * np.sin(np.pi*t/window) / (np.pi**2 * t**2)
    val = ne.evaluate("window * sin(pi*t) * sin(pi*t/window) / (pi**2 * t**2)")
   
    val[t==0] = 1.0
   
    val[np.abs(t)>window] = 0.0

    return val

def getPCComponents(featureMatrix, numPCs):

    pca = PCA(n_components=numPCs, svd_solver='full')
    pca.fit(featureMatrix)
    pcaFeatures = pca.transform(featureMatrix)
    
    return pcaFeatures

def getElmoPCs(ELMoFeatures, visualize = True):
    
    modelPCA = PCA()
    ELMoPCs = modelPCA.fit_transform(ELMoFeatures)
    print(ELMoPCs.shape)
    if visualize: 
        fig = plt.figure(figsize=[10,6])
        ax = fig.add_subplot(1,2,1) 
        ax.set_xlabel('Principal Component') #, fontsize = 15)
        ax.set_ylabel('Variance explained (%)') #, fontsize = 15)
        plt.title("ELMo PCs")
        ax.xaxis.set_major_locator(MaxNLocator(integer=True))
        # ax.set_title('N-component PCA') #, fontsize = 20)
        plt.plot(modelPCA.explained_variance_ratio_, marker='o', color='gray', linestyle='dashed')
        plt.show()
    # get variance explained from first 10 PCs
    print("% variance explained from first 10 PCS: " + str(np.sum(modelPCA.explained_variance_ratio_[0:10])))

    # get 10 PCs 
    return ELMoPCs[:, 0:10] 

def offsetFeatureMatrixByTRsNegativeAndPosOffset(featureMat, offsetTRVect = [-5, -4, -3, -2, -1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]): # includes original feature mat (no offset) and all the offsets Noted 
    trRows, featureCols = featureMat.shape
    
    if offsetTRVect[0] < 0:
        vectorOf0s = pd.concat([pd.DataFrame(np.zeros(featureCols).reshape(1, -1))]*abs(offsetTRVect[0]), ignore_index=True) # get rows which can be added at the end
        droppedRowsFeatureMat = np.delete(featureMat, np.arange(0, abs(offsetTRVect[0])), 0)
        featureMatNew = np.concatenate((droppedRowsFeatureMat, vectorOf0s))
        featureMatOffset = np.hstack((featureMat, featureMatNew[0:featureMat.shape[0], :]))
    
    elif offsetTRVect[0] > 0:
        vectorOf0s = pd.concat([pd.DataFrame(np.zeros(featureCols).reshape(1, -1))]*abs(offsetTRVect[0]), ignore_index=True)
        featureMatNew = np.concatenate((vectorOf0s, featureMat))
        featureMatOffset = np.hstack((featureMat, featureMatNew[0:featureMat.shape[0], :])) # only keep until end of orig feature mat as that's as long as fMRI data so design matrix can only be that long

    for idx, TRNum in enumerate(offsetTRVect[1:]): # for each TR offset  
        if TRNum < 0: 
            vectorOf0s = pd.concat([pd.DataFrame(np.zeros(featureCols).reshape(1, -1))]*abs(TRNum), ignore_index=True) # get rows which can be added at the end
            droppedRowsFeatureMat = np.delete(featureMat, np.arange(0, abs(TRNum)), 0)
            featureMatNew = np.concatenate((droppedRowsFeatureMat, vectorOf0s))
            featureMatOffset = np.hstack((featureMatOffset, featureMatNew[0:featureMatOffset.shape[0], :]))
        elif TRNum > 0:
            vectorOf0s = pd.concat([pd.DataFrame(np.zeros(featureCols).reshape(1, -1))]*TRNum, ignore_index=True) # get rows which can be added 
            featureMatNew = np.concatenate((vectorOf0s, featureMat))
            featureMatOffset = np.hstack((featureMatOffset, featureMatNew[0:featureMatOffset.shape[0], :]))
    return featureMatOffset
    

def get_ELMo_feature_matrix(text, time, number_of_TRs, TRs_to_drop):
    """
   Given word list, corresponding time in TRs per word list, total number of TRs in run, list of TRs to drop
   obtains the ELMo embedding and processes this to be an input feature matrix for an encoding model. This is
   designed for 1 run of data.
    Args:
        text: sequential list of the words in the stimuli as they were presented to participants
        time: list (corresponding to the text list) of the onset time (in TRs) for each word in the text list, values are floats 
        number_of_TRs: total number of TRs within the fMRI run
        TRs_to_drop: list of TRs to drop, values are ints
    Returns:
        Feaure matrix TRs x 10 PCs 
    """
    # get first hidden layer of ELMo
    ELMo_first_hidden_layer = getELMo_first_hidden_layer(text)
    # lanczos filter 
    run_aligned_features_first_hidden_layer = lanczosinterp2D(ELMo_first_hidden_layer, timeList, np.arange(0, number_of_TRs), window=3)
    # get pcs
    PCs_first_hidden_layer  = getElmoPCs(run_aligned_features_first_hidden_layer)

    # offset by tr delay 
    offset_first_hidden_layer = offsetFeatureMatrixByTRsNegativeAndPosOffset(PCs_first_hidden_layer, [1, 2, 3, 4, 5, 6, 7, 8]) 

    # drop trs skipping 
    dropped_TRs_first_hidden_layer = np.delete(offset_first_hidden_layer, TRs_to_drop, axis = 0)

    # save  feature matrix
    np.save("ELMo_first_hidden_layer_feature_matrix", dropped_TRs_first_hidden_layer)


def get_ELMo_feature_matrix_4_runs(text1, text2, text3, text4, time1, time2, time3, time4, number_of_TRs1, 
    number_of_TRs2, number_of_TRs3, number_of_TRs4, TRs_to_drop1, TRs_to_drop2, TRs_to_drop3, TRs_to_drop4):
    """
   Given word list, corresponding time in TRs per word list, total number of TRs in run, list of TRs to drop
   obtains the ELMo embedding and processes this to be an input feature matrix for an encoding model. This is
   designed for 4 runs of data, as in HCP.
    Args:
        text#: sequential list of the words in the stimuli as they were presented to participants
        time#: list (corresponding to the text list) of the onset time (in TRs) for each word in the text list, values are floats 
        number_of_TRs#: total number of TRs within the fMRI run
        TRs_to_drop#: list of TRs to drop, values are ints
        #: corresponds to what run the argument is from
    Returns:
        Feaure matrix TRs x ELMo embedding (# PCs x 9 {original + 1-8 TR offset})
    """
    # get first hidden layer of ELMo
    ELMo_first_hidden_layer1 = getELMo_first_hidden_layer(text1)
    ELMo_first_hidden_layer2 = getELMo_first_hidden_layer(text2)
    ELMo_first_hidden_layer3 = getELMo_first_hidden_layer(text3)
    ELMo_first_hidden_layer4 = getELMo_first_hidden_layer(text4)
    # lanczos filter 
    run_aligned_features_first_hidden_layer1 = lanczosinterp2D(ELMo_first_hidden_layer1, timeList1, np.arange(0, number_of_TRs1), window=3)
    run_aligned_features_first_hidden_layer2 = lanczosinterp2D(ELMo_first_hidden_layer2, timeList2, np.arange(0, number_of_TRs2), window=3)
    run_aligned_features_first_hidden_layer3 = lanczosinterp2D(ELMo_first_hidden_layer3, timeList3, np.arange(0, number_of_TRs3), window=3)
    run_aligned_features_first_hidden_layer4 = lanczosinterp2D(ELMo_first_hidden_layer4, timeList4, np.arange(0, number_of_TRs4), window=3)

    # get pcs
    concatenated_features_first_hidden_layer = np.vstack((run_aligned_features_first_hidden_layer1, run_aligned_features_first_hidden_layer2,
        run_aligned_features_first_hidden_layer3, run_aligned_features_first_hidden_layer4))
    PCs_first_hidden_layer  = getElmoPCs(concatenated_features_first_hidden_layer)

    # offset by tr delay 
    offset_first_hidden_layer1 = offsetFeatureMatrixByTRsNegativeAndPosOffset(PCs_first_hidden_layer[0:number_of_TRs1, :], [1, 2, 3, 4, 5, 6, 7, 8]) 
    offset_first_hidden_layer2 = offsetFeatureMatrixByTRsNegativeAndPosOffset(PCs_first_hidden_layer[number_of_TRs1:(number_of_TRs1 +number_of_TRs2), :], 
        [1, 2, 3, 4, 5, 6, 7, 8]) 
    offset_first_hidden_layer3 = offsetFeatureMatrixByTRsNegativeAndPosOffset(PCs_first_hidden_layer[(number_of_TRs1 +number_of_TRs2):(number_of_TRs1 +number_of_TRs2 + number_of_TRs3), :], 
        [1, 2, 3, 4, 5, 6, 7, 8]) 
    offset_first_hidden_layer4 = offsetFeatureMatrixByTRsNegativeAndPosOffset(PCs_first_hidden_layer[(number_of_TRs1 +number_of_TRs2 + number_of_TRs3):(number_of_TRs1 +number_of_TRs2 + number_of_TRs3 + number_of_TRs4), :], 
        [1, 2, 3, 4, 5, 6, 7, 8]) 

    # drop trs skipping 
    dropped_TRs_first_hidden_layer1 = np.delete(offset_first_hidden_layer1, TRs_to_drop1, axis = 0)
    dropped_TRs_first_hidden_layer2 = np.delete(offset_first_hidden_layer2, TRs_to_drop2, axis = 0)
    dropped_TRs_first_hidden_layer3 = np.delete(offset_first_hidden_layer3, TRs_to_drop3, axis = 0)
    dropped_TRs_first_hidden_layer4 = np.delete(offset_first_hidden_layer4, TRs_to_drop4, axis = 0)

    # save  feature matrix
    dropped_TRs_first_hidden_layer = np.vstack((dropped_TRs_first_hidden_layer1, dropped_TRs_first_hidden_layer2, dropped_TRs_first_hidden_layer3,
        dropped_TRs_first_hidden_layer4))
    np.save("ELMo_first_hidden_layer_feature_matrix", dropped_TRs_first_hidden_layer)







