import tensorflow as tf

#suppress all tensorflow warnings (largely related to compatability with v2)
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

import scipy.ndimage.filters
import numpy as np

from twpca import TWPCA
from twpca.regularizers import curvature

def load_bci_session(idx, dataDirs, rootDir, charDef):
    dataDir = dataDirs[idx]

    print('current session: ' + dataDir)
    dat = scipy.io.loadmat(rootDir+'/Datasets/'+dataDir+'/singleLetters.mat')

    #Because baseline firing rates drift over time, we normalize each electrode's firing rate by subtracting
    #its mean firing rate within each block of data (re-centering it). We also divide by each electrode's standard deviation 
    #to normalize the units.
    for char in charDef['charList']:
        neuralCube = dat['neuralActivityCube_'+char].astype(np.float64)

        #get the trials that belong to this character
        trlIdx = []
        for t in range(dat['characterCues'].shape[0]):
            if dat['characterCues'][t,0]==char:
                trlIdx.append(t)
        
        #get the block that each trial belonged to
        blockIdx = dat['blockNumsTimeSeries'][dat['goPeriodOnsetTimeBin'][trlIdx]]
        blockIdx = np.squeeze(blockIdx)

        #subtract block-specific means from each trial 
        for b in range(dat['blockList'].shape[0]):
            trialsFromThisBlock = np.squeeze(blockIdx==dat['blockList'][b])
            neuralCube[trialsFromThisBlock,:,:] -= dat['meansPerBlock'][np.newaxis,b,:]

        #divide by standard deviation to normalize the units
        neuralCube = neuralCube / dat['stdAcrossAllData'][np.newaxis,:,:]

        #replace the original cube with this newly normalized one
        dat['neuralActivityCube_'+char] = neuralCube

    return dat

def time_warping_bci(dat, charDef):
    alignedDat = {}
    smoothedDat = {}
    for char in charDef['charList']:
        print('Warping character: ' + char)
        #Clears the previous character's graph
        tf.compat.v1.reset_default_graph()  
        # tf.reset_default_graph()

        #Number of factors used to denoise the data while time-warping (by approximating data with low-rank matrices)
        n_components = 5

        #Adds an L1 penalty on the second order finite difference of the warping functions.
        #This encourages the warping functions to be piecewise linear.
        warp_regularizer = curvature(scale=0.001, power=1)

        #Adds an L2 penatly on the second order finite difference of the temporal factors.
        #Encourages the temporal factors to be smooth in time.
        time_regularizer = curvature(scale=1.0, power=2, axis=0)

        # Smooths the binned spike counts before time-warping to denoise them (this step is key!)
        smoothed_spikes = scipy.ndimage.filters.gaussian_filter1d(dat['neuralActivityCube_'+char], 3.0, axis=1)

        # fit time-warping model	
        tf.compat.v1.disable_eager_execution()
        model_pca = TWPCA(smoothed_spikes, 
                    n_components, 
                    warp_regularizer=warp_regularizer, 
                    time_regularizer=time_regularizer).fit(progressbar=False)

        # use the model object to align data 
        estimated_aligned_data = model_pca.transform(dat['neuralActivityCube_'+char])
        smoothed_aligned_data = scipy.ndimage.filters.gaussian_filter1d(estimated_aligned_data, 3.0, axis=1)

        # store smoothed data
        smoothedDat[char] = smoothed_aligned_data

        #store aligned data and time-warping functions
        alignedDat[char] = estimated_aligned_data
        alignedDat[char+'_T'] = model_pca.params['warp'].T.copy()
    return alignedDat, smoothedDat

def generate_bci_data_loader(batch_size, day_spike, day_cursor_pos_xy):
    batch_cnt = 0
    while True:
        if batch_cnt*batch_size >= day_spike.shape[0]:
            batch_cnt = 0

        start_index = batch_cnt*batch_size
        end_index = min(start_index + batch_size, day_spike.shape[0])
        batch_spike = day_spike[start_index:end_index]
            
        batch_label = day_cursor_pos_xy[start_index:end_index]
        batch_num = np.array(end_index - start_index)
        batch_cnt += 1

        # return a iterator
        yield batch_spike, batch_label, batch_num