import pickle
import numpy as np
import pdb
from tqdm import tqdm
from scipy.interpolate import interp1d
from scipy.signal import resample,  convolve, get_window
from scipy.ndimage import convolve1d, gaussian_filter1d

def filter_window(signal, window_name,  window_length=10):
    window = get_window(window_name, window_length)
    signal = convolve1d(signal, window)
    return signal

FILTER_DICT = {'gaussian':gaussian_filter1d, 'none': lambda x, **kwargs: x, 'window': filter_window}

def moving_center(X, n, axis=0):
    if n % 2 == 0:
        n += 1
    w = -np.ones(n) / n
    w[n // 2] += 1
    X_ctd = convolve1d(X, w, axis=axis)
    return X_ctd

def sinc_filter(X, fc, axis=0):
        
    # Windowed sinc filter
    b = 0.08  # Transition band, as a fraction of the sampling rate (in (0, 0.5)).
    N = int(np.ceil((4 / b)))
    if not N % 2: N += 1  # Make sure that N is odd.
    n = np.arange(N)
    
    # Compute sinc filter.
    h = np.sinc(2 * fc * (n - (N - 1) / 2))

    # Compute Blackman window.
    w = 0.42 - 0.5 * np.cos(2 * np.pi * n / (N - 1)) + \
        0.08 * np.cos(4 * np.pi * n / (N - 1))

    # Multiply sinc filter by window.
    h = h * w

    # Normalize to get unity gain.
    h = h / np.sum(h)
    return convolve(X, h)        

def window_spike_array(spike_times, tstart, tend):
    windowed_spike_times = np.zeros(spike_times.shape, dtype=np.object)

    for i in range(spike_times.shape[0]):
        for j in range(spike_times.shape[1]):
            wst, _ = window_spikes(spike_times[i, j], tstart[i], tend[i])
            windowed_spike_times[i, j] = wst

    return windowed_spike_times

def window_spikes(spike_times, tstart, tend, start_idx=0):

    spike_times = spike_times[start_idx:]
    spike_times[spike_times > tstart]

    if len(spike_times) > 0:
        start_idx = np.argmax(spike_times > tstart)
        end_idx = np.argmin(spike_times < tend)

        windowed_spike_times = spike_times[start_idx:end_idx]

        # Offset spike_times to start at 0
        if windowed_spike_times.size > 0:
                windowed_spike_times -= tstart

        return windowed_spike_times, end_idx - 1
    else:
        return np.array([]), start_idx

def align_behavior(x, T, bin_width):
    
    bins = np.linspace(0, T, int(T//bin_width))
    bin_centers = bins + (bins[1] - bins[0])/2
    bin_centers = bin_centers[:-1]
    xaligned = np.zeros((bin_centers.size, x.shape[-1]))
    
    for j in range(x.shape[-1]):
        interpolator = interp1d(np.linspace(0, T, x[:, j].size), x[:, j])
        xaligned[:, j] = interpolator(bin_centers)

    return xaligned

def align_peanut_behavior(t, x, bins):
    # Offset to 0
    t -= t[0]
    bin_centers = bins + (bins[1] - bins[0])/2
    bin_centers = bin_centers[:-1]
    interpolator = interp1d(t, x, axis=0)
    xaligned = interpolator(bin_centers)
    return xaligned, bin_centers

# spike_times: (n_trial, n_neurons)
#  trial threshold: If we require a spike threshold, trial threshold = 1 requires 
#  the spike threshold to hold for the neuron for all trials. 0 would mean no trials
def postprocess_spikes(spike_times, T, bin_width, boxcox, filter_fn, filter_kwargs,
                       spike_threshold=0, trial_threshold=1, high_pass=False, return_unit_filter=False):

    # Trials are of different duration
    if np.isscalar(T):
        ragged_trials = False
    else:
        ragged_trials = True

    # Discretize time over bins
    if ragged_trials:
        bins = []
        for i in range(len(T)):
            bins.append(np.linspace(0, T[i], int(T[i]//bin_width)))
        bins = np.array(bins, dtype=np.object)
        spike_rates = np.zeros((spike_times.shape[0], spike_times.shape[1]), dtype=np.object)
    else:
        bins = np.linspace(0, T, int(T//bin_width))
        spike_rates = np.zeros((spike_times.shape[0], spike_times.shape[1], bins.size - 1,))    

    # Did the trial/unit have enough spikes?
    insufficient_spikes = np.zeros(spike_times.shape)
    #print('Processing spikes')
    #for i in tqdm(range(spike_times.shape[0])):
    for i in range(spike_times.shape[0]):
        for j in range(spike_times.shape[1]):    

            # Ignore this trial/unit combo
            if np.any(np.isnan(spike_times[i, j])):
                insufficient_spikes[i, j] = 1          

            if ragged_trials:
                spike_counts = np.histogram(spike_times[i, j], bins=np.squeeze(bins[i]))[0]    
            else:
                spike_counts = np.histogram(spike_times[i, j], bins=bins)[0]

            if spike_threshold is not None:
                if np.sum(spike_counts) <= spike_threshold:
                    insufficient_spikes[i, j] = 1

            # Apply a boxcox transformation
            if boxcox is not None:
                spike_counts = np.array([(np.power(spike_count, boxcox) - 1)/boxcox 
                                         for spike_count in spike_counts])

            # Filter only if we have to, otherwise vectorize the process
            if ragged_trials:
                # Filter the resulting spike counts
                spike_rates_ = FILTER_DICT[filter_fn](spike_counts.astype(np.float), **filter_kwargs)

                # High pass to remove long term trends (needed for sabes data)
                if high_pass:
                    spike_rates_ = moving_center(spike_rates_, 600)
            else:
                spike_rates_ = spike_counts
            spike_rates[i, j] = spike_rates_

    # Filter out bad units
    sufficient_spikes = np.arange(spike_times.shape[1])[np.sum(insufficient_spikes, axis=0) < \
                                                        (1 - (trial_threshold -1e-3)) * spike_times.shape[0]]
    spike_rates = spike_rates[:, list(sufficient_spikes)]

    # Transpose so time is along the the second 'axis'
    if ragged_trials:
        spike_rates = [np.array([spike_rates[i, j] for j in range(spike_rates.shape[1])]).T for i in range(spike_rates.shape[0])]
    else:
        # Filter the resulting spike counts
        spike_rates = FILTER_DICT[filter_fn](spike_rates, **filter_kwargs)
        # High pass to remove long term trends (needed for sabes data)
        if high_pass:
            spike_rates = moving_center(spike_rates, 600, axis=-1)

        spike_rates = np.transpose(spike_rates, (0, 2, 1))

    if return_unit_filter:
        return spike_rates, sufficient_spikes
    else:
        return spike_rates

def load_peanut(fpath, epoch, spike_threshold, bin_width=25, boxcox=0.5,
                filter_fn='none', speed_threshold=4, region='HPc', filter_kwargs={}):
    '''
        Parameters:
            fpath: str
                 path to file
            epoch: list of ints
                which epochs (session) to load. The rat is sleeping during odd numbered epochs
            spike_threshold: int
                throw away neurons that spike less than the threshold during the epoch
            bin_width:  float 
                Bin width for binning spikes. Note the behavior is sampled at 25ms
            boxcox: float or None
                Apply boxcox transformation
            filter_fn: str
                Check filter_dict
            filter_kwargs
                keyword arguments for filter_fn
    '''

    data = pickle.load(open(fpath, 'rb'))
    dict_ = data['peanut_day14_epoch%d' % epoch]
    
    # Collect single units located in hippocampus

    HPc_probes = [key for key, value in dict_['identification']['nt_brain_region_dict'].items()
                  if value == 'HPc']

    OFC_probes = [key for key, value in dict_['identification']['nt_brain_region_dict'].items()
                  if value == 'OFC']

    if region == 'HPc':
        probes = HPc_probes
    elif region == 'OFC':
        probes = OFC_probes
    elif region == 'both':
        probes = list(set(HPc_probes).union(set(OFC_probes)))

    spike_times = []
    unit_ids = []
    for probe in dict_['spike_times'].keys():
        probe_id = probe.split('_')[-1]
        if probe_id in probes:
            for unit, times in dict_['spike_times'][probe].items():
                spike_times.append(list(times))
                unit_ids.append((probe_id, unit))
        else:
            continue


    # sort spike times
    spike_times = [list(np.sort(times)) for times in spike_times]
    # Apply spike threshold

    spike_threshold_filter = [idx for idx in range(len(spike_times))
                              if len(spike_times[idx]) > spike_threshold]
    spike_times = np.array(spike_times, dtype=object)
    spike_times = spike_times[spike_threshold_filter]
    unit_ids = np.array(unit_ids)[spike_threshold_filter]

    t = dict_['position_df']['time'].values
    T = t[-1] - t[0] 
    # Convert bin width to s
    bin_width = bin_width/1000
    
    # covnert smoothin bandwidth to indices
    if filter_fn == 'gaussian':
        filter_kwargs['sigma'] /= bin_width
        filter_kwargs['sigma'] = min(1, filter_kwargs['sigma'])
    
    bins = np.linspace(0, T, int(T//bin_width))

    spike_rates = np.zeros((bins.size - 1, len(spike_times)))
    for i in range(len(spike_times)):
        # translate to 0
        spike_times[i] -= t[0]
        
        spike_counts = np.histogram(spike_times[i], bins=bins)[0]
        if boxcox is not None:
            spike_counts = np.array([(np.power(spike_count, boxcox) - 1)/boxcox
                                     for spike_count in spike_counts])
        spike_rates_ = FILTER_DICT[filter_fn](spike_counts.astype(np.float), **filter_kwargs)
        
        spike_rates[:, i] = spike_rates_
    
    print('%d units\n' % unit_ids.size)
    print('Session duration: %f seconds' % (np.max([np.max(s) for s in spike_times]) -\
                                            np.min([np.min(s) for s in spike_times])))
    # Align behavior with the binned spike rates
    pos_linear = dict_['position_df']['position_linear'].values
    pos_xy = np.array([dict_['position_df']['x-loess'], dict_['position_df']['y-loess']]).T
    pos_linear, taligned = align_peanut_behavior(t, pos_linear, bins)
    pos_xy, _ = align_peanut_behavior(t, pos_xy, bins)
    
    dat = {}
    dat['unit_ids'] = unit_ids
    # Apply movement threshold
    if speed_threshold is not None:
        vel = np.divide(np.diff(pos_linear), np.diff(taligned))
        # trim off first index to match lengths
        spike_rates = spike_rates[1:, ...]
        pos_linear = pos_linear[1:, ...]
        pos_xy = pos_xy[1:, ...]

        spike_rates = spike_rates[np.abs(vel) > speed_threshold]

        pos_linear = pos_linear[np.abs(vel) > speed_threshold]
        pos_xy = pos_xy[np.abs(vel) > speed_threshold]

    dat['unit_ids'] = unit_ids
    dat['spike_rates'] = spike_rates
    dat['behavior'] = pos_xy
    dat['behavior_linear'] = pos_linear[:, np.newaxis]
    dat['time'] = taligned
    return dat

