import numpy as np
import xarray as xr
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy import signal
from scipy import linalg
from scipy import stats
from scipy.signal import hilbert
import scipy.io as sio
import os
import shutil
from scipy.interpolate import interp1d
from scipy.stats import pearsonr
import pickle
import copy
import random
from multiprocessing import Pool, current_process
import seaborn as sns
import warnings

from utils import *

def delayprofile_centroid_segments(sigorg):
    """ function to compute the principal delay profile using centroid """
    
    sig = sigorg.copy()
        
    dprfs = []
    binpks = []

    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=RuntimeWarning)
        for sigseg in sig:
            
            sigseg = sigseg.copy()
            sig_min = sigseg.min(axis=0)
            idx = sig_min < 0
            sigseg[:,idx] = sigseg[:,idx] - sig_min[idx]
            
            t = np.arange(sigseg.shape[0])
            dprfseg = sigseg.T @ t / sigseg.sum(axis=0)
            peak = np.nanmean(dprfseg)

            dprfs.append(dprfseg)
            binpks.append(peak)
    
    dprfs = np.array(dprfs)
    binpks = np.round(binpks).astype('int')
    
    idx0, idx1 = np.isnan(dprfs).nonzero()
    # dprfs[idx0, idx1] = np.nanmean(dprfs,axis=0)[idx1]
    dprfs[idx0, idx1] = np.nanmean(dprfs,axis=1)[idx0]
    
    return dprfs, binpks

def delayprofile_centroid(sigorg, binEdgeIdx):
    """ function to compute the principal delay profile using centroid """
    
    sig = sigorg.copy()
    
    sig_min = sig.min(axis=0)
    idx = sig_min < 0
    sig[:,idx] = sig[:,idx] - sig_min[idx]
        
    dprfs = []
    binpks = []

    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=RuntimeWarning)
        for sidx, eidx in zip(binEdgeIdx[:-1], binEdgeIdx[1:]):
            sigseg = sig[sidx:eidx,:]
            t = np.arange(eidx - sidx)
            dprfseg = sigseg.T @ t / sigseg.sum(axis=0)
            peak = np.nanmean(dprfseg) + sidx

            dprfs.append(dprfseg)
            binpks.append(peak)
    
    dprfs = np.array(dprfs)
    # binpks = np.round(binpks).astype('int')
    
    idx0, idx1 = np.isnan(dprfs).nonzero()
    # dprfs[idx0, idx1] = np.nanmean(dprfs,axis=0)[idx1]
    dprfs[idx0, idx1] = np.nanmean(dprfs,axis=1)[idx0]
    
    return dprfs, binpks

def delayprofile_max(sigorg, binEdgeIdx):
    """ function to compute the principal delay profile using max """
    
    sig = sigorg.copy()
    
    dprfs = []
    binpks = []
    
    for sidx, eidx in zip(binEdgeIdx[:-1], binEdgeIdx[1:]):
        sigseg = sig[sidx:eidx,:]
        dprfseg = np.argmax(sigseg, axis=0)
        peak = np.nanmean(dprfseg) + sidx
        
        dprfs.append(dprfseg)
        binpks.append(peak)
    
    dprfs = np.array(dprfs)
    
    idx0, idx1 = np.isnan(dprfs).nonzero()
    # dprfs[idx0, idx1] = np.nanmean(dprfs,axis=0)[idx1]
    dprfs[idx0, idx1] = np.nanmean(dprfs,axis=1)[idx0]
    
    return dprfs,binpks

def delayprofile_peak(sigorg, binEdgeIdx):
    """ function to compute the principal delay profile using peak """
    
    sig = sigorg.copy()
    
    dprfs = []
    binpks = []
    
    for sidx, eidx in zip(binEdgeIdx[:-1], binEdgeIdx[1:]):
        sigseg = sig[sidx:eidx,:]
        
        dprfseg = []
        for i in range(sigseg.shape[1]):
            pinds = signal.find_peaks(sigseg[:,i])[0]
            if pinds.shape[0] > 0:
                dprfseg.append(pinds[np.argmax(sigseg[pinds,i])])
            else:
                dprfseg.append(np.nan)
        if np.isnan(dprfseg).all():
            peak = (sidx + eidx) / 2
        else:
            peak = np.nanmean(dprfseg) + sidx
        
        dprfs.append(dprfseg)
        binpks.append(peak)
    
    dprfs = np.array(dprfs)
    # binpks = np.round(binpks).astype('int')
    
    idx0, idx1 = np.isnan(dprfs).nonzero()
    # dprfs[idx0, idx1] = np.nanmean(dprfs,axis=0)[idx1]
    dprfs[idx0, idx1] = np.nanmean(dprfs,axis=1)[idx0]
    
    return dprfs,binpks

def get_window_signal_around_anchors(signals, anchorInd, winSize, axis=0):
    """ get windowed signal around anchor index """
    extsize = list(signals.shape)
    extsize[axis] = winSize
    sigext = np.concatenate((np.full(extsize, np.nan), signals, np.full(extsize, np.nan)))
    
    winl = anchorInd
    winh = anchorInd + 2 * winSize + 1
    
    if axis == 0:
        sigwins = [sigext[l:h] for l, h in zip(winl, winh)]
    if axis == 1:
        sigwins = [sigext[:,l:h] for l, h in zip(winl, winh)]
    if axis == 2:
        sigwins = [sigext[:,:,l:h] for l, h in zip(winl, winh)]
    return np.array(sigwins)

def re_reference_lfps(lfps, ref_time):
    """ Re-reference the lfps by time-points """
    
    assert (lfps.time.values[0] < ref_time[0]) & (lfps.time.values[-1] > ref_time[-1]), 'time information is invalid'

    intp = interp1d(lfps.time.values.T, lfps.values.T)
    sig_intp = intp(ref_time).T
    lfps_intp = xr.DataArray(sig_intp, dims=['time', 'channel'], 
                             coords={'time': ref_time, 
                                     'channel': lfps.channel.values})
    return lfps_intp